summaryrefslogtreecommitdiff
path: root/rag/retriever/encoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/retriever/encoder.py')
-rw-r--r--rag/retriever/encoder.py42
1 files changed, 31 insertions, 11 deletions
diff --git a/rag/retriever/encoder.py b/rag/retriever/encoder.py
index b68c3bb..8b02a14 100644
--- a/rag/retriever/encoder.py
+++ b/rag/retriever/encoder.py
@@ -1,7 +1,8 @@
+from dataclasses import dataclass
import hashlib
import os
from pathlib import Path
-from typing import Dict, List
+from typing import Dict, List, Union
import ollama
from langchain_core.documents import Document
@@ -9,29 +10,42 @@ from loguru import logger as log
from qdrant_client.http.models import StrictFloat
from tqdm import tqdm
-from .vector import Point
+from .vector import Documents, Point
+
+@dataclass
+class Query:
+ query: str
+
+
+Input = Query | Documents
class Encoder:
def __init__(self) -> None:
self.model = os.environ["ENCODER_MODEL"]
- self.query_prompt = "Represent this sentence for searching relevant passages: "
-
- def __encode(self, prompt: str) -> List[StrictFloat]:
- return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"])
+ self.preamble = (
+ "Represent this sentence for searching relevant passages: "
+ if "mxbai-embed-large" in model_name
+ else ""
+ )
def __get_source(self, metadata: Dict[str, str]) -> str:
source = metadata["source"]
return Path(source).name
- def encode_document(self, chunks: List[Document]) -> List[Point]:
+ def __encode(self, prompt: str) -> List[StrictFloat]:
+ return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"])
+
+ # TODO: move this to vec db and just return the embeddings
+ # TODO: use late chunking here
+ def __encode_document(self, chunks: List[Document]) -> List[Point]:
log.debug("Encoding document...")
return [
Point(
id=hashlib.sha256(
chunk.page_content.encode(encoding="utf-8")
).hexdigest(),
- vector=self.__encode(chunk.page_content),
+ vector=list(self.__encode(chunk.page_content)),
payload={
"text": chunk.page_content,
"source": self.__get_source(chunk.metadata),
@@ -40,8 +54,14 @@ class Encoder:
for chunk in tqdm(chunks)
]
- def encode_query(self, query: str) -> List[StrictFloat]:
+ def __encode_query(self, query: str) -> List[StrictFloat]:
log.debug(f"Encoding query: {query}")
- if self.model == "mxbai-embed-large":
- query = self.query_prompt + query
+ query = self.preamble + query
return self.__encode(query)
+
+ def encode(self, x: Input) -> Union[List[StrictFloat], List[Point]]:
+ match x:
+ case Query(query):
+ return self.__encode_query(query)
+ case Documents(documents):
+ return self.__encode_document(documents)