summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-09 00:14:00 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-09 00:14:00 +0200
commit91ddb3672e514fa9824609ff047d7cab0c65631a (patch)
tree009fd82618588d2960b5207128e86875f73cccdc
parentd487ef8b04cc7f5ac1491f0638f902fe2abe5ac5 (diff)
Refactor
-rw-r--r--notebooks/testing.ipynb53
-rw-r--r--poetry.lock2
-rw-r--r--rag/cli.py1
-rw-r--r--rag/generator/__init__.py15
-rw-r--r--rag/generator/abstract.py19
-rw-r--r--rag/generator/cohere.py (renamed from rag/llm/cohere_generator.py)10
-rw-r--r--rag/generator/ollama.py (renamed from rag/llm/ollama_generator.py)21
-rw-r--r--rag/generator/prompt.py14
-rw-r--r--rag/parser/__init__.py0
-rw-r--r--rag/rag.py63
-rw-r--r--rag/retriever/__init__.py (renamed from rag/db/__init__.py)0
-rw-r--r--rag/retriever/document.py (renamed from rag/db/document.py)0
-rw-r--r--rag/retriever/encoder.py (renamed from rag/llm/encoder.py)6
-rw-r--r--rag/retriever/parser/__init__.py (renamed from rag/llm/__init__.py)0
-rw-r--r--rag/retriever/parser/pdf.py (renamed from rag/parser/pdf.py)8
-rw-r--r--rag/retriever/retriever.py37
-rw-r--r--rag/retriever/vector.py (renamed from rag/db/vector.py)0
-rw-r--r--rag/ui.py21
18 files changed, 193 insertions, 77 deletions
diff --git a/notebooks/testing.ipynb b/notebooks/testing.ipynb
index 18827fc..bcaa613 100644
--- a/notebooks/testing.ipynb
+++ b/notebooks/testing.ipynb
@@ -98,10 +98,61 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 5,
"id": "05f7068c-b4c6-47b2-ac62-79c021838500",
"metadata": {},
"outputs": [],
+ "source": [
+ "from rag.generator import get_generator"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "31dd3e7a-1d6c-4cc6-8ffe-f83478f95875",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x = get_generator(\"ollama\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "749c58c4-5250-464c-b4f4-b5b4a00be3b3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "y = get_generator(\"ollama\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "c59ece14-11e0-416c-afad-bc2ff6a526ec",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "x is y"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0e664d3d-a787-45e5-8b71-8e5e0f348443",
+ "metadata": {},
+ "outputs": [],
"source": []
}
],
diff --git a/poetry.lock b/poetry.lock
index ffeea46..8a2b3a2 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -388,7 +388,7 @@ files = [
[[package]]
name = "cffi"
version = "1.16.0"
-description = "Foreign Function Interface for Python calling C code."
+description = "Foreign Function AbstractGenerator for Python calling C code."
optional = false
python-versions = ">=3.8"
files = [
diff --git a/rag/cli.py b/rag/cli.py
index c470db3..d5651c1 100644
--- a/rag/cli.py
+++ b/rag/cli.py
@@ -25,4 +25,3 @@ if __name__ == "__main__":
print(result.answer + "\n")
case _:
print("Invalid option!")
-
diff --git a/rag/generator/__init__.py b/rag/generator/__init__.py
new file mode 100644
index 0000000..7da603c
--- /dev/null
+++ b/rag/generator/__init__.py
@@ -0,0 +1,15 @@
+from typing import Type
+
+from .abstract import AbstractGenerator
+from .ollama import Ollama
+from .cohere import Cohere
+
+
+def get_generator(model: str) -> Type[AbstractGenerator]:
+ match model:
+ case "ollama":
+ return Ollama()
+ case "cohere":
+ return Cohere()
+ case _:
+ exit(1)
diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py
new file mode 100644
index 0000000..a53b5d8
--- /dev/null
+++ b/rag/generator/abstract.py
@@ -0,0 +1,19 @@
+from abc import ABC, abstractmethod
+
+from typing import Any, Generator
+
+from .prompt import Prompt
+
+
+class AbstractGenerator(ABC, type):
+ _instances = {}
+
+ def __call__(cls, *args, **kwargs):
+ if cls not in cls._instances:
+ instance = super().__call__(*args, **kwargs)
+ cls._instances[cls] = instance
+ return cls._instances[cls]
+
+ @abstractmethod
+ def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
+ pass
diff --git a/rag/llm/cohere_generator.py b/rag/generator/cohere.py
index a6feacd..cf95c18 100644
--- a/rag/llm/cohere_generator.py
+++ b/rag/generator/cohere.py
@@ -3,14 +3,14 @@ from typing import Any, Generator
import cohere
from dataclasses import asdict
-try:
- from rag.llm.ollama_generator import Prompt
-except ModuleNotFoundError:
- from llm.ollama_generator import Prompt
+
+from .prompt import Prompt
+from .abstract import AbstractGenerator
+
from loguru import logger as log
-class CohereGenerator:
+class Cohere(metaclass=AbstractGenerator):
def __init__(self) -> None:
self.client = cohere.Client(os.environ["COHERE_API_KEY"])
diff --git a/rag/llm/ollama_generator.py b/rag/generator/ollama.py
index dd17f8d..ec6a55f 100644
--- a/rag/llm/ollama_generator.py
+++ b/rag/generator/ollama.py
@@ -1,21 +1,16 @@
import os
-from dataclasses import dataclass
from typing import Any, Generator, List
import ollama
from loguru import logger as log
+from .prompt import Prompt
+from .abstract import AbstractGenerator
+
try:
- from rag.db.vector import Document
+ from rag.retriever.vector import Document
except ModuleNotFoundError:
- from db.vector import Document
-
-
-@dataclass
-class Prompt:
- query: str
- documents: List[Document]
-
+ from retriever.vector import Document
SYSTEM_PROMPT = (
"# System Preamble"
@@ -28,7 +23,7 @@ SYSTEM_PROMPT = (
)
-class OllamaGenerator:
+class Ollama(metaclass=AbstractGenerator):
def __init__(self) -> None:
self.model = os.environ["GENERATOR_MODEL"]
@@ -72,5 +67,5 @@ class OllamaGenerator:
metaprompt = self.__metaprompt(prompt)
for chunk in ollama.generate(
model=self.model, prompt=metaprompt, system=SYSTEM_PROMPT, stream=True
- ):
- yield chunk
+ ):
+ yield chunk["response"]
diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py
new file mode 100644
index 0000000..ed372c9
--- /dev/null
+++ b/rag/generator/prompt.py
@@ -0,0 +1,14 @@
+from dataclasses import dataclass
+from typing import List
+
+
+try:
+ from rag.retriever.vector import Document
+except ModuleNotFoundError:
+ from retriever.vector import Document
+
+
+@dataclass
+class Prompt:
+ query: str
+ documents: List[Document]
diff --git a/rag/parser/__init__.py b/rag/parser/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/rag/parser/__init__.py
+++ /dev/null
diff --git a/rag/rag.py b/rag/rag.py
index 93f9fd7..c95d93a 100644
--- a/rag/rag.py
+++ b/rag/rag.py
@@ -1,27 +1,22 @@
from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
-from typing import List
+from typing import List, Optional, Type
from dotenv import load_dotenv
from loguru import logger as log
-
try:
- from rag.db.vector import VectorDB, Document
- from rag.db.document import DocumentDB
- from rag.llm.encoder import Encoder
- from rag.llm.ollama_generator import OllamaGenerator, Prompt
- from rag.llm.cohere_generator import CohereGenerator
- from rag.parser.pdf import PDFParser
+ from rag.retriever.vector import Document
+ from rag.generator.abstract import AbstractGenerator
+ from rag.retriever.retriever import Retriever
+ from rag.generator.prompt import Prompt
except ModuleNotFoundError:
- from db.vector import VectorDB, Document
- from db.document import DocumentDB
- from llm.encoder import Encoder
- from llm.ollama_generator import OllamaGenerator, Prompt
- from llm.cohere_generator import CohereGenerator
- from parser.pdf import PDFParser
+ from retriever.vector import Document
+ from generator.abstract import AbstractGenerator
+ from retriever.retriever import Retriever
+ from generator.prompt import Prompt
@dataclass
@@ -35,29 +30,23 @@ class RAG:
def __init__(self) -> None:
# FIXME: load this somewhere else?
load_dotenv()
- self.pdf_parser = PDFParser()
- self.generator = CohereGenerator()
- self.encoder = Encoder()
- self.vector_db = VectorDB()
- self.doc_db = DocumentDB()
-
- def add_pdf_from_path(self, path: Path):
- blob = self.pdf_parser.from_path(path)
- self.add_pdf_from_blob(blob)
-
- def add_pdf_from_blob(self, blob: BytesIO, source: str):
- if self.doc_db.add(blob):
- log.debug("Adding pdf to vector database...")
- document = self.pdf_parser.from_data(blob)
- chunks = self.pdf_parser.chunk(document, source)
- points = self.encoder.encode_document(chunks)
- self.vector_db.add(points)
+ self.retriever = Retriever()
+
+ def add_pdf(
+ self,
+ path: Optional[Path],
+ blob: Optional[BytesIO],
+ source: Optional[str] = None,
+ ):
+ if path:
+ self.retriever.add_pdf_from_path(path)
+ elif blob:
+ self.retriever.add_pdf_from_blob(blob, source)
else:
- log.debug("Document already exists!")
+ log.error("Both path and blob was None, no pdf added!")
- def search(self, query: str, limit: int = 5) -> List[Document]:
- query_emb = self.encoder.encode_query(query)
- return self.vector_db.search(query_emb, limit)
+ def retrieve(self, query: str, limit: int = 5) -> List[Document]:
+ return self.retriever.retrieve(query, limit)
- def retrieve(self, prompt: Prompt):
- yield from self.generator.generate(prompt)
+ def generate(self, generator: Type[AbstractGenerator], prompt: Prompt):
+ yield from generator.generate(prompt)
diff --git a/rag/db/__init__.py b/rag/retriever/__init__.py
index e69de29..e69de29 100644
--- a/rag/db/__init__.py
+++ b/rag/retriever/__init__.py
diff --git a/rag/db/document.py b/rag/retriever/document.py
index 54ac451..54ac451 100644
--- a/rag/db/document.py
+++ b/rag/retriever/document.py
diff --git a/rag/llm/encoder.py b/rag/retriever/encoder.py
index a59b1b4..753157f 100644
--- a/rag/llm/encoder.py
+++ b/rag/retriever/encoder.py
@@ -8,11 +8,7 @@ from langchain_core.documents import Document
from loguru import logger as log
from qdrant_client.http.models import StrictFloat
-
-try:
- from rag.db.vector import Point
-except ModuleNotFoundError:
- from db.vector import Point
+from .vector import Point
class Encoder:
diff --git a/rag/llm/__init__.py b/rag/retriever/parser/__init__.py
index e69de29..e69de29 100644
--- a/rag/llm/__init__.py
+++ b/rag/retriever/parser/__init__.py
diff --git a/rag/parser/pdf.py b/rag/retriever/parser/pdf.py
index ca9b72d..410f027 100644
--- a/rag/parser/pdf.py
+++ b/rag/retriever/parser/pdf.py
@@ -1,6 +1,6 @@
import os
from pathlib import Path
-from typing import Iterator, List, Optional
+from typing import List, Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
@@ -14,14 +14,14 @@ class PDFParser:
def __init__(self) -> None:
self.parser = PyPDFParser(password=None, extract_images=False)
- def from_data(self, blob: Blob) -> Iterator[Document]:
+ def from_data(self, blob: Blob) -> List[Document]:
return self.parser.parse(blob)
- def from_path(self, path: Path) -> Iterator[Document]:
+ def from_path(self, path: Path) -> Blob:
return Blob.from_path(path)
def chunk(
- self, document: Iterator[Document], source: Optional[str] = None
+ self, document: List[Document], source: Optional[str] = None
) -> List[Document]:
splitter = RecursiveCharacterTextSplitter(
chunk_size=int(os.environ["CHUNK_SIZE"]),
diff --git a/rag/retriever/retriever.py b/rag/retriever/retriever.py
new file mode 100644
index 0000000..dbfdfa2
--- /dev/null
+++ b/rag/retriever/retriever.py
@@ -0,0 +1,37 @@
+from pathlib import Path
+from typing import Optional, List
+from loguru import logger as log
+
+from io import BytesIO
+from .document import DocumentDB
+from .encoder import Encoder
+from .parser.pdf import PDFParser
+from .vector import VectorDB, Document
+
+
+class Retriever:
+ def __init__(self) -> None:
+ self.pdf_parser = PDFParser()
+ self.encoder = Encoder()
+ self.doc_db = DocumentDB()
+ self.vec_db = VectorDB()
+
+ def add_pdf_from_path(self, path: Path):
+ log.debug(f"Adding pdf from {path}")
+ blob = self.pdf_parser.from_path(path)
+ self.add_pdf_from_blob(blob)
+
+ def add_pdf_from_blob(self, blob: BytesIO, source: Optional[str] = None):
+ if self.doc_db.add(blob):
+ log.debug("Adding pdf to vector database...")
+ document = self.pdf_parser.from_data(blob)
+ chunks = self.pdf_parser.chunk(document, source)
+ points = self.encoder.encode_document(chunks)
+ self.vec_db.add(points)
+ else:
+ log.debug("Document already exists!")
+
+ def retrieve(self, query: str, limit: int = 5) -> List[Document]:
+ log.debug(f"Finding documents matching query: {query}")
+ query_emb = self.encoder.encode_query(query)
+ return self.vec_db.search(query_emb, limit)
diff --git a/rag/db/vector.py b/rag/retriever/vector.py
index fd2b2c2..fd2b2c2 100644
--- a/rag/db/vector.py
+++ b/rag/retriever/vector.py
diff --git a/rag/ui.py b/rag/ui.py
index 84dbbeb..83a22e2 100644
--- a/rag/ui.py
+++ b/rag/ui.py
@@ -1,16 +1,15 @@
import streamlit as st
from langchain_community.document_loaders.blob_loaders import Blob
+from .rag import RAG
-try:
- from rag.rag import RAG
- from rag.llm.ollama_generator import Prompt
-except ModuleNotFoundError:
- from rag import RAG
- from llm.ollama_generator import Prompt
+from .generator import get_generator
+from .generator.prompt import Prompt
rag = RAG()
+MODELS = ["ollama", "cohere"]
+
def upload_pdfs():
files = st.file_uploader(
@@ -26,13 +25,15 @@ def upload_pdfs():
for file in files:
source = file.name
blob = Blob.from_data(file.read())
- rag.add_pdf_from_blob(blob, source)
+ rag.add_pdf(blob, source)
if __name__ == "__main__":
ss = st.session_state
st.header("RAG-UI")
+ model = st.selectbox("Model", options=MODELS)
+
upload_pdfs()
with st.form(key="query"):
@@ -56,7 +57,7 @@ if __name__ == "__main__":
query = ss.get("query", "")
with st.spinner("Searching for documents..."):
- documents = rag.search(query)
+ documents = rag.retrieve(query)
prompt = Prompt(query, documents)
@@ -69,6 +70,6 @@ if __name__ == "__main__":
st.markdown("---")
with result_column:
+ generator = get_generator(model)
st.markdown("### Answer")
- st.write_stream(rag.retrieve(prompt))
-
+ st.write_stream(rag.generate(generator, prompt))