summaryrefslogtreecommitdiff
path: root/rag
diff options
context:
space:
mode:
Diffstat (limited to 'rag')
-rw-r--r--rag/cli.py2
-rw-r--r--rag/db/vector.py25
-rw-r--r--rag/llm/cohere_generator.py29
-rw-r--r--rag/llm/encoder.py15
-rw-r--r--rag/llm/generator.py33
-rw-r--r--rag/llm/ollama_generator.py76
-rw-r--r--rag/parser/pdf.py13
-rw-r--r--rag/rag.py34
-rw-r--r--rag/ui.py59
9 files changed, 202 insertions, 84 deletions
diff --git a/rag/cli.py b/rag/cli.py
index 5ea1a47..c470db3 100644
--- a/rag/cli.py
+++ b/rag/cli.py
@@ -22,7 +22,7 @@ if __name__ == "__main__":
if query:
result = rag.retrive(query)
print("Answer: \n")
- print(result.answer)
+ print(result.answer + "\n")
case _:
print("Invalid option!")
diff --git a/rag/db/vector.py b/rag/db/vector.py
index bbbbf32..fd2b2c2 100644
--- a/rag/db/vector.py
+++ b/rag/db/vector.py
@@ -5,7 +5,7 @@ from typing import Dict, List
from loguru import logger as log
from qdrant_client import QdrantClient
from qdrant_client.http.models import StrictFloat
-from qdrant_client.models import Distance, PointStruct, ScoredPoint, VectorParams
+from qdrant_client.models import Distance, PointStruct, VectorParams
@dataclass
@@ -15,11 +15,18 @@ class Point:
payload: Dict[str, str]
+@dataclass
+class Document:
+ title: str
+ text: str
+
+
class VectorDB:
- def __init__(self):
+ def __init__(self, score_threshold: float = 0.6):
self.dim = int(os.environ["EMBEDDING_DIM"])
self.collection_name = os.environ["QDRANT_COLLECTION_NAME"]
self.client = QdrantClient(url=os.environ["QDRANT_URL"])
+ self.score_threshold = score_threshold
self.__configure()
def __configure(self):
@@ -47,12 +54,20 @@ class VectorDB:
max_retries=3,
)
- def search(self, query: List[float], limit: int = 4) -> List[ScoredPoint]:
+ def search(self, query: List[float], limit: int = 5) -> List[Document]:
log.debug("Searching for vectors...")
hits = self.client.search(
collection_name=self.collection_name,
query_vector=query,
limit=limit,
- score_threshold=0.6,
+ score_threshold=self.score_threshold,
+ )
+ log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}")
+ return list(
+ map(
+ lambda h: Document(
+ title=h.payload.get("source", ""), text=h.payload["text"]
+ ),
+ hits,
+ )
)
- return hits
diff --git a/rag/llm/cohere_generator.py b/rag/llm/cohere_generator.py
new file mode 100644
index 0000000..a6feacd
--- /dev/null
+++ b/rag/llm/cohere_generator.py
@@ -0,0 +1,29 @@
+import os
+from typing import Any, Generator
+import cohere
+
+from dataclasses import asdict
+try:
+ from rag.llm.ollama_generator import Prompt
+except ModuleNotFoundError:
+ from llm.ollama_generator import Prompt
+from loguru import logger as log
+
+
+class CohereGenerator:
+ def __init__(self) -> None:
+ self.client = cohere.Client(os.environ["COHERE_API_KEY"])
+
+ def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
+ log.debug("Generating answer from cohere")
+ for event in self.client.chat_stream(
+ message=prompt.query,
+ documents=[asdict(d) for d in prompt.documents],
+ prompt_truncation="AUTO",
+ ):
+ if event.event_type == "text-generation":
+ yield event.text
+ elif event.event_type == "citation-generation":
+ yield event.citations
+ elif event.event_type == "stream-end":
+ yield event.finish_reason
diff --git a/rag/llm/encoder.py b/rag/llm/encoder.py
index 95f3c6a..a59b1b4 100644
--- a/rag/llm/encoder.py
+++ b/rag/llm/encoder.py
@@ -1,5 +1,6 @@
import os
-from typing import Iterator, List
+from pathlib import Path
+from typing import List, Dict
from uuid import uuid4
import ollama
@@ -13,6 +14,7 @@ try:
except ModuleNotFoundError:
from db.vector import Point
+
class Encoder:
def __init__(self) -> None:
self.model = os.environ["ENCODER_MODEL"]
@@ -21,13 +23,20 @@ class Encoder:
def __encode(self, prompt: str) -> List[StrictFloat]:
return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"])
- def encode_document(self, chunks: Iterator[Document]) -> List[Point]:
+ def __get_source(self, metadata: Dict[str, str]) -> str:
+ source = metadata["source"]
+ return Path(source).name
+
+ def encode_document(self, chunks: List[Document]) -> List[Point]:
log.debug("Encoding document...")
return [
Point(
id=uuid4().hex,
vector=self.__encode(chunk.page_content),
- payload={"text": chunk.page_content},
+ payload={
+ "text": chunk.page_content,
+ "source": self.__get_source(chunk.metadata),
+ },
)
for chunk in chunks
]
diff --git a/rag/llm/generator.py b/rag/llm/generator.py
deleted file mode 100644
index 8c7702f..0000000
--- a/rag/llm/generator.py
+++ /dev/null
@@ -1,33 +0,0 @@
-import os
-from dataclasses import dataclass
-
-import ollama
-from loguru import logger as log
-
-
-@dataclass
-class Prompt:
- query: str
- context: str
-
-
-class Generator:
- def __init__(self) -> None:
- self.model = os.environ["GENERATOR_MODEL"]
-
- def __metaprompt(self, prompt: Prompt) -> str:
- metaprompt = (
- "Answer the following question using the provided context.\n"
- "If you can't find the answer, do not pretend you know it,"
- 'but answer "I don\'t know".\n\n'
- f"Question: {prompt.query.strip()}\n\n"
- "Context:\n"
- f"{prompt.context.strip()}\n\n"
- "Answer:\n"
- )
- return metaprompt
-
- def generate(self, prompt: Prompt) -> str:
- log.debug("Generating answer...")
- metaprompt = self.__metaprompt(prompt)
- return ollama.generate(model=self.model, prompt=metaprompt)
diff --git a/rag/llm/ollama_generator.py b/rag/llm/ollama_generator.py
new file mode 100644
index 0000000..dd17f8d
--- /dev/null
+++ b/rag/llm/ollama_generator.py
@@ -0,0 +1,76 @@
+import os
+from dataclasses import dataclass
+from typing import Any, Generator, List
+
+import ollama
+from loguru import logger as log
+
+try:
+ from rag.db.vector import Document
+except ModuleNotFoundError:
+ from db.vector import Document
+
+
+@dataclass
+class Prompt:
+ query: str
+ documents: List[Document]
+
+
+SYSTEM_PROMPT = (
+ "# System Preamble"
+ "## Basic Rules"
+ "When you answer the user's requests, you cite your sources in your answers, according to those instructions."
+ "Answer the following question using the provided context.\n"
+ "## Style Guide"
+ "Unless the user asks for a different style of answer, you should answer "
+ "in full sentences, using proper grammar and spelling."
+)
+
+
+class OllamaGenerator:
+ def __init__(self) -> None:
+ self.model = os.environ["GENERATOR_MODEL"]
+
+ def __context(self, documents: List[Document]) -> str:
+ results = [
+ f"Document: {i}\ntitle: {doc.title}\n{doc.text}"
+ for i, doc in enumerate(documents)
+ ]
+ return "\n".join(results)
+
+ def __metaprompt(self, prompt: Prompt) -> str:
+ # Include sources
+ metaprompt = (
+ f'Question: "{prompt.query.strip()}"\n\n'
+ "Context:\n"
+ "<result>\n"
+ f"{self.__context(prompt.documents)}\n\n"
+ "</result>\n"
+ "Carefully perform the following instructions, in order, starting each "
+ "with a new line.\n"
+ "Firstly, Decide which of the retrieved documents are relevant to the "
+ "user's last input by writing 'Relevant Documents:' followed by "
+ "comma-separated list of document numbers.\n If none are relevant, you "
+ "should instead write 'None'.\n"
+ "Secondly, Decide which of the retrieved documents contain facts that "
+ "should be cited in a good answer to the user's last input by writing "
+ "'Cited Documents:' followed a comma-separated list of document numbers. "
+ "If you dont want to cite any of them, you should instead write 'None'.\n"
+ "Thirdly, Write 'Answer:' followed by a response to the user's last input "
+ "in high quality natural english. Use the retrieved documents to help you. "
+ "Do not insert any citations or grounding markup.\n"
+ "Finally, Write 'Grounded answer:' followed by a response to the user's "
+ "last input in high quality natural english. Use the symbols <co: doc> and "
+ "</co: doc> to indicate when a fact comes from a document in the search "
+ "result, e.g <co: 0>my fact</co: 0> for a fact from document 0."
+ )
+ return metaprompt
+
+ def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
+ log.debug("Generating answer...")
+ metaprompt = self.__metaprompt(prompt)
+ for chunk in ollama.generate(
+ model=self.model, prompt=metaprompt, system=SYSTEM_PROMPT, stream=True
+ ):
+ yield chunk
diff --git a/rag/parser/pdf.py b/rag/parser/pdf.py
index cbd86a3..ca9b72d 100644
--- a/rag/parser/pdf.py
+++ b/rag/parser/pdf.py
@@ -1,6 +1,6 @@
import os
from pathlib import Path
-from typing import Iterator
+from typing import Iterator, List, Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
@@ -15,15 +15,20 @@ class PDFParser:
self.parser = PyPDFParser(password=None, extract_images=False)
def from_data(self, blob: Blob) -> Iterator[Document]:
- yield from self.parser.parse(blob)
+ return self.parser.parse(blob)
def from_path(self, path: Path) -> Iterator[Document]:
return Blob.from_path(path)
- def chunk(self, content: Iterator[Document]):
+ def chunk(
+ self, document: Iterator[Document], source: Optional[str] = None
+ ) -> List[Document]:
splitter = RecursiveCharacterTextSplitter(
chunk_size=int(os.environ["CHUNK_SIZE"]),
chunk_overlap=int(os.environ["CHUNK_OVERLAP"]),
)
- chunks = splitter.split_documents(content)
+ chunks = splitter.split_documents(document)
+ if source is not None:
+ for c in chunks:
+ c.metadata["source"] = source
return chunks
diff --git a/rag/rag.py b/rag/rag.py
index cd4537e..93f9fd7 100644
--- a/rag/rag.py
+++ b/rag/rag.py
@@ -5,20 +5,22 @@ from typing import List
from dotenv import load_dotenv
from loguru import logger as log
-from qdrant_client.models import StrictFloat
+
try:
- from rag.db.vector import VectorDB
+ from rag.db.vector import VectorDB, Document
from rag.db.document import DocumentDB
from rag.llm.encoder import Encoder
- from rag.llm.generator import Generator, Prompt
+ from rag.llm.ollama_generator import OllamaGenerator, Prompt
+ from rag.llm.cohere_generator import CohereGenerator
from rag.parser.pdf import PDFParser
except ModuleNotFoundError:
- from db.vector import VectorDB
+ from db.vector import VectorDB, Document
from db.document import DocumentDB
from llm.encoder import Encoder
- from llm.generator import Generator, Prompt
+ from llm.ollama_generator import OllamaGenerator, Prompt
+ from llm.cohere_generator import CohereGenerator
from parser.pdf import PDFParser
@@ -34,7 +36,7 @@ class RAG:
# FIXME: load this somewhere else?
load_dotenv()
self.pdf_parser = PDFParser()
- self.generator = Generator()
+ self.generator = CohereGenerator()
self.encoder = Encoder()
self.vector_db = VectorDB()
self.doc_db = DocumentDB()
@@ -43,23 +45,19 @@ class RAG:
blob = self.pdf_parser.from_path(path)
self.add_pdf_from_blob(blob)
- def add_pdf_from_blob(self, blob: BytesIO):
+ def add_pdf_from_blob(self, blob: BytesIO, source: str):
if self.doc_db.add(blob):
log.debug("Adding pdf to vector database...")
- chunks = self.pdf_parser.from_data(blob)
+ document = self.pdf_parser.from_data(blob)
+ chunks = self.pdf_parser.chunk(document, source)
points = self.encoder.encode_document(chunks)
self.vector_db.add(points)
else:
log.debug("Document already exists!")
- def __context(self, query_emb: List[StrictFloat], limit: int) -> str:
- hits = self.vector_db.search(query_emb, limit)
- log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}")
- return [h.payload["text"] for h in hits]
-
- def retrive(self, query: str, limit: int = 5) -> Response:
+ def search(self, query: str, limit: int = 5) -> List[Document]:
query_emb = self.encoder.encode_query(query)
- context = self.__context(query_emb, limit)
- prompt = Prompt(query, "\n".join(context))
- answer = self.generator.generate(prompt)["response"]
- return Response(query, context, answer)
+ return self.vector_db.search(query_emb, limit)
+
+ def retrieve(self, prompt: Prompt):
+ yield from self.generator.generate(prompt)
diff --git a/rag/ui.py b/rag/ui.py
index 37c50dd..84dbbeb 100644
--- a/rag/ui.py
+++ b/rag/ui.py
@@ -4,8 +4,10 @@ from langchain_community.document_loaders.blob_loaders import Blob
try:
from rag.rag import RAG
+ from rag.llm.ollama_generator import Prompt
except ModuleNotFoundError:
from rag import RAG
+ from llm.ollama_generator import Prompt
rag = RAG()
@@ -16,9 +18,15 @@ def upload_pdfs():
type="pdf",
accept_multiple_files=True,
)
- for file in files:
- blob = Blob.from_data(file.read())
- rag.add_pdf_from_blob(blob)
+
+ if not files:
+ return
+
+ with st.spinner("Indexing documents..."):
+ for file in files:
+ source = file.name
+ blob = Blob.from_data(file.read())
+ rag.add_pdf_from_blob(blob, source)
if __name__ == "__main__":
@@ -26,30 +34,41 @@ if __name__ == "__main__":
st.header("RAG-UI")
upload_pdfs()
- query = st.text_area(
- "query",
- key="query",
- height=100,
- placeholder="Enter query here",
- help="",
- label_visibility="collapsed",
- disabled=False,
- )
+
+ with st.form(key="query"):
+ query = st.text_area(
+ "query",
+ key="query",
+ height=100,
+ placeholder="Enter query here",
+ help="",
+ label_visibility="collapsed",
+ disabled=False,
+ )
+ submit = st.form_submit_button("Generate")
(b,) = st.columns(1)
(result_column, context_column) = st.columns(2)
- if b.button("Generate", disabled=False, type="primary", use_container_width=True):
+ if submit:
+ if not query:
+ st.stop()
+
query = ss.get("query", "")
- with st.spinner("Generating answer..."):
- response = rag.retrieve(query)
+ with st.spinner("Searching for documents..."):
+ documents = rag.search(query)
- with result_column:
- st.markdown("### Answer")
- st.markdown(response.answer)
+ prompt = Prompt(query, documents)
with context_column:
st.markdown("### Context")
- for c in response.context:
- st.markdown(c)
+ for i, doc in enumerate(documents):
+ st.markdown(f"### Document {i}")
+ st.markdown(f"**Title: {doc.title}**")
+ st.markdown(doc.text)
st.markdown("---")
+
+ with result_column:
+ st.markdown("### Answer")
+ st.write_stream(rag.retrieve(prompt))
+