diff options
Diffstat (limited to 'rag/retriever')
-rw-r--r-- | rag/retriever/document.py | 2 | ||||
-rw-r--r-- | rag/retriever/encoder.py | 42 | ||||
-rw-r--r-- | rag/retriever/parser/pdf.py | 6 | ||||
-rw-r--r-- | rag/retriever/rerank/local.py | 45 | ||||
-rw-r--r-- | rag/retriever/retriever.py | 57 | ||||
-rw-r--r-- | rag/retriever/vector.py | 32 |
6 files changed, 112 insertions, 72 deletions
diff --git a/rag/retriever/document.py b/rag/retriever/document.py index 132ec4b..df7a057 100644 --- a/rag/retriever/document.py +++ b/rag/retriever/document.py @@ -44,7 +44,7 @@ class DocumentDB: ) self.conn.commit() - def add(self, blob: Blob) -> bool: + def create(self, blob: Blob) -> bool: with self.conn.cursor() as cur: hash = self.__hash(blob) cur.execute( diff --git a/rag/retriever/encoder.py b/rag/retriever/encoder.py index b68c3bb..8b02a14 100644 --- a/rag/retriever/encoder.py +++ b/rag/retriever/encoder.py @@ -1,7 +1,8 @@ +from dataclasses import dataclass import hashlib import os from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Union import ollama from langchain_core.documents import Document @@ -9,29 +10,42 @@ from loguru import logger as log from qdrant_client.http.models import StrictFloat from tqdm import tqdm -from .vector import Point +from .vector import Documents, Point + +@dataclass +class Query: + query: str + + +Input = Query | Documents class Encoder: def __init__(self) -> None: self.model = os.environ["ENCODER_MODEL"] - self.query_prompt = "Represent this sentence for searching relevant passages: " - - def __encode(self, prompt: str) -> List[StrictFloat]: - return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"]) + self.preamble = ( + "Represent this sentence for searching relevant passages: " + if "mxbai-embed-large" in model_name + else "" + ) def __get_source(self, metadata: Dict[str, str]) -> str: source = metadata["source"] return Path(source).name - def encode_document(self, chunks: List[Document]) -> List[Point]: + def __encode(self, prompt: str) -> List[StrictFloat]: + return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"]) + + # TODO: move this to vec db and just return the embeddings + # TODO: use late chunking here + def __encode_document(self, chunks: List[Document]) -> List[Point]: log.debug("Encoding document...") return [ Point( id=hashlib.sha256( chunk.page_content.encode(encoding="utf-8") ).hexdigest(), - vector=self.__encode(chunk.page_content), + vector=list(self.__encode(chunk.page_content)), payload={ "text": chunk.page_content, "source": self.__get_source(chunk.metadata), @@ -40,8 +54,14 @@ class Encoder: for chunk in tqdm(chunks) ] - def encode_query(self, query: str) -> List[StrictFloat]: + def __encode_query(self, query: str) -> List[StrictFloat]: log.debug(f"Encoding query: {query}") - if self.model == "mxbai-embed-large": - query = self.query_prompt + query + query = self.preamble + query return self.__encode(query) + + def encode(self, x: Input) -> Union[List[StrictFloat], List[Point]]: + match x: + case Query(query): + return self.__encode_query(query) + case Documents(documents): + return self.__encode_document(documents) diff --git a/rag/retriever/parser/pdf.py b/rag/retriever/parser/pdf.py index 4c5addc..3253dc1 100644 --- a/rag/retriever/parser/pdf.py +++ b/rag/retriever/parser/pdf.py @@ -8,8 +8,10 @@ from langchain_community.document_loaders.parsers.pdf import ( PyPDFParser, ) from langchain_core.documents import Document +from rag.retriever.encoder import Chunks +# TODO: fix the PDFParser, remove langchain class PDFParser: def __init__(self) -> None: self.parser = PyPDFParser(password=None, extract_images=False) @@ -22,7 +24,7 @@ class PDFParser: def chunk( self, document: List[Document], source: Optional[str] = None - ) -> List[Document]: + ) -> Chunks: splitter = RecursiveCharacterTextSplitter( chunk_size=int(os.environ["CHUNK_SIZE"]), chunk_overlap=int(os.environ["CHUNK_OVERLAP"]), @@ -31,4 +33,4 @@ class PDFParser: if source is not None: for c in chunks: c.metadata["source"] = source - return chunks + return Chunks(chunks) diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py index 231d50a..e2bef31 100644 --- a/rag/retriever/rerank/local.py +++ b/rag/retriever/rerank/local.py @@ -1,46 +1,39 @@ import os -from typing import List from loguru import logger as log from sentence_transformers import CrossEncoder -from rag.message import Message +from rag.message import Messages +from rag.retriever.encoder import Query from rag.retriever.rerank.abstract import AbstractReranker -from rag.retriever.vector import Document +from rag.retriever.vector import Documents + +Context = Documents | Messages class Reranker(metaclass=AbstractReranker): def __init__(self) -> None: self.model = CrossEncoder(os.environ["RERANK_MODEL"], device="cpu") self.top_k = int(os.environ["RERANK_TOP_K"]) - self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"]) - - def rerank_documents(self, query: str, documents: List[Document]) -> List[str]: - results = self.model.rank( - query=query, - documents=[d.text for d in documents], - return_documents=False, - top_k=self.top_k, - ) - ranking = list( - filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results) - ) - log.debug( - f"Reranking gave {len(ranking)} relevant documents of {len(documents)}" - ) - return [documents[r.get("corpus_id", 0)] for r in ranking] + self.relevance_threshold = float(os.environ["RERANK_RELEVANCE_THRESHOLD"]) - def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: + def rerank(self, query: Query, documents: Context) -> Context: results = self.model.rank( - query=query, - documents=[m.content for m in messages], + query=query.query, + documents=documents.content(), return_documents=False, top_k=self.top_k, ) - ranking = list( - filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results) + rankings = list( + map( + lambda x: x.get("corpus_id", 0), + filter( + lambda x: x.get("score", 0.0) > self.relevance_threshold, results + ), + ) ) log.debug( - f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}" + f"Reranking gave {len(rankings)} relevant documents of {len(documents)}" ) - return [messages[r.get("corpus_id", 0)] for r in ranking] + documents.rerank(rankings) + return documents diff --git a/rag/retriever/retriever.py b/rag/retriever/retriever.py index 351cfb0..7d43941 100644 --- a/rag/retriever/retriever.py +++ b/rag/retriever/retriever.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from io import BytesIO from pathlib import Path from typing import List, Optional @@ -5,11 +6,25 @@ from typing import List, Optional from loguru import logger as log from .document import DocumentDB -from .encoder import Encoder +from .encoder import Encoder, Query from .parser.pdf import PDFParser from .vector import Document, VectorDB +@dataclass +class FilePath: + path: Path + + +@dataclass +class Blob: + blob: BytesIO + source: Optional[str] = None + + +FileType = FilePath | Blob + + class Retriever: def __init__(self) -> None: self.pdf_parser = PDFParser() @@ -17,35 +32,29 @@ class Retriever: self.doc_db = DocumentDB() self.vec_db = VectorDB() - def __add_pdf_from_path(self, path: Path): - log.debug(f"Adding pdf from {path}") + def __index_pdf_from_path(self, path: Path): + log.debug(f"Indexing pdf from {path}") blob = self.pdf_parser.from_path(path) - self.__add_pdf_from_blob(blob) + self.__index_pdf_from_blob(blob, None) - def __add_pdf_from_blob(self, blob: BytesIO, source: Optional[str] = None): - if self.doc_db.add(blob): - log.debug("Adding pdf to vector database...") + def __index_pdf_from_blob(self, blob: BytesIO, source: Optional[str]): + if self.doc_db.create(blob): + log.debug("Indexing pdf to vector database...") document = self.pdf_parser.from_data(blob) chunks = self.pdf_parser.chunk(document, source) - points = self.encoder.encode_document(chunks) - self.vec_db.add(points) + points = self.encoder.encode(chunks) + self.vec_db.index(points) else: log.debug("Document already exists!") - def add_pdf( - self, - path: Optional[Path] = None, - blob: Optional[BytesIO] = None, - source: Optional[str] = None, - ): - if path: - self.__add_pdf_from_path(path) - elif blob and source: - self.__add_pdf_from_blob(blob, source) - else: - log.error("Invalid input!") + def index(self, filetype: FileType): + match filetype: + case FilePath(path): + self.__index_pdf_from_path(path) + case Blob(blob, source): + self.__index_pdf_from_blob(blob, source) - def retrieve(self, query: str) -> List[Document]: - log.debug(f"Finding documents matching query: {query}") - query_emb = self.encoder.encode_query(query) + def search(self, query: Query) -> List[Document]: + log.debug(f"Finding documents matching query: {query.query}") + query_emb = self.encoder.encode(query) return self.vec_db.search(query_emb) diff --git a/rag/retriever/vector.py b/rag/retriever/vector.py index 1a484f3..b36aee8 100644 --- a/rag/retriever/vector.py +++ b/rag/retriever/vector.py @@ -21,6 +21,20 @@ class Document: text: str +@dataclass +class Documents: + documents: List[Document] + + def __len__(self): + return len(self.documents) + + def content(self) -> List[str]: + return [d.text for d in self.documents] + + def rerank(self, rankings: List[int]): + self.documents = [self.documents[r] for r in rankings] + + class VectorDB: def __init__(self): self.dim = int(os.environ["EMBEDDING_DIM"]) @@ -47,7 +61,7 @@ class VectorDB: log.info(f"Deleting collection {self.collection_name}") self.client.delete_collection(self.collection_name) - def add(self, points: List[Point]): + def index(self, points: List[Point]): log.debug(f"Inserting {len(points)} vectors into the vector db...") self.client.upload_points( collection_name=self.collection_name, @@ -59,7 +73,7 @@ class VectorDB: max_retries=3, ) - def search(self, query: List[float]) -> List[Document]: + def search(self, query: List[float]) -> Documents: log.debug("Searching for vectors...") hits = self.client.search( collection_name=self.collection_name, @@ -68,11 +82,13 @@ class VectorDB: score_threshold=self.score_threshold, ) log.debug(f"Got {len(hits)} hits in the vector db with limit={self.top_k}") - return list( - map( - lambda h: Document( - title=h.payload.get("source", ""), text=h.payload["text"] - ), - hits, + return Documents( + list( + map( + lambda h: Document( + title=h.payload.get("source", ""), text=h.payload["text"] + ), + hits, + ) ) ) |