diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-08 22:28:47 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-08 22:28:47 +0200 |
commit | d487ef8b04cc7f5ac1491f0638f902fe2abe5ac5 (patch) | |
tree | c2e02b81b410092f083d8c1d6c606e5975f2b568 /rag | |
parent | 5f777ecdfbf486e5057d31547bdc53358037dce0 (diff) |
Wip refactor
Diffstat (limited to 'rag')
-rw-r--r-- | rag/cli.py | 2 | ||||
-rw-r--r-- | rag/db/vector.py | 25 | ||||
-rw-r--r-- | rag/llm/cohere_generator.py | 29 | ||||
-rw-r--r-- | rag/llm/encoder.py | 15 | ||||
-rw-r--r-- | rag/llm/generator.py | 33 | ||||
-rw-r--r-- | rag/llm/ollama_generator.py | 76 | ||||
-rw-r--r-- | rag/parser/pdf.py | 13 | ||||
-rw-r--r-- | rag/rag.py | 34 | ||||
-rw-r--r-- | rag/ui.py | 59 |
9 files changed, 202 insertions, 84 deletions
@@ -22,7 +22,7 @@ if __name__ == "__main__": if query: result = rag.retrive(query) print("Answer: \n") - print(result.answer) + print(result.answer + "\n") case _: print("Invalid option!") diff --git a/rag/db/vector.py b/rag/db/vector.py index bbbbf32..fd2b2c2 100644 --- a/rag/db/vector.py +++ b/rag/db/vector.py @@ -5,7 +5,7 @@ from typing import Dict, List from loguru import logger as log from qdrant_client import QdrantClient from qdrant_client.http.models import StrictFloat -from qdrant_client.models import Distance, PointStruct, ScoredPoint, VectorParams +from qdrant_client.models import Distance, PointStruct, VectorParams @dataclass @@ -15,11 +15,18 @@ class Point: payload: Dict[str, str] +@dataclass +class Document: + title: str + text: str + + class VectorDB: - def __init__(self): + def __init__(self, score_threshold: float = 0.6): self.dim = int(os.environ["EMBEDDING_DIM"]) self.collection_name = os.environ["QDRANT_COLLECTION_NAME"] self.client = QdrantClient(url=os.environ["QDRANT_URL"]) + self.score_threshold = score_threshold self.__configure() def __configure(self): @@ -47,12 +54,20 @@ class VectorDB: max_retries=3, ) - def search(self, query: List[float], limit: int = 4) -> List[ScoredPoint]: + def search(self, query: List[float], limit: int = 5) -> List[Document]: log.debug("Searching for vectors...") hits = self.client.search( collection_name=self.collection_name, query_vector=query, limit=limit, - score_threshold=0.6, + score_threshold=self.score_threshold, + ) + log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}") + return list( + map( + lambda h: Document( + title=h.payload.get("source", ""), text=h.payload["text"] + ), + hits, + ) ) - return hits diff --git a/rag/llm/cohere_generator.py b/rag/llm/cohere_generator.py new file mode 100644 index 0000000..a6feacd --- /dev/null +++ b/rag/llm/cohere_generator.py @@ -0,0 +1,29 @@ +import os +from typing import Any, Generator +import cohere + +from dataclasses import asdict +try: + from rag.llm.ollama_generator import Prompt +except ModuleNotFoundError: + from llm.ollama_generator import Prompt +from loguru import logger as log + + +class CohereGenerator: + def __init__(self) -> None: + self.client = cohere.Client(os.environ["COHERE_API_KEY"]) + + def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: + log.debug("Generating answer from cohere") + for event in self.client.chat_stream( + message=prompt.query, + documents=[asdict(d) for d in prompt.documents], + prompt_truncation="AUTO", + ): + if event.event_type == "text-generation": + yield event.text + elif event.event_type == "citation-generation": + yield event.citations + elif event.event_type == "stream-end": + yield event.finish_reason diff --git a/rag/llm/encoder.py b/rag/llm/encoder.py index 95f3c6a..a59b1b4 100644 --- a/rag/llm/encoder.py +++ b/rag/llm/encoder.py @@ -1,5 +1,6 @@ import os -from typing import Iterator, List +from pathlib import Path +from typing import List, Dict from uuid import uuid4 import ollama @@ -13,6 +14,7 @@ try: except ModuleNotFoundError: from db.vector import Point + class Encoder: def __init__(self) -> None: self.model = os.environ["ENCODER_MODEL"] @@ -21,13 +23,20 @@ class Encoder: def __encode(self, prompt: str) -> List[StrictFloat]: return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"]) - def encode_document(self, chunks: Iterator[Document]) -> List[Point]: + def __get_source(self, metadata: Dict[str, str]) -> str: + source = metadata["source"] + return Path(source).name + + def encode_document(self, chunks: List[Document]) -> List[Point]: log.debug("Encoding document...") return [ Point( id=uuid4().hex, vector=self.__encode(chunk.page_content), - payload={"text": chunk.page_content}, + payload={ + "text": chunk.page_content, + "source": self.__get_source(chunk.metadata), + }, ) for chunk in chunks ] diff --git a/rag/llm/generator.py b/rag/llm/generator.py deleted file mode 100644 index 8c7702f..0000000 --- a/rag/llm/generator.py +++ /dev/null @@ -1,33 +0,0 @@ -import os -from dataclasses import dataclass - -import ollama -from loguru import logger as log - - -@dataclass -class Prompt: - query: str - context: str - - -class Generator: - def __init__(self) -> None: - self.model = os.environ["GENERATOR_MODEL"] - - def __metaprompt(self, prompt: Prompt) -> str: - metaprompt = ( - "Answer the following question using the provided context.\n" - "If you can't find the answer, do not pretend you know it," - 'but answer "I don\'t know".\n\n' - f"Question: {prompt.query.strip()}\n\n" - "Context:\n" - f"{prompt.context.strip()}\n\n" - "Answer:\n" - ) - return metaprompt - - def generate(self, prompt: Prompt) -> str: - log.debug("Generating answer...") - metaprompt = self.__metaprompt(prompt) - return ollama.generate(model=self.model, prompt=metaprompt) diff --git a/rag/llm/ollama_generator.py b/rag/llm/ollama_generator.py new file mode 100644 index 0000000..dd17f8d --- /dev/null +++ b/rag/llm/ollama_generator.py @@ -0,0 +1,76 @@ +import os +from dataclasses import dataclass +from typing import Any, Generator, List + +import ollama +from loguru import logger as log + +try: + from rag.db.vector import Document +except ModuleNotFoundError: + from db.vector import Document + + +@dataclass +class Prompt: + query: str + documents: List[Document] + + +SYSTEM_PROMPT = ( + "# System Preamble" + "## Basic Rules" + "When you answer the user's requests, you cite your sources in your answers, according to those instructions." + "Answer the following question using the provided context.\n" + "## Style Guide" + "Unless the user asks for a different style of answer, you should answer " + "in full sentences, using proper grammar and spelling." +) + + +class OllamaGenerator: + def __init__(self) -> None: + self.model = os.environ["GENERATOR_MODEL"] + + def __context(self, documents: List[Document]) -> str: + results = [ + f"Document: {i}\ntitle: {doc.title}\n{doc.text}" + for i, doc in enumerate(documents) + ] + return "\n".join(results) + + def __metaprompt(self, prompt: Prompt) -> str: + # Include sources + metaprompt = ( + f'Question: "{prompt.query.strip()}"\n\n' + "Context:\n" + "<result>\n" + f"{self.__context(prompt.documents)}\n\n" + "</result>\n" + "Carefully perform the following instructions, in order, starting each " + "with a new line.\n" + "Firstly, Decide which of the retrieved documents are relevant to the " + "user's last input by writing 'Relevant Documents:' followed by " + "comma-separated list of document numbers.\n If none are relevant, you " + "should instead write 'None'.\n" + "Secondly, Decide which of the retrieved documents contain facts that " + "should be cited in a good answer to the user's last input by writing " + "'Cited Documents:' followed a comma-separated list of document numbers. " + "If you dont want to cite any of them, you should instead write 'None'.\n" + "Thirdly, Write 'Answer:' followed by a response to the user's last input " + "in high quality natural english. Use the retrieved documents to help you. " + "Do not insert any citations or grounding markup.\n" + "Finally, Write 'Grounded answer:' followed by a response to the user's " + "last input in high quality natural english. Use the symbols <co: doc> and " + "</co: doc> to indicate when a fact comes from a document in the search " + "result, e.g <co: 0>my fact</co: 0> for a fact from document 0." + ) + return metaprompt + + def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: + log.debug("Generating answer...") + metaprompt = self.__metaprompt(prompt) + for chunk in ollama.generate( + model=self.model, prompt=metaprompt, system=SYSTEM_PROMPT, stream=True + ): + yield chunk diff --git a/rag/parser/pdf.py b/rag/parser/pdf.py index cbd86a3..ca9b72d 100644 --- a/rag/parser/pdf.py +++ b/rag/parser/pdf.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Iterator +from typing import Iterator, List, Optional from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_core.documents import Document @@ -15,15 +15,20 @@ class PDFParser: self.parser = PyPDFParser(password=None, extract_images=False) def from_data(self, blob: Blob) -> Iterator[Document]: - yield from self.parser.parse(blob) + return self.parser.parse(blob) def from_path(self, path: Path) -> Iterator[Document]: return Blob.from_path(path) - def chunk(self, content: Iterator[Document]): + def chunk( + self, document: Iterator[Document], source: Optional[str] = None + ) -> List[Document]: splitter = RecursiveCharacterTextSplitter( chunk_size=int(os.environ["CHUNK_SIZE"]), chunk_overlap=int(os.environ["CHUNK_OVERLAP"]), ) - chunks = splitter.split_documents(content) + chunks = splitter.split_documents(document) + if source is not None: + for c in chunks: + c.metadata["source"] = source return chunks @@ -5,20 +5,22 @@ from typing import List from dotenv import load_dotenv from loguru import logger as log -from qdrant_client.models import StrictFloat + try: - from rag.db.vector import VectorDB + from rag.db.vector import VectorDB, Document from rag.db.document import DocumentDB from rag.llm.encoder import Encoder - from rag.llm.generator import Generator, Prompt + from rag.llm.ollama_generator import OllamaGenerator, Prompt + from rag.llm.cohere_generator import CohereGenerator from rag.parser.pdf import PDFParser except ModuleNotFoundError: - from db.vector import VectorDB + from db.vector import VectorDB, Document from db.document import DocumentDB from llm.encoder import Encoder - from llm.generator import Generator, Prompt + from llm.ollama_generator import OllamaGenerator, Prompt + from llm.cohere_generator import CohereGenerator from parser.pdf import PDFParser @@ -34,7 +36,7 @@ class RAG: # FIXME: load this somewhere else? load_dotenv() self.pdf_parser = PDFParser() - self.generator = Generator() + self.generator = CohereGenerator() self.encoder = Encoder() self.vector_db = VectorDB() self.doc_db = DocumentDB() @@ -43,23 +45,19 @@ class RAG: blob = self.pdf_parser.from_path(path) self.add_pdf_from_blob(blob) - def add_pdf_from_blob(self, blob: BytesIO): + def add_pdf_from_blob(self, blob: BytesIO, source: str): if self.doc_db.add(blob): log.debug("Adding pdf to vector database...") - chunks = self.pdf_parser.from_data(blob) + document = self.pdf_parser.from_data(blob) + chunks = self.pdf_parser.chunk(document, source) points = self.encoder.encode_document(chunks) self.vector_db.add(points) else: log.debug("Document already exists!") - def __context(self, query_emb: List[StrictFloat], limit: int) -> str: - hits = self.vector_db.search(query_emb, limit) - log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}") - return [h.payload["text"] for h in hits] - - def retrive(self, query: str, limit: int = 5) -> Response: + def search(self, query: str, limit: int = 5) -> List[Document]: query_emb = self.encoder.encode_query(query) - context = self.__context(query_emb, limit) - prompt = Prompt(query, "\n".join(context)) - answer = self.generator.generate(prompt)["response"] - return Response(query, context, answer) + return self.vector_db.search(query_emb, limit) + + def retrieve(self, prompt: Prompt): + yield from self.generator.generate(prompt) @@ -4,8 +4,10 @@ from langchain_community.document_loaders.blob_loaders import Blob try: from rag.rag import RAG + from rag.llm.ollama_generator import Prompt except ModuleNotFoundError: from rag import RAG + from llm.ollama_generator import Prompt rag = RAG() @@ -16,9 +18,15 @@ def upload_pdfs(): type="pdf", accept_multiple_files=True, ) - for file in files: - blob = Blob.from_data(file.read()) - rag.add_pdf_from_blob(blob) + + if not files: + return + + with st.spinner("Indexing documents..."): + for file in files: + source = file.name + blob = Blob.from_data(file.read()) + rag.add_pdf_from_blob(blob, source) if __name__ == "__main__": @@ -26,30 +34,41 @@ if __name__ == "__main__": st.header("RAG-UI") upload_pdfs() - query = st.text_area( - "query", - key="query", - height=100, - placeholder="Enter query here", - help="", - label_visibility="collapsed", - disabled=False, - ) + + with st.form(key="query"): + query = st.text_area( + "query", + key="query", + height=100, + placeholder="Enter query here", + help="", + label_visibility="collapsed", + disabled=False, + ) + submit = st.form_submit_button("Generate") (b,) = st.columns(1) (result_column, context_column) = st.columns(2) - if b.button("Generate", disabled=False, type="primary", use_container_width=True): + if submit: + if not query: + st.stop() + query = ss.get("query", "") - with st.spinner("Generating answer..."): - response = rag.retrieve(query) + with st.spinner("Searching for documents..."): + documents = rag.search(query) - with result_column: - st.markdown("### Answer") - st.markdown(response.answer) + prompt = Prompt(query, documents) with context_column: st.markdown("### Context") - for c in response.context: - st.markdown(c) + for i, doc in enumerate(documents): + st.markdown(f"### Document {i}") + st.markdown(f"**Title: {doc.title}**") + st.markdown(doc.text) st.markdown("---") + + with result_column: + st.markdown("### Answer") + st.write_stream(rag.retrieve(prompt)) + |