From 91ddb3672e514fa9824609ff047d7cab0c65631a Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 9 Apr 2024 00:14:00 +0200 Subject: Refactor --- notebooks/testing.ipynb | 53 +++++++++++++++++++++++++++- poetry.lock | 2 +- rag/cli.py | 1 - rag/db/__init__.py | 0 rag/db/document.py | 57 ------------------------------ rag/db/vector.py | 73 -------------------------------------- rag/generator/__init__.py | 15 ++++++++ rag/generator/abstract.py | 19 ++++++++++ rag/generator/cohere.py | 29 +++++++++++++++ rag/generator/ollama.py | 71 +++++++++++++++++++++++++++++++++++++ rag/generator/prompt.py | 14 ++++++++ rag/llm/__init__.py | 0 rag/llm/cohere_generator.py | 29 --------------- rag/llm/encoder.py | 47 ------------------------- rag/llm/ollama_generator.py | 76 ---------------------------------------- rag/parser/__init__.py | 0 rag/parser/pdf.py | 34 ------------------ rag/rag.py | 63 ++++++++++++++------------------- rag/retriever/__init__.py | 0 rag/retriever/document.py | 57 ++++++++++++++++++++++++++++++ rag/retriever/encoder.py | 43 +++++++++++++++++++++++ rag/retriever/parser/__init__.py | 0 rag/retriever/parser/pdf.py | 34 ++++++++++++++++++ rag/retriever/retriever.py | 37 +++++++++++++++++++ rag/retriever/vector.py | 73 ++++++++++++++++++++++++++++++++++++++ rag/ui.py | 21 +++++------ 26 files changed, 482 insertions(+), 366 deletions(-) delete mode 100644 rag/db/__init__.py delete mode 100644 rag/db/document.py delete mode 100644 rag/db/vector.py create mode 100644 rag/generator/__init__.py create mode 100644 rag/generator/abstract.py create mode 100644 rag/generator/cohere.py create mode 100644 rag/generator/ollama.py create mode 100644 rag/generator/prompt.py delete mode 100644 rag/llm/__init__.py delete mode 100644 rag/llm/cohere_generator.py delete mode 100644 rag/llm/encoder.py delete mode 100644 rag/llm/ollama_generator.py delete mode 100644 rag/parser/__init__.py delete mode 100644 rag/parser/pdf.py create mode 100644 rag/retriever/__init__.py create mode 100644 rag/retriever/document.py create mode 100644 rag/retriever/encoder.py create mode 100644 rag/retriever/parser/__init__.py create mode 100644 rag/retriever/parser/pdf.py create mode 100644 rag/retriever/retriever.py create mode 100644 rag/retriever/vector.py diff --git a/notebooks/testing.ipynb b/notebooks/testing.ipynb index 18827fc..bcaa613 100644 --- a/notebooks/testing.ipynb +++ b/notebooks/testing.ipynb @@ -98,10 +98,61 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "05f7068c-b4c6-47b2-ac62-79c021838500", "metadata": {}, "outputs": [], + "source": [ + "from rag.generator import get_generator" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "31dd3e7a-1d6c-4cc6-8ffe-f83478f95875", + "metadata": {}, + "outputs": [], + "source": [ + "x = get_generator(\"ollama\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "749c58c4-5250-464c-b4f4-b5b4a00be3b3", + "metadata": {}, + "outputs": [], + "source": [ + "y = get_generator(\"ollama\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c59ece14-11e0-416c-afad-bc2ff6a526ec", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x is y" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e664d3d-a787-45e5-8b71-8e5e0f348443", + "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/poetry.lock b/poetry.lock index ffeea46..8a2b3a2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -388,7 +388,7 @@ files = [ [[package]] name = "cffi" version = "1.16.0" -description = "Foreign Function Interface for Python calling C code." +description = "Foreign Function AbstractGenerator for Python calling C code." optional = false python-versions = ">=3.8" files = [ diff --git a/rag/cli.py b/rag/cli.py index c470db3..d5651c1 100644 --- a/rag/cli.py +++ b/rag/cli.py @@ -25,4 +25,3 @@ if __name__ == "__main__": print(result.answer + "\n") case _: print("Invalid option!") - diff --git a/rag/db/__init__.py b/rag/db/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/rag/db/document.py b/rag/db/document.py deleted file mode 100644 index 54ac451..0000000 --- a/rag/db/document.py +++ /dev/null @@ -1,57 +0,0 @@ -import hashlib -import os - -from langchain_community.document_loaders.blob_loaders import Blob -import psycopg -from loguru import logger as log - -TABLES = """ -CREATE TABLE IF NOT EXISTS document ( - hash text PRIMARY KEY) -""" - - -class DocumentDB: - def __init__(self) -> None: - self.conn = psycopg.connect( - f"dbname={os.environ['DOCUMENT_DB_NAME']} user={os.environ['DOCUMENT_DB_USER']}" - ) - self.__configure() - - def close(self): - self.conn.close() - - def __configure(self): - log.debug("Creating documents table if it does not exist...") - with self.conn.cursor() as cur: - cur.execute(TABLES) - self.conn.commit() - - def __hash(self, blob: Blob) -> str: - log.debug("Hashing document...") - return hashlib.sha256(blob.as_bytes()).hexdigest() - - def add(self, blob: Blob) -> bool: - with self.conn.cursor() as cur: - hash = self.__hash(blob) - cur.execute( - """ - SELECT * FROM document - WHERE - hash = %s - """, - (hash,), - ) - exist = cur.fetchone() - if exist is None: - log.debug("Inserting document hash into documents db...") - cur.execute( - """ - INSERT INTO document - (hash) VALUES - (%s) - """, - (hash,), - ) - self.conn.commit() - return exist is None diff --git a/rag/db/vector.py b/rag/db/vector.py deleted file mode 100644 index fd2b2c2..0000000 --- a/rag/db/vector.py +++ /dev/null @@ -1,73 +0,0 @@ -import os -from dataclasses import dataclass -from typing import Dict, List - -from loguru import logger as log -from qdrant_client import QdrantClient -from qdrant_client.http.models import StrictFloat -from qdrant_client.models import Distance, PointStruct, VectorParams - - -@dataclass -class Point: - id: str - vector: List[StrictFloat] - payload: Dict[str, str] - - -@dataclass -class Document: - title: str - text: str - - -class VectorDB: - def __init__(self, score_threshold: float = 0.6): - self.dim = int(os.environ["EMBEDDING_DIM"]) - self.collection_name = os.environ["QDRANT_COLLECTION_NAME"] - self.client = QdrantClient(url=os.environ["QDRANT_URL"]) - self.score_threshold = score_threshold - self.__configure() - - def __configure(self): - collections = list( - map(lambda col: col.name, self.client.get_collections().collections) - ) - if self.collection_name not in collections: - log.debug(f"Configuring collection {self.collection_name}...") - self.client.create_collection( - collection_name=self.collection_name, - vectors_config=VectorParams(size=self.dim, distance=Distance.COSINE), - ) - else: - log.debug(f"Collection {self.collection_name} already exists...") - - def add(self, points: List[Point]): - log.debug(f"Inserting {len(points)} vectors into the vector db...") - self.client.upload_points( - collection_name=self.collection_name, - points=[ - PointStruct(id=point.id, vector=point.vector, payload=point.payload) - for point in points - ], - parallel=4, - max_retries=3, - ) - - def search(self, query: List[float], limit: int = 5) -> List[Document]: - log.debug("Searching for vectors...") - hits = self.client.search( - collection_name=self.collection_name, - query_vector=query, - limit=limit, - score_threshold=self.score_threshold, - ) - log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}") - return list( - map( - lambda h: Document( - title=h.payload.get("source", ""), text=h.payload["text"] - ), - hits, - ) - ) diff --git a/rag/generator/__init__.py b/rag/generator/__init__.py new file mode 100644 index 0000000..7da603c --- /dev/null +++ b/rag/generator/__init__.py @@ -0,0 +1,15 @@ +from typing import Type + +from .abstract import AbstractGenerator +from .ollama import Ollama +from .cohere import Cohere + + +def get_generator(model: str) -> Type[AbstractGenerator]: + match model: + case "ollama": + return Ollama() + case "cohere": + return Cohere() + case _: + exit(1) diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py new file mode 100644 index 0000000..a53b5d8 --- /dev/null +++ b/rag/generator/abstract.py @@ -0,0 +1,19 @@ +from abc import ABC, abstractmethod + +from typing import Any, Generator + +from .prompt import Prompt + + +class AbstractGenerator(ABC, type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance + return cls._instances[cls] + + @abstractmethod + def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: + pass diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py new file mode 100644 index 0000000..cf95c18 --- /dev/null +++ b/rag/generator/cohere.py @@ -0,0 +1,29 @@ +import os +from typing import Any, Generator +import cohere + +from dataclasses import asdict + +from .prompt import Prompt +from .abstract import AbstractGenerator + +from loguru import logger as log + + +class Cohere(metaclass=AbstractGenerator): + def __init__(self) -> None: + self.client = cohere.Client(os.environ["COHERE_API_KEY"]) + + def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: + log.debug("Generating answer from cohere") + for event in self.client.chat_stream( + message=prompt.query, + documents=[asdict(d) for d in prompt.documents], + prompt_truncation="AUTO", + ): + if event.event_type == "text-generation": + yield event.text + elif event.event_type == "citation-generation": + yield event.citations + elif event.event_type == "stream-end": + yield event.finish_reason diff --git a/rag/generator/ollama.py b/rag/generator/ollama.py new file mode 100644 index 0000000..ec6a55f --- /dev/null +++ b/rag/generator/ollama.py @@ -0,0 +1,71 @@ +import os +from typing import Any, Generator, List + +import ollama +from loguru import logger as log + +from .prompt import Prompt +from .abstract import AbstractGenerator + +try: + from rag.retriever.vector import Document +except ModuleNotFoundError: + from retriever.vector import Document + +SYSTEM_PROMPT = ( + "# System Preamble" + "## Basic Rules" + "When you answer the user's requests, you cite your sources in your answers, according to those instructions." + "Answer the following question using the provided context.\n" + "## Style Guide" + "Unless the user asks for a different style of answer, you should answer " + "in full sentences, using proper grammar and spelling." +) + + +class Ollama(metaclass=AbstractGenerator): + def __init__(self) -> None: + self.model = os.environ["GENERATOR_MODEL"] + + def __context(self, documents: List[Document]) -> str: + results = [ + f"Document: {i}\ntitle: {doc.title}\n{doc.text}" + for i, doc in enumerate(documents) + ] + return "\n".join(results) + + def __metaprompt(self, prompt: Prompt) -> str: + # Include sources + metaprompt = ( + f'Question: "{prompt.query.strip()}"\n\n' + "Context:\n" + "\n" + f"{self.__context(prompt.documents)}\n\n" + "\n" + "Carefully perform the following instructions, in order, starting each " + "with a new line.\n" + "Firstly, Decide which of the retrieved documents are relevant to the " + "user's last input by writing 'Relevant Documents:' followed by " + "comma-separated list of document numbers.\n If none are relevant, you " + "should instead write 'None'.\n" + "Secondly, Decide which of the retrieved documents contain facts that " + "should be cited in a good answer to the user's last input by writing " + "'Cited Documents:' followed a comma-separated list of document numbers. " + "If you dont want to cite any of them, you should instead write 'None'.\n" + "Thirdly, Write 'Answer:' followed by a response to the user's last input " + "in high quality natural english. Use the retrieved documents to help you. " + "Do not insert any citations or grounding markup.\n" + "Finally, Write 'Grounded answer:' followed by a response to the user's " + "last input in high quality natural english. Use the symbols and " + " to indicate when a fact comes from a document in the search " + "result, e.g my fact for a fact from document 0." + ) + return metaprompt + + def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: + log.debug("Generating answer...") + metaprompt = self.__metaprompt(prompt) + for chunk in ollama.generate( + model=self.model, prompt=metaprompt, system=SYSTEM_PROMPT, stream=True + ): + yield chunk["response"] diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py new file mode 100644 index 0000000..ed372c9 --- /dev/null +++ b/rag/generator/prompt.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass +from typing import List + + +try: + from rag.retriever.vector import Document +except ModuleNotFoundError: + from retriever.vector import Document + + +@dataclass +class Prompt: + query: str + documents: List[Document] diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/rag/llm/cohere_generator.py b/rag/llm/cohere_generator.py deleted file mode 100644 index a6feacd..0000000 --- a/rag/llm/cohere_generator.py +++ /dev/null @@ -1,29 +0,0 @@ -import os -from typing import Any, Generator -import cohere - -from dataclasses import asdict -try: - from rag.llm.ollama_generator import Prompt -except ModuleNotFoundError: - from llm.ollama_generator import Prompt -from loguru import logger as log - - -class CohereGenerator: - def __init__(self) -> None: - self.client = cohere.Client(os.environ["COHERE_API_KEY"]) - - def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: - log.debug("Generating answer from cohere") - for event in self.client.chat_stream( - message=prompt.query, - documents=[asdict(d) for d in prompt.documents], - prompt_truncation="AUTO", - ): - if event.event_type == "text-generation": - yield event.text - elif event.event_type == "citation-generation": - yield event.citations - elif event.event_type == "stream-end": - yield event.finish_reason diff --git a/rag/llm/encoder.py b/rag/llm/encoder.py deleted file mode 100644 index a59b1b4..0000000 --- a/rag/llm/encoder.py +++ /dev/null @@ -1,47 +0,0 @@ -import os -from pathlib import Path -from typing import List, Dict -from uuid import uuid4 - -import ollama -from langchain_core.documents import Document -from loguru import logger as log -from qdrant_client.http.models import StrictFloat - - -try: - from rag.db.vector import Point -except ModuleNotFoundError: - from db.vector import Point - - -class Encoder: - def __init__(self) -> None: - self.model = os.environ["ENCODER_MODEL"] - self.query_prompt = "Represent this sentence for searching relevant passages: " - - def __encode(self, prompt: str) -> List[StrictFloat]: - return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"]) - - def __get_source(self, metadata: Dict[str, str]) -> str: - source = metadata["source"] - return Path(source).name - - def encode_document(self, chunks: List[Document]) -> List[Point]: - log.debug("Encoding document...") - return [ - Point( - id=uuid4().hex, - vector=self.__encode(chunk.page_content), - payload={ - "text": chunk.page_content, - "source": self.__get_source(chunk.metadata), - }, - ) - for chunk in chunks - ] - - def encode_query(self, query: str) -> List[StrictFloat]: - log.debug(f"Encoding query: {query}") - query = self.query_prompt + query - return self.__encode(query) diff --git a/rag/llm/ollama_generator.py b/rag/llm/ollama_generator.py deleted file mode 100644 index dd17f8d..0000000 --- a/rag/llm/ollama_generator.py +++ /dev/null @@ -1,76 +0,0 @@ -import os -from dataclasses import dataclass -from typing import Any, Generator, List - -import ollama -from loguru import logger as log - -try: - from rag.db.vector import Document -except ModuleNotFoundError: - from db.vector import Document - - -@dataclass -class Prompt: - query: str - documents: List[Document] - - -SYSTEM_PROMPT = ( - "# System Preamble" - "## Basic Rules" - "When you answer the user's requests, you cite your sources in your answers, according to those instructions." - "Answer the following question using the provided context.\n" - "## Style Guide" - "Unless the user asks for a different style of answer, you should answer " - "in full sentences, using proper grammar and spelling." -) - - -class OllamaGenerator: - def __init__(self) -> None: - self.model = os.environ["GENERATOR_MODEL"] - - def __context(self, documents: List[Document]) -> str: - results = [ - f"Document: {i}\ntitle: {doc.title}\n{doc.text}" - for i, doc in enumerate(documents) - ] - return "\n".join(results) - - def __metaprompt(self, prompt: Prompt) -> str: - # Include sources - metaprompt = ( - f'Question: "{prompt.query.strip()}"\n\n' - "Context:\n" - "\n" - f"{self.__context(prompt.documents)}\n\n" - "\n" - "Carefully perform the following instructions, in order, starting each " - "with a new line.\n" - "Firstly, Decide which of the retrieved documents are relevant to the " - "user's last input by writing 'Relevant Documents:' followed by " - "comma-separated list of document numbers.\n If none are relevant, you " - "should instead write 'None'.\n" - "Secondly, Decide which of the retrieved documents contain facts that " - "should be cited in a good answer to the user's last input by writing " - "'Cited Documents:' followed a comma-separated list of document numbers. " - "If you dont want to cite any of them, you should instead write 'None'.\n" - "Thirdly, Write 'Answer:' followed by a response to the user's last input " - "in high quality natural english. Use the retrieved documents to help you. " - "Do not insert any citations or grounding markup.\n" - "Finally, Write 'Grounded answer:' followed by a response to the user's " - "last input in high quality natural english. Use the symbols and " - " to indicate when a fact comes from a document in the search " - "result, e.g my fact for a fact from document 0." - ) - return metaprompt - - def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: - log.debug("Generating answer...") - metaprompt = self.__metaprompt(prompt) - for chunk in ollama.generate( - model=self.model, prompt=metaprompt, system=SYSTEM_PROMPT, stream=True - ): - yield chunk diff --git a/rag/parser/__init__.py b/rag/parser/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/rag/parser/pdf.py b/rag/parser/pdf.py deleted file mode 100644 index ca9b72d..0000000 --- a/rag/parser/pdf.py +++ /dev/null @@ -1,34 +0,0 @@ -import os -from pathlib import Path -from typing import Iterator, List, Optional - -from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain_core.documents import Document -from langchain_community.document_loaders.parsers.pdf import ( - PyPDFParser, -) -from langchain_community.document_loaders.blob_loaders import Blob - - -class PDFParser: - def __init__(self) -> None: - self.parser = PyPDFParser(password=None, extract_images=False) - - def from_data(self, blob: Blob) -> Iterator[Document]: - return self.parser.parse(blob) - - def from_path(self, path: Path) -> Iterator[Document]: - return Blob.from_path(path) - - def chunk( - self, document: Iterator[Document], source: Optional[str] = None - ) -> List[Document]: - splitter = RecursiveCharacterTextSplitter( - chunk_size=int(os.environ["CHUNK_SIZE"]), - chunk_overlap=int(os.environ["CHUNK_OVERLAP"]), - ) - chunks = splitter.split_documents(document) - if source is not None: - for c in chunks: - c.metadata["source"] = source - return chunks diff --git a/rag/rag.py b/rag/rag.py index 93f9fd7..c95d93a 100644 --- a/rag/rag.py +++ b/rag/rag.py @@ -1,27 +1,22 @@ from dataclasses import dataclass from io import BytesIO from pathlib import Path -from typing import List +from typing import List, Optional, Type from dotenv import load_dotenv from loguru import logger as log - try: - from rag.db.vector import VectorDB, Document - from rag.db.document import DocumentDB - from rag.llm.encoder import Encoder - from rag.llm.ollama_generator import OllamaGenerator, Prompt - from rag.llm.cohere_generator import CohereGenerator - from rag.parser.pdf import PDFParser + from rag.retriever.vector import Document + from rag.generator.abstract import AbstractGenerator + from rag.retriever.retriever import Retriever + from rag.generator.prompt import Prompt except ModuleNotFoundError: - from db.vector import VectorDB, Document - from db.document import DocumentDB - from llm.encoder import Encoder - from llm.ollama_generator import OllamaGenerator, Prompt - from llm.cohere_generator import CohereGenerator - from parser.pdf import PDFParser + from retriever.vector import Document + from generator.abstract import AbstractGenerator + from retriever.retriever import Retriever + from generator.prompt import Prompt @dataclass @@ -35,29 +30,23 @@ class RAG: def __init__(self) -> None: # FIXME: load this somewhere else? load_dotenv() - self.pdf_parser = PDFParser() - self.generator = CohereGenerator() - self.encoder = Encoder() - self.vector_db = VectorDB() - self.doc_db = DocumentDB() - - def add_pdf_from_path(self, path: Path): - blob = self.pdf_parser.from_path(path) - self.add_pdf_from_blob(blob) - - def add_pdf_from_blob(self, blob: BytesIO, source: str): - if self.doc_db.add(blob): - log.debug("Adding pdf to vector database...") - document = self.pdf_parser.from_data(blob) - chunks = self.pdf_parser.chunk(document, source) - points = self.encoder.encode_document(chunks) - self.vector_db.add(points) + self.retriever = Retriever() + + def add_pdf( + self, + path: Optional[Path], + blob: Optional[BytesIO], + source: Optional[str] = None, + ): + if path: + self.retriever.add_pdf_from_path(path) + elif blob: + self.retriever.add_pdf_from_blob(blob, source) else: - log.debug("Document already exists!") + log.error("Both path and blob was None, no pdf added!") - def search(self, query: str, limit: int = 5) -> List[Document]: - query_emb = self.encoder.encode_query(query) - return self.vector_db.search(query_emb, limit) + def retrieve(self, query: str, limit: int = 5) -> List[Document]: + return self.retriever.retrieve(query, limit) - def retrieve(self, prompt: Prompt): - yield from self.generator.generate(prompt) + def generate(self, generator: Type[AbstractGenerator], prompt: Prompt): + yield from generator.generate(prompt) diff --git a/rag/retriever/__init__.py b/rag/retriever/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rag/retriever/document.py b/rag/retriever/document.py new file mode 100644 index 0000000..54ac451 --- /dev/null +++ b/rag/retriever/document.py @@ -0,0 +1,57 @@ +import hashlib +import os + +from langchain_community.document_loaders.blob_loaders import Blob +import psycopg +from loguru import logger as log + +TABLES = """ +CREATE TABLE IF NOT EXISTS document ( + hash text PRIMARY KEY) +""" + + +class DocumentDB: + def __init__(self) -> None: + self.conn = psycopg.connect( + f"dbname={os.environ['DOCUMENT_DB_NAME']} user={os.environ['DOCUMENT_DB_USER']}" + ) + self.__configure() + + def close(self): + self.conn.close() + + def __configure(self): + log.debug("Creating documents table if it does not exist...") + with self.conn.cursor() as cur: + cur.execute(TABLES) + self.conn.commit() + + def __hash(self, blob: Blob) -> str: + log.debug("Hashing document...") + return hashlib.sha256(blob.as_bytes()).hexdigest() + + def add(self, blob: Blob) -> bool: + with self.conn.cursor() as cur: + hash = self.__hash(blob) + cur.execute( + """ + SELECT * FROM document + WHERE + hash = %s + """, + (hash,), + ) + exist = cur.fetchone() + if exist is None: + log.debug("Inserting document hash into documents db...") + cur.execute( + """ + INSERT INTO document + (hash) VALUES + (%s) + """, + (hash,), + ) + self.conn.commit() + return exist is None diff --git a/rag/retriever/encoder.py b/rag/retriever/encoder.py new file mode 100644 index 0000000..753157f --- /dev/null +++ b/rag/retriever/encoder.py @@ -0,0 +1,43 @@ +import os +from pathlib import Path +from typing import List, Dict +from uuid import uuid4 + +import ollama +from langchain_core.documents import Document +from loguru import logger as log +from qdrant_client.http.models import StrictFloat + +from .vector import Point + + +class Encoder: + def __init__(self) -> None: + self.model = os.environ["ENCODER_MODEL"] + self.query_prompt = "Represent this sentence for searching relevant passages: " + + def __encode(self, prompt: str) -> List[StrictFloat]: + return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"]) + + def __get_source(self, metadata: Dict[str, str]) -> str: + source = metadata["source"] + return Path(source).name + + def encode_document(self, chunks: List[Document]) -> List[Point]: + log.debug("Encoding document...") + return [ + Point( + id=uuid4().hex, + vector=self.__encode(chunk.page_content), + payload={ + "text": chunk.page_content, + "source": self.__get_source(chunk.metadata), + }, + ) + for chunk in chunks + ] + + def encode_query(self, query: str) -> List[StrictFloat]: + log.debug(f"Encoding query: {query}") + query = self.query_prompt + query + return self.__encode(query) diff --git a/rag/retriever/parser/__init__.py b/rag/retriever/parser/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rag/retriever/parser/pdf.py b/rag/retriever/parser/pdf.py new file mode 100644 index 0000000..410f027 --- /dev/null +++ b/rag/retriever/parser/pdf.py @@ -0,0 +1,34 @@ +import os +from pathlib import Path +from typing import List, Optional + +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_core.documents import Document +from langchain_community.document_loaders.parsers.pdf import ( + PyPDFParser, +) +from langchain_community.document_loaders.blob_loaders import Blob + + +class PDFParser: + def __init__(self) -> None: + self.parser = PyPDFParser(password=None, extract_images=False) + + def from_data(self, blob: Blob) -> List[Document]: + return self.parser.parse(blob) + + def from_path(self, path: Path) -> Blob: + return Blob.from_path(path) + + def chunk( + self, document: List[Document], source: Optional[str] = None + ) -> List[Document]: + splitter = RecursiveCharacterTextSplitter( + chunk_size=int(os.environ["CHUNK_SIZE"]), + chunk_overlap=int(os.environ["CHUNK_OVERLAP"]), + ) + chunks = splitter.split_documents(document) + if source is not None: + for c in chunks: + c.metadata["source"] = source + return chunks diff --git a/rag/retriever/retriever.py b/rag/retriever/retriever.py new file mode 100644 index 0000000..dbfdfa2 --- /dev/null +++ b/rag/retriever/retriever.py @@ -0,0 +1,37 @@ +from pathlib import Path +from typing import Optional, List +from loguru import logger as log + +from io import BytesIO +from .document import DocumentDB +from .encoder import Encoder +from .parser.pdf import PDFParser +from .vector import VectorDB, Document + + +class Retriever: + def __init__(self) -> None: + self.pdf_parser = PDFParser() + self.encoder = Encoder() + self.doc_db = DocumentDB() + self.vec_db = VectorDB() + + def add_pdf_from_path(self, path: Path): + log.debug(f"Adding pdf from {path}") + blob = self.pdf_parser.from_path(path) + self.add_pdf_from_blob(blob) + + def add_pdf_from_blob(self, blob: BytesIO, source: Optional[str] = None): + if self.doc_db.add(blob): + log.debug("Adding pdf to vector database...") + document = self.pdf_parser.from_data(blob) + chunks = self.pdf_parser.chunk(document, source) + points = self.encoder.encode_document(chunks) + self.vec_db.add(points) + else: + log.debug("Document already exists!") + + def retrieve(self, query: str, limit: int = 5) -> List[Document]: + log.debug(f"Finding documents matching query: {query}") + query_emb = self.encoder.encode_query(query) + return self.vec_db.search(query_emb, limit) diff --git a/rag/retriever/vector.py b/rag/retriever/vector.py new file mode 100644 index 0000000..fd2b2c2 --- /dev/null +++ b/rag/retriever/vector.py @@ -0,0 +1,73 @@ +import os +from dataclasses import dataclass +from typing import Dict, List + +from loguru import logger as log +from qdrant_client import QdrantClient +from qdrant_client.http.models import StrictFloat +from qdrant_client.models import Distance, PointStruct, VectorParams + + +@dataclass +class Point: + id: str + vector: List[StrictFloat] + payload: Dict[str, str] + + +@dataclass +class Document: + title: str + text: str + + +class VectorDB: + def __init__(self, score_threshold: float = 0.6): + self.dim = int(os.environ["EMBEDDING_DIM"]) + self.collection_name = os.environ["QDRANT_COLLECTION_NAME"] + self.client = QdrantClient(url=os.environ["QDRANT_URL"]) + self.score_threshold = score_threshold + self.__configure() + + def __configure(self): + collections = list( + map(lambda col: col.name, self.client.get_collections().collections) + ) + if self.collection_name not in collections: + log.debug(f"Configuring collection {self.collection_name}...") + self.client.create_collection( + collection_name=self.collection_name, + vectors_config=VectorParams(size=self.dim, distance=Distance.COSINE), + ) + else: + log.debug(f"Collection {self.collection_name} already exists...") + + def add(self, points: List[Point]): + log.debug(f"Inserting {len(points)} vectors into the vector db...") + self.client.upload_points( + collection_name=self.collection_name, + points=[ + PointStruct(id=point.id, vector=point.vector, payload=point.payload) + for point in points + ], + parallel=4, + max_retries=3, + ) + + def search(self, query: List[float], limit: int = 5) -> List[Document]: + log.debug("Searching for vectors...") + hits = self.client.search( + collection_name=self.collection_name, + query_vector=query, + limit=limit, + score_threshold=self.score_threshold, + ) + log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}") + return list( + map( + lambda h: Document( + title=h.payload.get("source", ""), text=h.payload["text"] + ), + hits, + ) + ) diff --git a/rag/ui.py b/rag/ui.py index 84dbbeb..83a22e2 100644 --- a/rag/ui.py +++ b/rag/ui.py @@ -1,16 +1,15 @@ import streamlit as st from langchain_community.document_loaders.blob_loaders import Blob +from .rag import RAG -try: - from rag.rag import RAG - from rag.llm.ollama_generator import Prompt -except ModuleNotFoundError: - from rag import RAG - from llm.ollama_generator import Prompt +from .generator import get_generator +from .generator.prompt import Prompt rag = RAG() +MODELS = ["ollama", "cohere"] + def upload_pdfs(): files = st.file_uploader( @@ -26,13 +25,15 @@ def upload_pdfs(): for file in files: source = file.name blob = Blob.from_data(file.read()) - rag.add_pdf_from_blob(blob, source) + rag.add_pdf(blob, source) if __name__ == "__main__": ss = st.session_state st.header("RAG-UI") + model = st.selectbox("Model", options=MODELS) + upload_pdfs() with st.form(key="query"): @@ -56,7 +57,7 @@ if __name__ == "__main__": query = ss.get("query", "") with st.spinner("Searching for documents..."): - documents = rag.search(query) + documents = rag.retrieve(query) prompt = Prompt(query, documents) @@ -69,6 +70,6 @@ if __name__ == "__main__": st.markdown("---") with result_column: + generator = get_generator(model) st.markdown("### Answer") - st.write_stream(rag.retrieve(prompt)) - + st.write_stream(rag.generate(generator, prompt)) -- cgit v1.2.3-70-g09d2