diff options
-rw-r--r-- | notebooks/testing.ipynb | 53 | ||||
-rw-r--r-- | poetry.lock | 2 | ||||
-rw-r--r-- | rag/cli.py | 1 | ||||
-rw-r--r-- | rag/generator/__init__.py | 15 | ||||
-rw-r--r-- | rag/generator/abstract.py | 19 | ||||
-rw-r--r-- | rag/generator/cohere.py (renamed from rag/llm/cohere_generator.py) | 10 | ||||
-rw-r--r-- | rag/generator/ollama.py (renamed from rag/llm/ollama_generator.py) | 21 | ||||
-rw-r--r-- | rag/generator/prompt.py | 14 | ||||
-rw-r--r-- | rag/parser/__init__.py | 0 | ||||
-rw-r--r-- | rag/rag.py | 63 | ||||
-rw-r--r-- | rag/retriever/__init__.py (renamed from rag/db/__init__.py) | 0 | ||||
-rw-r--r-- | rag/retriever/document.py (renamed from rag/db/document.py) | 0 | ||||
-rw-r--r-- | rag/retriever/encoder.py (renamed from rag/llm/encoder.py) | 6 | ||||
-rw-r--r-- | rag/retriever/parser/__init__.py (renamed from rag/llm/__init__.py) | 0 | ||||
-rw-r--r-- | rag/retriever/parser/pdf.py (renamed from rag/parser/pdf.py) | 8 | ||||
-rw-r--r-- | rag/retriever/retriever.py | 37 | ||||
-rw-r--r-- | rag/retriever/vector.py (renamed from rag/db/vector.py) | 0 | ||||
-rw-r--r-- | rag/ui.py | 21 |
18 files changed, 193 insertions, 77 deletions
diff --git a/notebooks/testing.ipynb b/notebooks/testing.ipynb index 18827fc..bcaa613 100644 --- a/notebooks/testing.ipynb +++ b/notebooks/testing.ipynb @@ -98,10 +98,61 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "05f7068c-b4c6-47b2-ac62-79c021838500", "metadata": {}, "outputs": [], + "source": [ + "from rag.generator import get_generator" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "31dd3e7a-1d6c-4cc6-8ffe-f83478f95875", + "metadata": {}, + "outputs": [], + "source": [ + "x = get_generator(\"ollama\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "749c58c4-5250-464c-b4f4-b5b4a00be3b3", + "metadata": {}, + "outputs": [], + "source": [ + "y = get_generator(\"ollama\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c59ece14-11e0-416c-afad-bc2ff6a526ec", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x is y" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e664d3d-a787-45e5-8b71-8e5e0f348443", + "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/poetry.lock b/poetry.lock index ffeea46..8a2b3a2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -388,7 +388,7 @@ files = [ [[package]] name = "cffi" version = "1.16.0" -description = "Foreign Function Interface for Python calling C code." +description = "Foreign Function AbstractGenerator for Python calling C code." optional = false python-versions = ">=3.8" files = [ @@ -25,4 +25,3 @@ if __name__ == "__main__": print(result.answer + "\n") case _: print("Invalid option!") - diff --git a/rag/generator/__init__.py b/rag/generator/__init__.py new file mode 100644 index 0000000..7da603c --- /dev/null +++ b/rag/generator/__init__.py @@ -0,0 +1,15 @@ +from typing import Type + +from .abstract import AbstractGenerator +from .ollama import Ollama +from .cohere import Cohere + + +def get_generator(model: str) -> Type[AbstractGenerator]: + match model: + case "ollama": + return Ollama() + case "cohere": + return Cohere() + case _: + exit(1) diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py new file mode 100644 index 0000000..a53b5d8 --- /dev/null +++ b/rag/generator/abstract.py @@ -0,0 +1,19 @@ +from abc import ABC, abstractmethod + +from typing import Any, Generator + +from .prompt import Prompt + + +class AbstractGenerator(ABC, type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance + return cls._instances[cls] + + @abstractmethod + def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: + pass diff --git a/rag/llm/cohere_generator.py b/rag/generator/cohere.py index a6feacd..cf95c18 100644 --- a/rag/llm/cohere_generator.py +++ b/rag/generator/cohere.py @@ -3,14 +3,14 @@ from typing import Any, Generator import cohere from dataclasses import asdict -try: - from rag.llm.ollama_generator import Prompt -except ModuleNotFoundError: - from llm.ollama_generator import Prompt + +from .prompt import Prompt +from .abstract import AbstractGenerator + from loguru import logger as log -class CohereGenerator: +class Cohere(metaclass=AbstractGenerator): def __init__(self) -> None: self.client = cohere.Client(os.environ["COHERE_API_KEY"]) diff --git a/rag/llm/ollama_generator.py b/rag/generator/ollama.py index dd17f8d..ec6a55f 100644 --- a/rag/llm/ollama_generator.py +++ b/rag/generator/ollama.py @@ -1,21 +1,16 @@ import os -from dataclasses import dataclass from typing import Any, Generator, List import ollama from loguru import logger as log +from .prompt import Prompt +from .abstract import AbstractGenerator + try: - from rag.db.vector import Document + from rag.retriever.vector import Document except ModuleNotFoundError: - from db.vector import Document - - -@dataclass -class Prompt: - query: str - documents: List[Document] - + from retriever.vector import Document SYSTEM_PROMPT = ( "# System Preamble" @@ -28,7 +23,7 @@ SYSTEM_PROMPT = ( ) -class OllamaGenerator: +class Ollama(metaclass=AbstractGenerator): def __init__(self) -> None: self.model = os.environ["GENERATOR_MODEL"] @@ -72,5 +67,5 @@ class OllamaGenerator: metaprompt = self.__metaprompt(prompt) for chunk in ollama.generate( model=self.model, prompt=metaprompt, system=SYSTEM_PROMPT, stream=True - ): - yield chunk + ): + yield chunk["response"] diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py new file mode 100644 index 0000000..ed372c9 --- /dev/null +++ b/rag/generator/prompt.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass +from typing import List + + +try: + from rag.retriever.vector import Document +except ModuleNotFoundError: + from retriever.vector import Document + + +@dataclass +class Prompt: + query: str + documents: List[Document] diff --git a/rag/parser/__init__.py b/rag/parser/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/rag/parser/__init__.py +++ /dev/null @@ -1,27 +1,22 @@ from dataclasses import dataclass from io import BytesIO from pathlib import Path -from typing import List +from typing import List, Optional, Type from dotenv import load_dotenv from loguru import logger as log - try: - from rag.db.vector import VectorDB, Document - from rag.db.document import DocumentDB - from rag.llm.encoder import Encoder - from rag.llm.ollama_generator import OllamaGenerator, Prompt - from rag.llm.cohere_generator import CohereGenerator - from rag.parser.pdf import PDFParser + from rag.retriever.vector import Document + from rag.generator.abstract import AbstractGenerator + from rag.retriever.retriever import Retriever + from rag.generator.prompt import Prompt except ModuleNotFoundError: - from db.vector import VectorDB, Document - from db.document import DocumentDB - from llm.encoder import Encoder - from llm.ollama_generator import OllamaGenerator, Prompt - from llm.cohere_generator import CohereGenerator - from parser.pdf import PDFParser + from retriever.vector import Document + from generator.abstract import AbstractGenerator + from retriever.retriever import Retriever + from generator.prompt import Prompt @dataclass @@ -35,29 +30,23 @@ class RAG: def __init__(self) -> None: # FIXME: load this somewhere else? load_dotenv() - self.pdf_parser = PDFParser() - self.generator = CohereGenerator() - self.encoder = Encoder() - self.vector_db = VectorDB() - self.doc_db = DocumentDB() - - def add_pdf_from_path(self, path: Path): - blob = self.pdf_parser.from_path(path) - self.add_pdf_from_blob(blob) - - def add_pdf_from_blob(self, blob: BytesIO, source: str): - if self.doc_db.add(blob): - log.debug("Adding pdf to vector database...") - document = self.pdf_parser.from_data(blob) - chunks = self.pdf_parser.chunk(document, source) - points = self.encoder.encode_document(chunks) - self.vector_db.add(points) + self.retriever = Retriever() + + def add_pdf( + self, + path: Optional[Path], + blob: Optional[BytesIO], + source: Optional[str] = None, + ): + if path: + self.retriever.add_pdf_from_path(path) + elif blob: + self.retriever.add_pdf_from_blob(blob, source) else: - log.debug("Document already exists!") + log.error("Both path and blob was None, no pdf added!") - def search(self, query: str, limit: int = 5) -> List[Document]: - query_emb = self.encoder.encode_query(query) - return self.vector_db.search(query_emb, limit) + def retrieve(self, query: str, limit: int = 5) -> List[Document]: + return self.retriever.retrieve(query, limit) - def retrieve(self, prompt: Prompt): - yield from self.generator.generate(prompt) + def generate(self, generator: Type[AbstractGenerator], prompt: Prompt): + yield from generator.generate(prompt) diff --git a/rag/db/__init__.py b/rag/retriever/__init__.py index e69de29..e69de29 100644 --- a/rag/db/__init__.py +++ b/rag/retriever/__init__.py diff --git a/rag/db/document.py b/rag/retriever/document.py index 54ac451..54ac451 100644 --- a/rag/db/document.py +++ b/rag/retriever/document.py diff --git a/rag/llm/encoder.py b/rag/retriever/encoder.py index a59b1b4..753157f 100644 --- a/rag/llm/encoder.py +++ b/rag/retriever/encoder.py @@ -8,11 +8,7 @@ from langchain_core.documents import Document from loguru import logger as log from qdrant_client.http.models import StrictFloat - -try: - from rag.db.vector import Point -except ModuleNotFoundError: - from db.vector import Point +from .vector import Point class Encoder: diff --git a/rag/llm/__init__.py b/rag/retriever/parser/__init__.py index e69de29..e69de29 100644 --- a/rag/llm/__init__.py +++ b/rag/retriever/parser/__init__.py diff --git a/rag/parser/pdf.py b/rag/retriever/parser/pdf.py index ca9b72d..410f027 100644 --- a/rag/parser/pdf.py +++ b/rag/retriever/parser/pdf.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Iterator, List, Optional +from typing import List, Optional from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_core.documents import Document @@ -14,14 +14,14 @@ class PDFParser: def __init__(self) -> None: self.parser = PyPDFParser(password=None, extract_images=False) - def from_data(self, blob: Blob) -> Iterator[Document]: + def from_data(self, blob: Blob) -> List[Document]: return self.parser.parse(blob) - def from_path(self, path: Path) -> Iterator[Document]: + def from_path(self, path: Path) -> Blob: return Blob.from_path(path) def chunk( - self, document: Iterator[Document], source: Optional[str] = None + self, document: List[Document], source: Optional[str] = None ) -> List[Document]: splitter = RecursiveCharacterTextSplitter( chunk_size=int(os.environ["CHUNK_SIZE"]), diff --git a/rag/retriever/retriever.py b/rag/retriever/retriever.py new file mode 100644 index 0000000..dbfdfa2 --- /dev/null +++ b/rag/retriever/retriever.py @@ -0,0 +1,37 @@ +from pathlib import Path +from typing import Optional, List +from loguru import logger as log + +from io import BytesIO +from .document import DocumentDB +from .encoder import Encoder +from .parser.pdf import PDFParser +from .vector import VectorDB, Document + + +class Retriever: + def __init__(self) -> None: + self.pdf_parser = PDFParser() + self.encoder = Encoder() + self.doc_db = DocumentDB() + self.vec_db = VectorDB() + + def add_pdf_from_path(self, path: Path): + log.debug(f"Adding pdf from {path}") + blob = self.pdf_parser.from_path(path) + self.add_pdf_from_blob(blob) + + def add_pdf_from_blob(self, blob: BytesIO, source: Optional[str] = None): + if self.doc_db.add(blob): + log.debug("Adding pdf to vector database...") + document = self.pdf_parser.from_data(blob) + chunks = self.pdf_parser.chunk(document, source) + points = self.encoder.encode_document(chunks) + self.vec_db.add(points) + else: + log.debug("Document already exists!") + + def retrieve(self, query: str, limit: int = 5) -> List[Document]: + log.debug(f"Finding documents matching query: {query}") + query_emb = self.encoder.encode_query(query) + return self.vec_db.search(query_emb, limit) diff --git a/rag/db/vector.py b/rag/retriever/vector.py index fd2b2c2..fd2b2c2 100644 --- a/rag/db/vector.py +++ b/rag/retriever/vector.py @@ -1,16 +1,15 @@ import streamlit as st from langchain_community.document_loaders.blob_loaders import Blob +from .rag import RAG -try: - from rag.rag import RAG - from rag.llm.ollama_generator import Prompt -except ModuleNotFoundError: - from rag import RAG - from llm.ollama_generator import Prompt +from .generator import get_generator +from .generator.prompt import Prompt rag = RAG() +MODELS = ["ollama", "cohere"] + def upload_pdfs(): files = st.file_uploader( @@ -26,13 +25,15 @@ def upload_pdfs(): for file in files: source = file.name blob = Blob.from_data(file.read()) - rag.add_pdf_from_blob(blob, source) + rag.add_pdf(blob, source) if __name__ == "__main__": ss = st.session_state st.header("RAG-UI") + model = st.selectbox("Model", options=MODELS) + upload_pdfs() with st.form(key="query"): @@ -56,7 +57,7 @@ if __name__ == "__main__": query = ss.get("query", "") with st.spinner("Searching for documents..."): - documents = rag.search(query) + documents = rag.retrieve(query) prompt = Prompt(query, documents) @@ -69,6 +70,6 @@ if __name__ == "__main__": st.markdown("---") with result_column: + generator = get_generator(model) st.markdown("### Answer") - st.write_stream(rag.retrieve(prompt)) - + st.write_stream(rag.generate(generator, prompt)) |