diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-05 02:04:20 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-05 02:04:20 +0200 |
commit | 7dfd91176fecfc442783260183ad8f34807cc284 (patch) | |
tree | 1bed1b39d37f524cf4a1a9b9d1583c620f07bbbe /rag/llm/encoder.py | |
parent | 551f08f61f111342ca2b48c5757f20fc9ef74542 (diff) |
Update encoder
Diffstat (limited to 'rag/llm/encoder.py')
-rw-r--r-- | rag/llm/encoder.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/rag/llm/encoder.py b/rag/llm/encoder.py index d5e0566..095c6af 100644 --- a/rag/llm/encoder.py +++ b/rag/llm/encoder.py @@ -3,6 +3,7 @@ from typing import List import numpy as np import ollama +from langchain_core.documents import Document class Encoder: @@ -11,12 +12,12 @@ class Encoder: self.query_prompt = "Represent this sentence for searching relevant passages: " def __encode(self, prompt: str) -> np.ndarray: - x = ollama.embeddings(model=ENCODER_MODEL, prompt=prompt) + x = ollama.embeddings(model=self.model, prompt=prompt) x = np.array([x["embedding"]]).astype("float32") return x - def encode(self, doc: List[str]) -> List[np.ndarray]: - return [self.__encode(chunk) for chunk in doc] + def encode_document(self, chunks: List[Document]) -> np.ndarray: + return np.concatenate([self.__encode(chunk.page_content) for chunk in chunks]) def query(self, query: str) -> np.ndarray: query = self.query_prompt + query |