From 3f447bff69c20109474c455f1ad52bd547ab49e9 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 9 Apr 2024 00:41:55 +0200 Subject: Update --- rag/cli.py | 26 ++++++++++++++--------- rag/generator/__init__.py | 1 + rag/generator/abstract.py | 4 ++-- rag/rag.py | 52 ---------------------------------------------- rag/retriever/retriever.py | 19 ++++++++++++++--- rag/ui.py | 45 +++++++++++++++++---------------------- 6 files changed, 54 insertions(+), 93 deletions(-) delete mode 100644 rag/rag.py (limited to 'rag') diff --git a/rag/cli.py b/rag/cli.py index d5651c1..003718d 100644 --- a/rag/cli.py +++ b/rag/cli.py @@ -1,27 +1,33 @@ from pathlib import Path +from dotenv import load_dotenv +from rag.generator import get_generator, MODELS +from rag.retriever.retriever import Retriever +from rag.generator.prompt import Prompt -try: - from rag.rag import RAG -except ModuleNotFoundError: - from rag import RAG if __name__ == "__main__": - rag = RAG() + load_dotenv() + retriever = Retriever() + + print("\n\nRetrieval Augmented Generation\n") + model = input(f"Enter model ({MODELS}):") while True: - print("Retrieval Augmented Generation") - choice = input("1. add pdf from path\n2. Enter a query\n") + choice = input("1. Add pdf from path\n2. Enter a query\n") match choice: case "1": path = input("Enter the path to the pdf: ") path = Path(path) - rag.add_pdf_from_path(path) + retriever.add_pdf(path=path) case "2": query = input("Enter your query: ") if query: - result = rag.retrive(query) + generator = get_generator(model) + documents = retriever.retrieve(query) + prompt = Prompt(query, documents) print("Answer: \n") - print(result.answer + "\n") + for chunk in generator.generate(prompt): + print(chunk, end="", flush=True) case _: print("Invalid option!") diff --git a/rag/generator/__init__.py b/rag/generator/__init__.py index 7da603c..541eff8 100644 --- a/rag/generator/__init__.py +++ b/rag/generator/__init__.py @@ -4,6 +4,7 @@ from .abstract import AbstractGenerator from .ollama import Ollama from .cohere import Cohere +MODELS = ["ollama", "cohere"] def get_generator(model: str) -> Type[AbstractGenerator]: match model: diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py index a53b5d8..5b336ea 100644 --- a/rag/generator/abstract.py +++ b/rag/generator/abstract.py @@ -1,11 +1,11 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Any, Generator from .prompt import Prompt -class AbstractGenerator(ABC, type): +class AbstractGenerator(type): _instances = {} def __call__(cls, *args, **kwargs): diff --git a/rag/rag.py b/rag/rag.py deleted file mode 100644 index c95d93a..0000000 --- a/rag/rag.py +++ /dev/null @@ -1,52 +0,0 @@ -from dataclasses import dataclass -from io import BytesIO -from pathlib import Path -from typing import List, Optional, Type - -from dotenv import load_dotenv -from loguru import logger as log - - -try: - from rag.retriever.vector import Document - from rag.generator.abstract import AbstractGenerator - from rag.retriever.retriever import Retriever - from rag.generator.prompt import Prompt -except ModuleNotFoundError: - from retriever.vector import Document - from generator.abstract import AbstractGenerator - from retriever.retriever import Retriever - from generator.prompt import Prompt - - -@dataclass -class Response: - query: str - context: List[str] - answer: str - - -class RAG: - def __init__(self) -> None: - # FIXME: load this somewhere else? - load_dotenv() - self.retriever = Retriever() - - def add_pdf( - self, - path: Optional[Path], - blob: Optional[BytesIO], - source: Optional[str] = None, - ): - if path: - self.retriever.add_pdf_from_path(path) - elif blob: - self.retriever.add_pdf_from_blob(blob, source) - else: - log.error("Both path and blob was None, no pdf added!") - - def retrieve(self, query: str, limit: int = 5) -> List[Document]: - return self.retriever.retrieve(query, limit) - - def generate(self, generator: Type[AbstractGenerator], prompt: Prompt): - yield from generator.generate(prompt) diff --git a/rag/retriever/retriever.py b/rag/retriever/retriever.py index dbfdfa2..885dafe 100644 --- a/rag/retriever/retriever.py +++ b/rag/retriever/retriever.py @@ -16,12 +16,12 @@ class Retriever: self.doc_db = DocumentDB() self.vec_db = VectorDB() - def add_pdf_from_path(self, path: Path): + def __add_pdf_from_path(self, path: Path): log.debug(f"Adding pdf from {path}") blob = self.pdf_parser.from_path(path) - self.add_pdf_from_blob(blob) + self.__add_pdf_from_blob(blob) - def add_pdf_from_blob(self, blob: BytesIO, source: Optional[str] = None): + def __add_pdf_from_blob(self, blob: BytesIO, source: Optional[str] = None): if self.doc_db.add(blob): log.debug("Adding pdf to vector database...") document = self.pdf_parser.from_data(blob) @@ -31,6 +31,19 @@ class Retriever: else: log.debug("Document already exists!") + def add_pdf( + self, + path: Optional[Path] = None, + blob: Optional[BytesIO] = None, + source: Optional[str] = None, + ): + if path: + self.__add_pdf_from_path(path) + elif blob and source: + self.__add_pdf_from_blob(blob, source) + else: + log.error("Invalid input!") + def retrieve(self, query: str, limit: int = 5) -> List[Document]: log.debug(f"Finding documents matching query: {query}") query_emb = self.encoder.encode_query(query) diff --git a/rag/ui.py b/rag/ui.py index 83a22e2..5083bb4 100644 --- a/rag/ui.py +++ b/rag/ui.py @@ -1,40 +1,33 @@ import streamlit as st from langchain_community.document_loaders.blob_loaders import Blob -from .rag import RAG -from .generator import get_generator -from .generator.prompt import Prompt +from dotenv import load_dotenv +from generator import get_generator, MODELS +from generator.prompt import Prompt +from retriever.retriever import Retriever -rag = RAG() -MODELS = ["ollama", "cohere"] +if __name__ == "__main__": + load_dotenv() + retriever = Retriever() + ss = st.session_state + st.header("Retrieval Augmented Generation") + model = st.selectbox("Model", options=MODELS) -def upload_pdfs(): files = st.file_uploader( "Choose pdfs to add to the knowledge base", type="pdf", accept_multiple_files=True, ) - if not files: - return - - with st.spinner("Indexing documents..."): - for file in files: - source = file.name - blob = Blob.from_data(file.read()) - rag.add_pdf(blob, source) - - -if __name__ == "__main__": - ss = st.session_state - st.header("RAG-UI") - - model = st.selectbox("Model", options=MODELS) - - upload_pdfs() + if files: + with st.spinner("Indexing documents..."): + for file in files: + source = file.name + blob = Blob.from_data(file.read()) + retriever.add_pdf(blob=blob, source=source) with st.form(key="query"): query = st.text_area( @@ -51,13 +44,13 @@ if __name__ == "__main__": (b,) = st.columns(1) (result_column, context_column) = st.columns(2) - if submit: + if submit and model: if not query: st.stop() query = ss.get("query", "") with st.spinner("Searching for documents..."): - documents = rag.retrieve(query) + documents = retriever.retrieve(query) prompt = Prompt(query, documents) @@ -72,4 +65,4 @@ if __name__ == "__main__": with result_column: generator = get_generator(model) st.markdown("### Answer") - st.write_stream(rag.generate(generator, prompt)) + st.write_stream(generator.generate(generator, prompt)) -- cgit v1.2.3-70-g09d2