diff options
Diffstat (limited to 'rag')
-rw-r--r-- | rag/llm/encoder.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/rag/llm/encoder.py b/rag/llm/encoder.py index d5e0566..095c6af 100644 --- a/rag/llm/encoder.py +++ b/rag/llm/encoder.py @@ -3,6 +3,7 @@ from typing import List import numpy as np import ollama +from langchain_core.documents import Document class Encoder: @@ -11,12 +12,12 @@ class Encoder: self.query_prompt = "Represent this sentence for searching relevant passages: " def __encode(self, prompt: str) -> np.ndarray: - x = ollama.embeddings(model=ENCODER_MODEL, prompt=prompt) + x = ollama.embeddings(model=self.model, prompt=prompt) x = np.array([x["embedding"]]).astype("float32") return x - def encode(self, doc: List[str]) -> List[np.ndarray]: - return [self.__encode(chunk) for chunk in doc] + def encode_document(self, chunks: List[Document]) -> np.ndarray: + return np.concatenate([self.__encode(chunk.page_content) for chunk in chunks]) def query(self, query: str) -> np.ndarray: query = self.query_prompt + query |