summaryrefslogtreecommitdiff
path: root/rag/llm/encoder.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-05 02:04:20 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-05 02:04:20 +0200
commit7dfd91176fecfc442783260183ad8f34807cc284 (patch)
tree1bed1b39d37f524cf4a1a9b9d1583c620f07bbbe /rag/llm/encoder.py
parent551f08f61f111342ca2b48c5757f20fc9ef74542 (diff)
Update encoder
Diffstat (limited to 'rag/llm/encoder.py')
-rw-r--r--rag/llm/encoder.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/rag/llm/encoder.py b/rag/llm/encoder.py
index d5e0566..095c6af 100644
--- a/rag/llm/encoder.py
+++ b/rag/llm/encoder.py
@@ -3,6 +3,7 @@ from typing import List
import numpy as np
import ollama
+from langchain_core.documents import Document
class Encoder:
@@ -11,12 +12,12 @@ class Encoder:
self.query_prompt = "Represent this sentence for searching relevant passages: "
def __encode(self, prompt: str) -> np.ndarray:
- x = ollama.embeddings(model=ENCODER_MODEL, prompt=prompt)
+ x = ollama.embeddings(model=self.model, prompt=prompt)
x = np.array([x["embedding"]]).astype("float32")
return x
- def encode(self, doc: List[str]) -> List[np.ndarray]:
- return [self.__encode(chunk) for chunk in doc]
+ def encode_document(self, chunks: List[Document]) -> np.ndarray:
+ return np.concatenate([self.__encode(chunk.page_content) for chunk in chunks])
def query(self, query: str) -> np.ndarray:
query = self.query_prompt + query