summaryrefslogtreecommitdiff
path: root/rag/llm/encoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/llm/encoder.py')
-rw-r--r--rag/llm/encoder.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/rag/llm/encoder.py b/rag/llm/encoder.py
index d5e0566..095c6af 100644
--- a/rag/llm/encoder.py
+++ b/rag/llm/encoder.py
@@ -3,6 +3,7 @@ from typing import List
import numpy as np
import ollama
+from langchain_core.documents import Document
class Encoder:
@@ -11,12 +12,12 @@ class Encoder:
self.query_prompt = "Represent this sentence for searching relevant passages: "
def __encode(self, prompt: str) -> np.ndarray:
- x = ollama.embeddings(model=ENCODER_MODEL, prompt=prompt)
+ x = ollama.embeddings(model=self.model, prompt=prompt)
x = np.array([x["embedding"]]).astype("float32")
return x
- def encode(self, doc: List[str]) -> List[np.ndarray]:
- return [self.__encode(chunk) for chunk in doc]
+ def encode_document(self, chunks: List[Document]) -> np.ndarray:
+ return np.concatenate([self.__encode(chunk.page_content) for chunk in chunks])
def query(self, query: str) -> np.ndarray:
query = self.query_prompt + query