summaryrefslogtreecommitdiff
path: root/rag/retriever
diff options
context:
space:
mode:
Diffstat (limited to 'rag/retriever')
-rw-r--r--rag/retriever/document.py2
-rw-r--r--rag/retriever/encoder.py42
-rw-r--r--rag/retriever/parser/pdf.py6
-rw-r--r--rag/retriever/rerank/local.py45
-rw-r--r--rag/retriever/retriever.py57
-rw-r--r--rag/retriever/vector.py32
6 files changed, 112 insertions, 72 deletions
diff --git a/rag/retriever/document.py b/rag/retriever/document.py
index 132ec4b..df7a057 100644
--- a/rag/retriever/document.py
+++ b/rag/retriever/document.py
@@ -44,7 +44,7 @@ class DocumentDB:
)
self.conn.commit()
- def add(self, blob: Blob) -> bool:
+ def create(self, blob: Blob) -> bool:
with self.conn.cursor() as cur:
hash = self.__hash(blob)
cur.execute(
diff --git a/rag/retriever/encoder.py b/rag/retriever/encoder.py
index b68c3bb..8b02a14 100644
--- a/rag/retriever/encoder.py
+++ b/rag/retriever/encoder.py
@@ -1,7 +1,8 @@
+from dataclasses import dataclass
import hashlib
import os
from pathlib import Path
-from typing import Dict, List
+from typing import Dict, List, Union
import ollama
from langchain_core.documents import Document
@@ -9,29 +10,42 @@ from loguru import logger as log
from qdrant_client.http.models import StrictFloat
from tqdm import tqdm
-from .vector import Point
+from .vector import Documents, Point
+
+@dataclass
+class Query:
+ query: str
+
+
+Input = Query | Documents
class Encoder:
def __init__(self) -> None:
self.model = os.environ["ENCODER_MODEL"]
- self.query_prompt = "Represent this sentence for searching relevant passages: "
-
- def __encode(self, prompt: str) -> List[StrictFloat]:
- return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"])
+ self.preamble = (
+ "Represent this sentence for searching relevant passages: "
+ if "mxbai-embed-large" in model_name
+ else ""
+ )
def __get_source(self, metadata: Dict[str, str]) -> str:
source = metadata["source"]
return Path(source).name
- def encode_document(self, chunks: List[Document]) -> List[Point]:
+ def __encode(self, prompt: str) -> List[StrictFloat]:
+ return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"])
+
+ # TODO: move this to vec db and just return the embeddings
+ # TODO: use late chunking here
+ def __encode_document(self, chunks: List[Document]) -> List[Point]:
log.debug("Encoding document...")
return [
Point(
id=hashlib.sha256(
chunk.page_content.encode(encoding="utf-8")
).hexdigest(),
- vector=self.__encode(chunk.page_content),
+ vector=list(self.__encode(chunk.page_content)),
payload={
"text": chunk.page_content,
"source": self.__get_source(chunk.metadata),
@@ -40,8 +54,14 @@ class Encoder:
for chunk in tqdm(chunks)
]
- def encode_query(self, query: str) -> List[StrictFloat]:
+ def __encode_query(self, query: str) -> List[StrictFloat]:
log.debug(f"Encoding query: {query}")
- if self.model == "mxbai-embed-large":
- query = self.query_prompt + query
+ query = self.preamble + query
return self.__encode(query)
+
+ def encode(self, x: Input) -> Union[List[StrictFloat], List[Point]]:
+ match x:
+ case Query(query):
+ return self.__encode_query(query)
+ case Documents(documents):
+ return self.__encode_document(documents)
diff --git a/rag/retriever/parser/pdf.py b/rag/retriever/parser/pdf.py
index 4c5addc..3253dc1 100644
--- a/rag/retriever/parser/pdf.py
+++ b/rag/retriever/parser/pdf.py
@@ -8,8 +8,10 @@ from langchain_community.document_loaders.parsers.pdf import (
PyPDFParser,
)
from langchain_core.documents import Document
+from rag.retriever.encoder import Chunks
+# TODO: fix the PDFParser, remove langchain
class PDFParser:
def __init__(self) -> None:
self.parser = PyPDFParser(password=None, extract_images=False)
@@ -22,7 +24,7 @@ class PDFParser:
def chunk(
self, document: List[Document], source: Optional[str] = None
- ) -> List[Document]:
+ ) -> Chunks:
splitter = RecursiveCharacterTextSplitter(
chunk_size=int(os.environ["CHUNK_SIZE"]),
chunk_overlap=int(os.environ["CHUNK_OVERLAP"]),
@@ -31,4 +33,4 @@ class PDFParser:
if source is not None:
for c in chunks:
c.metadata["source"] = source
- return chunks
+ return Chunks(chunks)
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py
index 231d50a..e2bef31 100644
--- a/rag/retriever/rerank/local.py
+++ b/rag/retriever/rerank/local.py
@@ -1,46 +1,39 @@
import os
-from typing import List
from loguru import logger as log
from sentence_transformers import CrossEncoder
-from rag.message import Message
+from rag.message import Messages
+from rag.retriever.encoder import Query
from rag.retriever.rerank.abstract import AbstractReranker
-from rag.retriever.vector import Document
+from rag.retriever.vector import Documents
+
+Context = Documents | Messages
class Reranker(metaclass=AbstractReranker):
def __init__(self) -> None:
self.model = CrossEncoder(os.environ["RERANK_MODEL"], device="cpu")
self.top_k = int(os.environ["RERANK_TOP_K"])
- self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"])
-
- def rerank_documents(self, query: str, documents: List[Document]) -> List[str]:
- results = self.model.rank(
- query=query,
- documents=[d.text for d in documents],
- return_documents=False,
- top_k=self.top_k,
- )
- ranking = list(
- filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results)
- )
- log.debug(
- f"Reranking gave {len(ranking)} relevant documents of {len(documents)}"
- )
- return [documents[r.get("corpus_id", 0)] for r in ranking]
+ self.relevance_threshold = float(os.environ["RERANK_RELEVANCE_THRESHOLD"])
- def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]:
+ def rerank(self, query: Query, documents: Context) -> Context:
results = self.model.rank(
- query=query,
- documents=[m.content for m in messages],
+ query=query.query,
+ documents=documents.content(),
return_documents=False,
top_k=self.top_k,
)
- ranking = list(
- filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results)
+ rankings = list(
+ map(
+ lambda x: x.get("corpus_id", 0),
+ filter(
+ lambda x: x.get("score", 0.0) > self.relevance_threshold, results
+ ),
+ )
)
log.debug(
- f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}"
+ f"Reranking gave {len(rankings)} relevant documents of {len(documents)}"
)
- return [messages[r.get("corpus_id", 0)] for r in ranking]
+ documents.rerank(rankings)
+ return documents
diff --git a/rag/retriever/retriever.py b/rag/retriever/retriever.py
index 351cfb0..7d43941 100644
--- a/rag/retriever/retriever.py
+++ b/rag/retriever/retriever.py
@@ -1,3 +1,4 @@
+from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
from typing import List, Optional
@@ -5,11 +6,25 @@ from typing import List, Optional
from loguru import logger as log
from .document import DocumentDB
-from .encoder import Encoder
+from .encoder import Encoder, Query
from .parser.pdf import PDFParser
from .vector import Document, VectorDB
+@dataclass
+class FilePath:
+ path: Path
+
+
+@dataclass
+class Blob:
+ blob: BytesIO
+ source: Optional[str] = None
+
+
+FileType = FilePath | Blob
+
+
class Retriever:
def __init__(self) -> None:
self.pdf_parser = PDFParser()
@@ -17,35 +32,29 @@ class Retriever:
self.doc_db = DocumentDB()
self.vec_db = VectorDB()
- def __add_pdf_from_path(self, path: Path):
- log.debug(f"Adding pdf from {path}")
+ def __index_pdf_from_path(self, path: Path):
+ log.debug(f"Indexing pdf from {path}")
blob = self.pdf_parser.from_path(path)
- self.__add_pdf_from_blob(blob)
+ self.__index_pdf_from_blob(blob, None)
- def __add_pdf_from_blob(self, blob: BytesIO, source: Optional[str] = None):
- if self.doc_db.add(blob):
- log.debug("Adding pdf to vector database...")
+ def __index_pdf_from_blob(self, blob: BytesIO, source: Optional[str]):
+ if self.doc_db.create(blob):
+ log.debug("Indexing pdf to vector database...")
document = self.pdf_parser.from_data(blob)
chunks = self.pdf_parser.chunk(document, source)
- points = self.encoder.encode_document(chunks)
- self.vec_db.add(points)
+ points = self.encoder.encode(chunks)
+ self.vec_db.index(points)
else:
log.debug("Document already exists!")
- def add_pdf(
- self,
- path: Optional[Path] = None,
- blob: Optional[BytesIO] = None,
- source: Optional[str] = None,
- ):
- if path:
- self.__add_pdf_from_path(path)
- elif blob and source:
- self.__add_pdf_from_blob(blob, source)
- else:
- log.error("Invalid input!")
+ def index(self, filetype: FileType):
+ match filetype:
+ case FilePath(path):
+ self.__index_pdf_from_path(path)
+ case Blob(blob, source):
+ self.__index_pdf_from_blob(blob, source)
- def retrieve(self, query: str) -> List[Document]:
- log.debug(f"Finding documents matching query: {query}")
- query_emb = self.encoder.encode_query(query)
+ def search(self, query: Query) -> List[Document]:
+ log.debug(f"Finding documents matching query: {query.query}")
+ query_emb = self.encoder.encode(query)
return self.vec_db.search(query_emb)
diff --git a/rag/retriever/vector.py b/rag/retriever/vector.py
index 1a484f3..b36aee8 100644
--- a/rag/retriever/vector.py
+++ b/rag/retriever/vector.py
@@ -21,6 +21,20 @@ class Document:
text: str
+@dataclass
+class Documents:
+ documents: List[Document]
+
+ def __len__(self):
+ return len(self.documents)
+
+ def content(self) -> List[str]:
+ return [d.text for d in self.documents]
+
+ def rerank(self, rankings: List[int]):
+ self.documents = [self.documents[r] for r in rankings]
+
+
class VectorDB:
def __init__(self):
self.dim = int(os.environ["EMBEDDING_DIM"])
@@ -47,7 +61,7 @@ class VectorDB:
log.info(f"Deleting collection {self.collection_name}")
self.client.delete_collection(self.collection_name)
- def add(self, points: List[Point]):
+ def index(self, points: List[Point]):
log.debug(f"Inserting {len(points)} vectors into the vector db...")
self.client.upload_points(
collection_name=self.collection_name,
@@ -59,7 +73,7 @@ class VectorDB:
max_retries=3,
)
- def search(self, query: List[float]) -> List[Document]:
+ def search(self, query: List[float]) -> Documents:
log.debug("Searching for vectors...")
hits = self.client.search(
collection_name=self.collection_name,
@@ -68,11 +82,13 @@ class VectorDB:
score_threshold=self.score_threshold,
)
log.debug(f"Got {len(hits)} hits in the vector db with limit={self.top_k}")
- return list(
- map(
- lambda h: Document(
- title=h.payload.get("source", ""), text=h.payload["text"]
- ),
- hits,
+ return Documents(
+ list(
+ map(
+ lambda h: Document(
+ title=h.payload.get("source", ""), text=h.payload["text"]
+ ),
+ hits,
+ )
)
)