diff options
Diffstat (limited to 'rag')
| -rw-r--r-- | rag/cli.py | 1 | ||||
| -rw-r--r-- | rag/generator/__init__.py | 15 | ||||
| -rw-r--r-- | rag/generator/abstract.py | 19 | ||||
| -rw-r--r-- | rag/generator/cohere.py (renamed from rag/llm/cohere_generator.py) | 10 | ||||
| -rw-r--r-- | rag/generator/ollama.py (renamed from rag/llm/ollama_generator.py) | 21 | ||||
| -rw-r--r-- | rag/generator/prompt.py | 14 | ||||
| -rw-r--r-- | rag/parser/__init__.py | 0 | ||||
| -rw-r--r-- | rag/rag.py | 63 | ||||
| -rw-r--r-- | rag/retriever/__init__.py (renamed from rag/db/__init__.py) | 0 | ||||
| -rw-r--r-- | rag/retriever/document.py (renamed from rag/db/document.py) | 0 | ||||
| -rw-r--r-- | rag/retriever/encoder.py (renamed from rag/llm/encoder.py) | 6 | ||||
| -rw-r--r-- | rag/retriever/parser/__init__.py (renamed from rag/llm/__init__.py) | 0 | ||||
| -rw-r--r-- | rag/retriever/parser/pdf.py (renamed from rag/parser/pdf.py) | 8 | ||||
| -rw-r--r-- | rag/retriever/retriever.py | 37 | ||||
| -rw-r--r-- | rag/retriever/vector.py (renamed from rag/db/vector.py) | 0 | ||||
| -rw-r--r-- | rag/ui.py | 21 | 
16 files changed, 140 insertions, 75 deletions
@@ -25,4 +25,3 @@ if __name__ == "__main__":                      print(result.answer + "\n")              case _:                  print("Invalid option!") - diff --git a/rag/generator/__init__.py b/rag/generator/__init__.py new file mode 100644 index 0000000..7da603c --- /dev/null +++ b/rag/generator/__init__.py @@ -0,0 +1,15 @@ +from typing import Type + +from .abstract import AbstractGenerator +from .ollama import Ollama +from .cohere import Cohere + + +def get_generator(model: str) -> Type[AbstractGenerator]: +    match model: +        case "ollama": +            return Ollama() +        case "cohere": +            return Cohere() +        case _: +            exit(1) diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py new file mode 100644 index 0000000..a53b5d8 --- /dev/null +++ b/rag/generator/abstract.py @@ -0,0 +1,19 @@ +from abc import ABC, abstractmethod + +from typing import Any, Generator + +from .prompt import Prompt + + +class AbstractGenerator(ABC, type): +    _instances = {} + +    def __call__(cls, *args, **kwargs): +        if cls not in cls._instances: +            instance = super().__call__(*args, **kwargs) +            cls._instances[cls] = instance +        return cls._instances[cls] + +    @abstractmethod +    def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: +        pass diff --git a/rag/llm/cohere_generator.py b/rag/generator/cohere.py index a6feacd..cf95c18 100644 --- a/rag/llm/cohere_generator.py +++ b/rag/generator/cohere.py @@ -3,14 +3,14 @@ from typing import Any, Generator  import cohere  from dataclasses import asdict -try: -    from rag.llm.ollama_generator import Prompt -except ModuleNotFoundError: -    from llm.ollama_generator import Prompt + +from .prompt import Prompt +from .abstract import AbstractGenerator +  from loguru import logger as log -class CohereGenerator: +class Cohere(metaclass=AbstractGenerator):      def __init__(self) -> None:          self.client = cohere.Client(os.environ["COHERE_API_KEY"]) diff --git a/rag/llm/ollama_generator.py b/rag/generator/ollama.py index dd17f8d..ec6a55f 100644 --- a/rag/llm/ollama_generator.py +++ b/rag/generator/ollama.py @@ -1,21 +1,16 @@  import os -from dataclasses import dataclass  from typing import Any, Generator, List  import ollama  from loguru import logger as log +from .prompt import Prompt +from .abstract import AbstractGenerator +  try: -    from rag.db.vector import Document +    from rag.retriever.vector import Document  except ModuleNotFoundError: -    from db.vector import Document - - -@dataclass -class Prompt: -    query: str -    documents: List[Document] - +    from retriever.vector import Document  SYSTEM_PROMPT = (      "# System Preamble" @@ -28,7 +23,7 @@ SYSTEM_PROMPT = (  ) -class OllamaGenerator: +class Ollama(metaclass=AbstractGenerator):      def __init__(self) -> None:          self.model = os.environ["GENERATOR_MODEL"] @@ -72,5 +67,5 @@ class OllamaGenerator:          metaprompt = self.__metaprompt(prompt)          for chunk in ollama.generate(              model=self.model, prompt=metaprompt, system=SYSTEM_PROMPT, stream=True -            ): -            yield chunk +        ): +            yield chunk["response"] diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py new file mode 100644 index 0000000..ed372c9 --- /dev/null +++ b/rag/generator/prompt.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass +from typing import List + + +try: +    from rag.retriever.vector import Document +except ModuleNotFoundError: +    from retriever.vector import Document + + +@dataclass +class Prompt: +    query: str +    documents: List[Document] diff --git a/rag/parser/__init__.py b/rag/parser/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/rag/parser/__init__.py +++ /dev/null @@ -1,27 +1,22 @@  from dataclasses import dataclass  from io import BytesIO  from pathlib import Path -from typing import List +from typing import List, Optional, Type  from dotenv import load_dotenv  from loguru import logger as log -  try: -    from rag.db.vector import VectorDB, Document -    from rag.db.document import DocumentDB -    from rag.llm.encoder import Encoder -    from rag.llm.ollama_generator import OllamaGenerator, Prompt -    from rag.llm.cohere_generator import CohereGenerator -    from rag.parser.pdf import PDFParser +    from rag.retriever.vector import Document +    from rag.generator.abstract import AbstractGenerator +    from rag.retriever.retriever import Retriever +    from rag.generator.prompt import Prompt  except ModuleNotFoundError: -    from db.vector import VectorDB, Document -    from db.document import DocumentDB -    from llm.encoder import Encoder -    from llm.ollama_generator import OllamaGenerator, Prompt -    from llm.cohere_generator import CohereGenerator -    from parser.pdf import PDFParser +    from retriever.vector import Document +    from generator.abstract import AbstractGenerator +    from retriever.retriever import Retriever +    from generator.prompt import Prompt  @dataclass @@ -35,29 +30,23 @@ class RAG:      def __init__(self) -> None:          # FIXME: load this somewhere else?          load_dotenv() -        self.pdf_parser = PDFParser() -        self.generator = CohereGenerator() -        self.encoder = Encoder() -        self.vector_db = VectorDB() -        self.doc_db = DocumentDB() - -    def add_pdf_from_path(self, path: Path): -        blob = self.pdf_parser.from_path(path) -        self.add_pdf_from_blob(blob) - -    def add_pdf_from_blob(self, blob: BytesIO, source: str): -        if self.doc_db.add(blob): -            log.debug("Adding pdf to vector database...") -            document = self.pdf_parser.from_data(blob) -            chunks = self.pdf_parser.chunk(document, source) -            points = self.encoder.encode_document(chunks) -            self.vector_db.add(points) +        self.retriever = Retriever() + +    def add_pdf( +        self, +        path: Optional[Path], +        blob: Optional[BytesIO], +        source: Optional[str] = None, +    ): +        if path: +            self.retriever.add_pdf_from_path(path) +        elif blob: +            self.retriever.add_pdf_from_blob(blob, source)          else: -            log.debug("Document already exists!") +            log.error("Both path and blob was None, no pdf added!") -    def search(self, query: str, limit: int = 5) -> List[Document]: -        query_emb = self.encoder.encode_query(query) -        return self.vector_db.search(query_emb, limit) +    def retrieve(self, query: str, limit: int = 5) -> List[Document]: +        return self.retriever.retrieve(query, limit) -    def retrieve(self, prompt: Prompt): -        yield from self.generator.generate(prompt) +    def generate(self, generator: Type[AbstractGenerator], prompt: Prompt): +        yield from generator.generate(prompt) diff --git a/rag/db/__init__.py b/rag/retriever/__init__.py index e69de29..e69de29 100644 --- a/rag/db/__init__.py +++ b/rag/retriever/__init__.py diff --git a/rag/db/document.py b/rag/retriever/document.py index 54ac451..54ac451 100644 --- a/rag/db/document.py +++ b/rag/retriever/document.py diff --git a/rag/llm/encoder.py b/rag/retriever/encoder.py index a59b1b4..753157f 100644 --- a/rag/llm/encoder.py +++ b/rag/retriever/encoder.py @@ -8,11 +8,7 @@ from langchain_core.documents import Document  from loguru import logger as log  from qdrant_client.http.models import StrictFloat - -try: -    from rag.db.vector import Point -except ModuleNotFoundError: -    from db.vector import Point +from .vector import Point  class Encoder: diff --git a/rag/llm/__init__.py b/rag/retriever/parser/__init__.py index e69de29..e69de29 100644 --- a/rag/llm/__init__.py +++ b/rag/retriever/parser/__init__.py diff --git a/rag/parser/pdf.py b/rag/retriever/parser/pdf.py index ca9b72d..410f027 100644 --- a/rag/parser/pdf.py +++ b/rag/retriever/parser/pdf.py @@ -1,6 +1,6 @@  import os  from pathlib import Path -from typing import Iterator, List, Optional +from typing import List, Optional  from langchain.text_splitter import RecursiveCharacterTextSplitter  from langchain_core.documents import Document @@ -14,14 +14,14 @@ class PDFParser:      def __init__(self) -> None:          self.parser = PyPDFParser(password=None, extract_images=False) -    def from_data(self, blob: Blob) -> Iterator[Document]: +    def from_data(self, blob: Blob) -> List[Document]:          return self.parser.parse(blob) -    def from_path(self, path: Path) -> Iterator[Document]: +    def from_path(self, path: Path) -> Blob:          return Blob.from_path(path)      def chunk( -        self, document: Iterator[Document], source: Optional[str] = None +        self, document: List[Document], source: Optional[str] = None      ) -> List[Document]:          splitter = RecursiveCharacterTextSplitter(              chunk_size=int(os.environ["CHUNK_SIZE"]), diff --git a/rag/retriever/retriever.py b/rag/retriever/retriever.py new file mode 100644 index 0000000..dbfdfa2 --- /dev/null +++ b/rag/retriever/retriever.py @@ -0,0 +1,37 @@ +from pathlib import Path +from typing import Optional, List +from loguru import logger as log + +from io import BytesIO +from .document import DocumentDB +from .encoder import Encoder +from .parser.pdf import PDFParser +from .vector import VectorDB, Document + + +class Retriever: +    def __init__(self) -> None: +        self.pdf_parser = PDFParser() +        self.encoder = Encoder() +        self.doc_db = DocumentDB() +        self.vec_db = VectorDB() + +    def add_pdf_from_path(self, path: Path): +        log.debug(f"Adding pdf from {path}") +        blob = self.pdf_parser.from_path(path) +        self.add_pdf_from_blob(blob) + +    def add_pdf_from_blob(self, blob: BytesIO, source: Optional[str] = None): +        if self.doc_db.add(blob): +            log.debug("Adding pdf to vector database...") +            document = self.pdf_parser.from_data(blob) +            chunks = self.pdf_parser.chunk(document, source) +            points = self.encoder.encode_document(chunks) +            self.vec_db.add(points) +        else: +            log.debug("Document already exists!") + +    def retrieve(self, query: str, limit: int = 5) -> List[Document]: +        log.debug(f"Finding documents matching query: {query}") +        query_emb = self.encoder.encode_query(query) +        return self.vec_db.search(query_emb, limit) diff --git a/rag/db/vector.py b/rag/retriever/vector.py index fd2b2c2..fd2b2c2 100644 --- a/rag/db/vector.py +++ b/rag/retriever/vector.py @@ -1,16 +1,15 @@  import streamlit as st  from langchain_community.document_loaders.blob_loaders import Blob +from .rag import RAG -try: -    from rag.rag import RAG -    from rag.llm.ollama_generator import Prompt -except ModuleNotFoundError: -    from rag import RAG -    from llm.ollama_generator import Prompt +from .generator import get_generator +from .generator.prompt import Prompt  rag = RAG() +MODELS = ["ollama", "cohere"] +  def upload_pdfs():      files = st.file_uploader( @@ -26,13 +25,15 @@ def upload_pdfs():          for file in files:              source = file.name              blob = Blob.from_data(file.read()) -            rag.add_pdf_from_blob(blob, source) +            rag.add_pdf(blob, source)  if __name__ == "__main__":      ss = st.session_state      st.header("RAG-UI") +    model = st.selectbox("Model", options=MODELS) +      upload_pdfs()      with st.form(key="query"): @@ -56,7 +57,7 @@ if __name__ == "__main__":          query = ss.get("query", "")          with st.spinner("Searching for documents..."): -            documents = rag.search(query) +            documents = rag.retrieve(query)          prompt = Prompt(query, documents) @@ -69,6 +70,6 @@ if __name__ == "__main__":                  st.markdown("---")          with result_column: +            generator = get_generator(model)              st.markdown("### Answer") -            st.write_stream(rag.retrieve(prompt)) - +            st.write_stream(rag.generate(generator, prompt))  |