diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2025-08-25 00:06:19 +0200 | 
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2025-08-25 00:06:19 +0200 | 
| commit | 28a1f5d4eddab6eb7c9ca77346c6fa9608856dd5 (patch) | |
| tree | 563ffd32f1a6f5705c1fbf6230d5d226fd0e0e48 /rag/retriever | |
| parent | 6afba9079eebe867ac4f1b6073b5277513e7491b (diff) | |
Diffstat (limited to 'rag/retriever')
| -rw-r--r-- | rag/retriever/document.py | 2 | ||||
| -rw-r--r-- | rag/retriever/encoder.py | 42 | ||||
| -rw-r--r-- | rag/retriever/parser/pdf.py | 6 | ||||
| -rw-r--r-- | rag/retriever/rerank/local.py | 45 | ||||
| -rw-r--r-- | rag/retriever/retriever.py | 57 | ||||
| -rw-r--r-- | rag/retriever/vector.py | 32 | 
6 files changed, 112 insertions, 72 deletions
diff --git a/rag/retriever/document.py b/rag/retriever/document.py index 132ec4b..df7a057 100644 --- a/rag/retriever/document.py +++ b/rag/retriever/document.py @@ -44,7 +44,7 @@ class DocumentDB:              )              self.conn.commit() -    def add(self, blob: Blob) -> bool: +    def create(self, blob: Blob) -> bool:          with self.conn.cursor() as cur:              hash = self.__hash(blob)              cur.execute( diff --git a/rag/retriever/encoder.py b/rag/retriever/encoder.py index b68c3bb..8b02a14 100644 --- a/rag/retriever/encoder.py +++ b/rag/retriever/encoder.py @@ -1,7 +1,8 @@ +from dataclasses import dataclass  import hashlib  import os  from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Union  import ollama  from langchain_core.documents import Document @@ -9,29 +10,42 @@ from loguru import logger as log  from qdrant_client.http.models import StrictFloat  from tqdm import tqdm -from .vector import Point +from .vector import Documents, Point + +@dataclass +class Query: +    query: str + + +Input = Query | Documents  class Encoder:      def __init__(self) -> None:          self.model = os.environ["ENCODER_MODEL"] -        self.query_prompt = "Represent this sentence for searching relevant passages: " - -    def __encode(self, prompt: str) -> List[StrictFloat]: -        return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"]) +        self.preamble = ( +            "Represent this sentence for searching relevant passages: " +            if "mxbai-embed-large" in model_name +            else "" +        )      def __get_source(self, metadata: Dict[str, str]) -> str:          source = metadata["source"]          return Path(source).name -    def encode_document(self, chunks: List[Document]) -> List[Point]: +    def __encode(self, prompt: str) -> List[StrictFloat]: +        return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"]) + +    # TODO: move this to vec db and just return the embeddings +    # TODO: use late chunking here +    def __encode_document(self, chunks: List[Document]) -> List[Point]:          log.debug("Encoding document...")          return [              Point(                  id=hashlib.sha256(                      chunk.page_content.encode(encoding="utf-8")                  ).hexdigest(), -                vector=self.__encode(chunk.page_content), +                vector=list(self.__encode(chunk.page_content)),                  payload={                      "text": chunk.page_content,                      "source": self.__get_source(chunk.metadata), @@ -40,8 +54,14 @@ class Encoder:              for chunk in tqdm(chunks)          ] -    def encode_query(self, query: str) -> List[StrictFloat]: +    def __encode_query(self, query: str) -> List[StrictFloat]:          log.debug(f"Encoding query: {query}") -        if self.model == "mxbai-embed-large": -            query = self.query_prompt + query +        query = self.preamble + query          return self.__encode(query) + +    def encode(self, x: Input) -> Union[List[StrictFloat], List[Point]]: +        match x: +            case Query(query): +                return self.__encode_query(query) +            case Documents(documents): +                return self.__encode_document(documents) diff --git a/rag/retriever/parser/pdf.py b/rag/retriever/parser/pdf.py index 4c5addc..3253dc1 100644 --- a/rag/retriever/parser/pdf.py +++ b/rag/retriever/parser/pdf.py @@ -8,8 +8,10 @@ from langchain_community.document_loaders.parsers.pdf import (      PyPDFParser,  )  from langchain_core.documents import Document +from rag.retriever.encoder import Chunks +# TODO: fix the PDFParser, remove langchain  class PDFParser:      def __init__(self) -> None:          self.parser = PyPDFParser(password=None, extract_images=False) @@ -22,7 +24,7 @@ class PDFParser:      def chunk(          self, document: List[Document], source: Optional[str] = None -    ) -> List[Document]: +    ) -> Chunks:          splitter = RecursiveCharacterTextSplitter(              chunk_size=int(os.environ["CHUNK_SIZE"]),              chunk_overlap=int(os.environ["CHUNK_OVERLAP"]), @@ -31,4 +33,4 @@ class PDFParser:          if source is not None:              for c in chunks:                  c.metadata["source"] = source -        return chunks +        return Chunks(chunks) diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py index 231d50a..e2bef31 100644 --- a/rag/retriever/rerank/local.py +++ b/rag/retriever/rerank/local.py @@ -1,46 +1,39 @@  import os -from typing import List  from loguru import logger as log  from sentence_transformers import CrossEncoder -from rag.message import Message +from rag.message import Messages +from rag.retriever.encoder import Query  from rag.retriever.rerank.abstract import AbstractReranker -from rag.retriever.vector import Document +from rag.retriever.vector import Documents + +Context = Documents | Messages  class Reranker(metaclass=AbstractReranker):      def __init__(self) -> None:          self.model = CrossEncoder(os.environ["RERANK_MODEL"], device="cpu")          self.top_k = int(os.environ["RERANK_TOP_K"]) -        self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"]) - -    def rerank_documents(self, query: str, documents: List[Document]) -> List[str]: -        results = self.model.rank( -            query=query, -            documents=[d.text for d in documents], -            return_documents=False, -            top_k=self.top_k, -        ) -        ranking = list( -            filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results) -        ) -        log.debug( -            f"Reranking gave {len(ranking)} relevant documents of {len(documents)}" -        ) -        return [documents[r.get("corpus_id", 0)] for r in ranking] +        self.relevance_threshold = float(os.environ["RERANK_RELEVANCE_THRESHOLD"]) -    def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: +    def rerank(self, query: Query, documents: Context) -> Context:          results = self.model.rank( -            query=query, -            documents=[m.content for m in messages], +            query=query.query, +            documents=documents.content(),              return_documents=False,              top_k=self.top_k,          ) -        ranking = list( -            filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results) +        rankings = list( +            map( +                lambda x: x.get("corpus_id", 0), +                filter( +                    lambda x: x.get("score", 0.0) > self.relevance_threshold, results +                ), +            )          )          log.debug( -            f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}" +            f"Reranking gave {len(rankings)} relevant documents of {len(documents)}"          ) -        return [messages[r.get("corpus_id", 0)] for r in ranking] +        documents.rerank(rankings) +        return documents diff --git a/rag/retriever/retriever.py b/rag/retriever/retriever.py index 351cfb0..7d43941 100644 --- a/rag/retriever/retriever.py +++ b/rag/retriever/retriever.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass  from io import BytesIO  from pathlib import Path  from typing import List, Optional @@ -5,11 +6,25 @@ from typing import List, Optional  from loguru import logger as log  from .document import DocumentDB -from .encoder import Encoder +from .encoder import Encoder, Query  from .parser.pdf import PDFParser  from .vector import Document, VectorDB +@dataclass +class FilePath: +    path: Path + + +@dataclass +class Blob: +    blob: BytesIO +    source: Optional[str] = None + + +FileType = FilePath | Blob + +  class Retriever:      def __init__(self) -> None:          self.pdf_parser = PDFParser() @@ -17,35 +32,29 @@ class Retriever:          self.doc_db = DocumentDB()          self.vec_db = VectorDB() -    def __add_pdf_from_path(self, path: Path): -        log.debug(f"Adding pdf from {path}") +    def __index_pdf_from_path(self, path: Path): +        log.debug(f"Indexing pdf from {path}")          blob = self.pdf_parser.from_path(path) -        self.__add_pdf_from_blob(blob) +        self.__index_pdf_from_blob(blob, None) -    def __add_pdf_from_blob(self, blob: BytesIO, source: Optional[str] = None): -        if self.doc_db.add(blob): -            log.debug("Adding pdf to vector database...") +    def __index_pdf_from_blob(self, blob: BytesIO, source: Optional[str]): +        if self.doc_db.create(blob): +            log.debug("Indexing pdf to vector database...")              document = self.pdf_parser.from_data(blob)              chunks = self.pdf_parser.chunk(document, source) -            points = self.encoder.encode_document(chunks) -            self.vec_db.add(points) +            points = self.encoder.encode(chunks) +            self.vec_db.index(points)          else:              log.debug("Document already exists!") -    def add_pdf( -        self, -        path: Optional[Path] = None, -        blob: Optional[BytesIO] = None, -        source: Optional[str] = None, -    ): -        if path: -            self.__add_pdf_from_path(path) -        elif blob and source: -            self.__add_pdf_from_blob(blob, source) -        else: -            log.error("Invalid input!") +    def index(self, filetype: FileType): +        match filetype: +            case FilePath(path): +                self.__index_pdf_from_path(path) +            case Blob(blob, source): +                self.__index_pdf_from_blob(blob, source) -    def retrieve(self, query: str) -> List[Document]: -        log.debug(f"Finding documents matching query: {query}") -        query_emb = self.encoder.encode_query(query) +    def search(self, query: Query) -> List[Document]: +        log.debug(f"Finding documents matching query: {query.query}") +        query_emb = self.encoder.encode(query)          return self.vec_db.search(query_emb) diff --git a/rag/retriever/vector.py b/rag/retriever/vector.py index 1a484f3..b36aee8 100644 --- a/rag/retriever/vector.py +++ b/rag/retriever/vector.py @@ -21,6 +21,20 @@ class Document:      text: str +@dataclass +class Documents: +    documents: List[Document] + +    def __len__(self): +        return len(self.documents) + +    def content(self) -> List[str]: +        return [d.text for d in self.documents] + +    def rerank(self, rankings: List[int]): +        self.documents = [self.documents[r] for r in rankings] + +  class VectorDB:      def __init__(self):          self.dim = int(os.environ["EMBEDDING_DIM"]) @@ -47,7 +61,7 @@ class VectorDB:          log.info(f"Deleting collection {self.collection_name}")          self.client.delete_collection(self.collection_name) -    def add(self, points: List[Point]): +    def index(self, points: List[Point]):          log.debug(f"Inserting {len(points)} vectors into the vector db...")          self.client.upload_points(              collection_name=self.collection_name, @@ -59,7 +73,7 @@ class VectorDB:              max_retries=3,          ) -    def search(self, query: List[float]) -> List[Document]: +    def search(self, query: List[float]) -> Documents:          log.debug("Searching for vectors...")          hits = self.client.search(              collection_name=self.collection_name, @@ -68,11 +82,13 @@ class VectorDB:              score_threshold=self.score_threshold,          )          log.debug(f"Got {len(hits)} hits in the vector db with limit={self.top_k}") -        return list( -            map( -                lambda h: Document( -                    title=h.payload.get("source", ""), text=h.payload["text"] -                ), -                hits, +        return Documents( +            list( +                map( +                    lambda h: Document( +                        title=h.payload.get("source", ""), text=h.payload["text"] +                    ), +                    hits, +                )              )          )  |