summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-09 00:41:55 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-09 00:41:55 +0200
commit3f447bff69c20109474c455f1ad52bd547ab49e9 (patch)
tree66695ca0e3423e2973c5b24ec1ae7096019b5dd0
parentc05eae81f9aaa0a764203446ec54d7dd7cbeb66f (diff)
Update
-rw-r--r--rag/cli.py26
-rw-r--r--rag/generator/__init__.py1
-rw-r--r--rag/generator/abstract.py4
-rw-r--r--rag/rag.py52
-rw-r--r--rag/retriever/retriever.py19
-rw-r--r--rag/ui.py45
6 files changed, 54 insertions, 93 deletions
diff --git a/rag/cli.py b/rag/cli.py
index d5651c1..003718d 100644
--- a/rag/cli.py
+++ b/rag/cli.py
@@ -1,27 +1,33 @@
from pathlib import Path
+from dotenv import load_dotenv
+from rag.generator import get_generator, MODELS
+from rag.retriever.retriever import Retriever
+from rag.generator.prompt import Prompt
-try:
- from rag.rag import RAG
-except ModuleNotFoundError:
- from rag import RAG
if __name__ == "__main__":
- rag = RAG()
+ load_dotenv()
+ retriever = Retriever()
+
+ print("\n\nRetrieval Augmented Generation\n")
+ model = input(f"Enter model ({MODELS}):")
while True:
- print("Retrieval Augmented Generation")
- choice = input("1. add pdf from path\n2. Enter a query\n")
+ choice = input("1. Add pdf from path\n2. Enter a query\n")
match choice:
case "1":
path = input("Enter the path to the pdf: ")
path = Path(path)
- rag.add_pdf_from_path(path)
+ retriever.add_pdf(path=path)
case "2":
query = input("Enter your query: ")
if query:
- result = rag.retrive(query)
+ generator = get_generator(model)
+ documents = retriever.retrieve(query)
+ prompt = Prompt(query, documents)
print("Answer: \n")
- print(result.answer + "\n")
+ for chunk in generator.generate(prompt):
+ print(chunk, end="", flush=True)
case _:
print("Invalid option!")
diff --git a/rag/generator/__init__.py b/rag/generator/__init__.py
index 7da603c..541eff8 100644
--- a/rag/generator/__init__.py
+++ b/rag/generator/__init__.py
@@ -4,6 +4,7 @@ from .abstract import AbstractGenerator
from .ollama import Ollama
from .cohere import Cohere
+MODELS = ["ollama", "cohere"]
def get_generator(model: str) -> Type[AbstractGenerator]:
match model:
diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py
index a53b5d8..5b336ea 100644
--- a/rag/generator/abstract.py
+++ b/rag/generator/abstract.py
@@ -1,11 +1,11 @@
-from abc import ABC, abstractmethod
+from abc import abstractmethod
from typing import Any, Generator
from .prompt import Prompt
-class AbstractGenerator(ABC, type):
+class AbstractGenerator(type):
_instances = {}
def __call__(cls, *args, **kwargs):
diff --git a/rag/rag.py b/rag/rag.py
deleted file mode 100644
index c95d93a..0000000
--- a/rag/rag.py
+++ /dev/null
@@ -1,52 +0,0 @@
-from dataclasses import dataclass
-from io import BytesIO
-from pathlib import Path
-from typing import List, Optional, Type
-
-from dotenv import load_dotenv
-from loguru import logger as log
-
-
-try:
- from rag.retriever.vector import Document
- from rag.generator.abstract import AbstractGenerator
- from rag.retriever.retriever import Retriever
- from rag.generator.prompt import Prompt
-except ModuleNotFoundError:
- from retriever.vector import Document
- from generator.abstract import AbstractGenerator
- from retriever.retriever import Retriever
- from generator.prompt import Prompt
-
-
-@dataclass
-class Response:
- query: str
- context: List[str]
- answer: str
-
-
-class RAG:
- def __init__(self) -> None:
- # FIXME: load this somewhere else?
- load_dotenv()
- self.retriever = Retriever()
-
- def add_pdf(
- self,
- path: Optional[Path],
- blob: Optional[BytesIO],
- source: Optional[str] = None,
- ):
- if path:
- self.retriever.add_pdf_from_path(path)
- elif blob:
- self.retriever.add_pdf_from_blob(blob, source)
- else:
- log.error("Both path and blob was None, no pdf added!")
-
- def retrieve(self, query: str, limit: int = 5) -> List[Document]:
- return self.retriever.retrieve(query, limit)
-
- def generate(self, generator: Type[AbstractGenerator], prompt: Prompt):
- yield from generator.generate(prompt)
diff --git a/rag/retriever/retriever.py b/rag/retriever/retriever.py
index dbfdfa2..885dafe 100644
--- a/rag/retriever/retriever.py
+++ b/rag/retriever/retriever.py
@@ -16,12 +16,12 @@ class Retriever:
self.doc_db = DocumentDB()
self.vec_db = VectorDB()
- def add_pdf_from_path(self, path: Path):
+ def __add_pdf_from_path(self, path: Path):
log.debug(f"Adding pdf from {path}")
blob = self.pdf_parser.from_path(path)
- self.add_pdf_from_blob(blob)
+ self.__add_pdf_from_blob(blob)
- def add_pdf_from_blob(self, blob: BytesIO, source: Optional[str] = None):
+ def __add_pdf_from_blob(self, blob: BytesIO, source: Optional[str] = None):
if self.doc_db.add(blob):
log.debug("Adding pdf to vector database...")
document = self.pdf_parser.from_data(blob)
@@ -31,6 +31,19 @@ class Retriever:
else:
log.debug("Document already exists!")
+ def add_pdf(
+ self,
+ path: Optional[Path] = None,
+ blob: Optional[BytesIO] = None,
+ source: Optional[str] = None,
+ ):
+ if path:
+ self.__add_pdf_from_path(path)
+ elif blob and source:
+ self.__add_pdf_from_blob(blob, source)
+ else:
+ log.error("Invalid input!")
+
def retrieve(self, query: str, limit: int = 5) -> List[Document]:
log.debug(f"Finding documents matching query: {query}")
query_emb = self.encoder.encode_query(query)
diff --git a/rag/ui.py b/rag/ui.py
index 83a22e2..5083bb4 100644
--- a/rag/ui.py
+++ b/rag/ui.py
@@ -1,40 +1,33 @@
import streamlit as st
from langchain_community.document_loaders.blob_loaders import Blob
-from .rag import RAG
-from .generator import get_generator
-from .generator.prompt import Prompt
+from dotenv import load_dotenv
+from generator import get_generator, MODELS
+from generator.prompt import Prompt
+from retriever.retriever import Retriever
-rag = RAG()
-MODELS = ["ollama", "cohere"]
+if __name__ == "__main__":
+ load_dotenv()
+ retriever = Retriever()
+ ss = st.session_state
+ st.header("Retrieval Augmented Generation")
+ model = st.selectbox("Model", options=MODELS)
-def upload_pdfs():
files = st.file_uploader(
"Choose pdfs to add to the knowledge base",
type="pdf",
accept_multiple_files=True,
)
- if not files:
- return
-
- with st.spinner("Indexing documents..."):
- for file in files:
- source = file.name
- blob = Blob.from_data(file.read())
- rag.add_pdf(blob, source)
-
-
-if __name__ == "__main__":
- ss = st.session_state
- st.header("RAG-UI")
-
- model = st.selectbox("Model", options=MODELS)
-
- upload_pdfs()
+ if files:
+ with st.spinner("Indexing documents..."):
+ for file in files:
+ source = file.name
+ blob = Blob.from_data(file.read())
+ retriever.add_pdf(blob=blob, source=source)
with st.form(key="query"):
query = st.text_area(
@@ -51,13 +44,13 @@ if __name__ == "__main__":
(b,) = st.columns(1)
(result_column, context_column) = st.columns(2)
- if submit:
+ if submit and model:
if not query:
st.stop()
query = ss.get("query", "")
with st.spinner("Searching for documents..."):
- documents = rag.retrieve(query)
+ documents = retriever.retrieve(query)
prompt = Prompt(query, documents)
@@ -72,4 +65,4 @@ if __name__ == "__main__":
with result_column:
generator = get_generator(model)
st.markdown("### Answer")
- st.write_stream(rag.generate(generator, prompt))
+ st.write_stream(generator.generate(generator, prompt))