diff options
Diffstat (limited to 'rag/retriever/encoder.py')
-rw-r--r-- | rag/retriever/encoder.py | 42 |
1 files changed, 31 insertions, 11 deletions
diff --git a/rag/retriever/encoder.py b/rag/retriever/encoder.py index b68c3bb..8b02a14 100644 --- a/rag/retriever/encoder.py +++ b/rag/retriever/encoder.py @@ -1,7 +1,8 @@ +from dataclasses import dataclass import hashlib import os from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Union import ollama from langchain_core.documents import Document @@ -9,29 +10,42 @@ from loguru import logger as log from qdrant_client.http.models import StrictFloat from tqdm import tqdm -from .vector import Point +from .vector import Documents, Point + +@dataclass +class Query: + query: str + + +Input = Query | Documents class Encoder: def __init__(self) -> None: self.model = os.environ["ENCODER_MODEL"] - self.query_prompt = "Represent this sentence for searching relevant passages: " - - def __encode(self, prompt: str) -> List[StrictFloat]: - return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"]) + self.preamble = ( + "Represent this sentence for searching relevant passages: " + if "mxbai-embed-large" in model_name + else "" + ) def __get_source(self, metadata: Dict[str, str]) -> str: source = metadata["source"] return Path(source).name - def encode_document(self, chunks: List[Document]) -> List[Point]: + def __encode(self, prompt: str) -> List[StrictFloat]: + return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"]) + + # TODO: move this to vec db and just return the embeddings + # TODO: use late chunking here + def __encode_document(self, chunks: List[Document]) -> List[Point]: log.debug("Encoding document...") return [ Point( id=hashlib.sha256( chunk.page_content.encode(encoding="utf-8") ).hexdigest(), - vector=self.__encode(chunk.page_content), + vector=list(self.__encode(chunk.page_content)), payload={ "text": chunk.page_content, "source": self.__get_source(chunk.metadata), @@ -40,8 +54,14 @@ class Encoder: for chunk in tqdm(chunks) ] - def encode_query(self, query: str) -> List[StrictFloat]: + def __encode_query(self, query: str) -> List[StrictFloat]: log.debug(f"Encoding query: {query}") - if self.model == "mxbai-embed-large": - query = self.query_prompt + query + query = self.preamble + query return self.__encode(query) + + def encode(self, x: Input) -> Union[List[StrictFloat], List[Point]]: + match x: + case Query(query): + return self.__encode_query(query) + case Documents(documents): + return self.__encode_document(documents) |