diff options
Diffstat (limited to 'rag')
-rw-r--r-- | rag/cli.py | 11 | ||||
-rw-r--r-- | rag/generator/prompt.py | 9 | ||||
-rw-r--r-- | rag/mcp/__init__.py | 0 | ||||
-rw-r--r-- | rag/mcp/client.py | 0 | ||||
-rw-r--r-- | rag/mcp/server.py | 0 | ||||
-rw-r--r-- | rag/message.py | 25 | ||||
-rw-r--r-- | rag/model.py | 26 | ||||
-rw-r--r-- | rag/retriever/document.py | 2 | ||||
-rw-r--r-- | rag/retriever/encoder.py | 42 | ||||
-rw-r--r-- | rag/retriever/parser/pdf.py | 6 | ||||
-rw-r--r-- | rag/retriever/rerank/local.py | 45 | ||||
-rw-r--r-- | rag/retriever/retriever.py | 57 | ||||
-rw-r--r-- | rag/retriever/vector.py | 32 | ||||
-rw-r--r-- | rag/static/styles.tcss | 48 | ||||
-rw-r--r-- | rag/tui.py | 71 | ||||
-rw-r--r-- | rag/ui.py | 134 | ||||
-rw-r--r-- | rag/workflows.py | 7 |
17 files changed, 285 insertions, 230 deletions
@@ -7,6 +7,7 @@ from tqdm import tqdm from rag.generator.prompt import Prompt from rag.model import Rag +from rag.retriever.retriever import FilePath from rag.retriever.retriever import Retriever @@ -38,13 +39,13 @@ def cli(): default=None, ) @click.option("-v", "--verbose", count=True) -def upload(directory: str, verbose: int): +def index(directory: str, verbose: int): configure_logging(verbose) log.info(f"Uploading pfs found in directory {directory}...") retriever = Retriever() pdfs = Path(directory).glob("**/*.pdf") for path in tqdm(list(pdfs)): - retriever.add_pdf(path=path) + retriever.index(FilePath(path)) @click.command() @@ -56,7 +57,7 @@ def upload(directory: str, verbose: int): help="Generator and rerank model", ) @click.option("-v", "--verbose", count=True) -def rag(client: str, verbose: int): +def search(client: str, verbose: int): configure_logging(verbose) rag = Rag(client) while True: @@ -92,8 +93,8 @@ def drop(): vec_db.delete_collection() -cli.add_command(rag) -cli.add_command(upload) +cli.add_command(search) +cli.add_command(index) cli.add_command(drop) if __name__ == "__main__": diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py index 75966e8..16ea447 100644 --- a/rag/generator/prompt.py +++ b/rag/generator/prompt.py @@ -4,10 +4,8 @@ from typing import List from rag.retriever.vector import Document ANSWER_INSTRUCTION = ( - "Do not attempt to answer the query without relevant context, and do not use " - "prior knowledge or training data!\n" - "If the context does not contain the answer or is empty, only reply that you " - "cannot answer the query given the context." + "Using the information contained in the context, give a comprehensive answer to the question.\n" + "If the answer cannot be deduced from the context, do not give an answer.\n\n" ) @@ -30,8 +28,7 @@ class Prompt: else: return ( "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n" - "Using the information contained in the context, give a comprehensive answer to the question.\n" - "If the answer cannot be deduced from the context, do not give an answer.\n\n" + f"{ANSWER_INSTRUCTION}" "Context:\n" "---\n" f"{self.__context(self.documents)}\n\n" diff --git a/rag/mcp/__init__.py b/rag/mcp/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/rag/mcp/__init__.py diff --git a/rag/mcp/client.py b/rag/mcp/client.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/rag/mcp/client.py diff --git a/rag/mcp/server.py b/rag/mcp/server.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/rag/mcp/server.py diff --git a/rag/message.py b/rag/message.py index d628982..d207596 100644 --- a/rag/message.py +++ b/rag/message.py @@ -1,5 +1,7 @@ from dataclasses import dataclass -from typing import Dict +from typing import Dict, List + +from loguru import logger as log @dataclass @@ -13,3 +15,24 @@ class Message: return {"role": self.role, "message": self.content} else: return {"role": self.role, "content": self.content} + + +@dataclass +class Messages: + messages: List[Message] + + def __add__(self, message: Message): + self.messages.append(message) + + def __len__(self): + return len(self.messages) + + def reset(self): + log.debug("Resetting messages...") + self.messages = [] + + def content(self) -> List[str]: + return [m.content for m in self.messages] + + def rerank(self, rankings: List[int]): + self.messages = [self.messages[r] for r in rankings] diff --git a/rag/model.py b/rag/model.py index b186d43..c6df9e2 100644 --- a/rag/model.py +++ b/rag/model.py @@ -1,10 +1,11 @@ -from typing import Any, Generator, List +from typing import Any, Generator, List, Optional from loguru import logger as log from rag.generator import get_generator from rag.generator.prompt import Prompt -from rag.message import Message +from rag.message import Message, Messages +from rag.retriever.encoder import Query from rag.retriever.rerank import get_reranker from rag.retriever.retriever import Retriever from rag.retriever.vector import Document @@ -21,6 +22,7 @@ class Rag: self.generator = get_generator(self.client) self.__set_roles() + # TODO: move this to messages def __set_roles(self): self.bot = "assistant" if self.client == "local" else "CHATBOT" self.user = "user" if self.client == "local" else "USER" @@ -34,21 +36,21 @@ class Rag: self.__reset_messages() log.debug(f"Swapped client to {self.client}") - def __reset_messages(self): - log.debug("Deleting messages...") - self.messages = [] - - def retrieve(self, query: str) -> List[Document]: - documents = self.retriever.retrieve(query) + def retrieve(self, query: Query) -> List[Document]: + documents = self.retriever.search(query) log.info(f"Found {len(documents)} relevant documents") - return self.reranker.rerank_documents(query, documents) + return self.reranker.rerank(query, documents) + # TODO: move this to messages def add_message(self, role: str, content: str): self.messages.append(Message(role=role, content=content, client=self.client)) - def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: - if self.messages: - messages = self.reranker.rerank_messages(prompt.query, self.messages) + # TODO: fix the reranking + def generate( + self, prompt: Prompt, messages: Optional[Messages] = None + ) -> Generator[Any, Any, Any]: + if messages: + messages = self.reranker.rerank(prompt.query, self.messages) else: messages = [] messages.append( diff --git a/rag/retriever/document.py b/rag/retriever/document.py index 132ec4b..df7a057 100644 --- a/rag/retriever/document.py +++ b/rag/retriever/document.py @@ -44,7 +44,7 @@ class DocumentDB: ) self.conn.commit() - def add(self, blob: Blob) -> bool: + def create(self, blob: Blob) -> bool: with self.conn.cursor() as cur: hash = self.__hash(blob) cur.execute( diff --git a/rag/retriever/encoder.py b/rag/retriever/encoder.py index b68c3bb..8b02a14 100644 --- a/rag/retriever/encoder.py +++ b/rag/retriever/encoder.py @@ -1,7 +1,8 @@ +from dataclasses import dataclass import hashlib import os from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Union import ollama from langchain_core.documents import Document @@ -9,29 +10,42 @@ from loguru import logger as log from qdrant_client.http.models import StrictFloat from tqdm import tqdm -from .vector import Point +from .vector import Documents, Point + +@dataclass +class Query: + query: str + + +Input = Query | Documents class Encoder: def __init__(self) -> None: self.model = os.environ["ENCODER_MODEL"] - self.query_prompt = "Represent this sentence for searching relevant passages: " - - def __encode(self, prompt: str) -> List[StrictFloat]: - return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"]) + self.preamble = ( + "Represent this sentence for searching relevant passages: " + if "mxbai-embed-large" in model_name + else "" + ) def __get_source(self, metadata: Dict[str, str]) -> str: source = metadata["source"] return Path(source).name - def encode_document(self, chunks: List[Document]) -> List[Point]: + def __encode(self, prompt: str) -> List[StrictFloat]: + return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"]) + + # TODO: move this to vec db and just return the embeddings + # TODO: use late chunking here + def __encode_document(self, chunks: List[Document]) -> List[Point]: log.debug("Encoding document...") return [ Point( id=hashlib.sha256( chunk.page_content.encode(encoding="utf-8") ).hexdigest(), - vector=self.__encode(chunk.page_content), + vector=list(self.__encode(chunk.page_content)), payload={ "text": chunk.page_content, "source": self.__get_source(chunk.metadata), @@ -40,8 +54,14 @@ class Encoder: for chunk in tqdm(chunks) ] - def encode_query(self, query: str) -> List[StrictFloat]: + def __encode_query(self, query: str) -> List[StrictFloat]: log.debug(f"Encoding query: {query}") - if self.model == "mxbai-embed-large": - query = self.query_prompt + query + query = self.preamble + query return self.__encode(query) + + def encode(self, x: Input) -> Union[List[StrictFloat], List[Point]]: + match x: + case Query(query): + return self.__encode_query(query) + case Documents(documents): + return self.__encode_document(documents) diff --git a/rag/retriever/parser/pdf.py b/rag/retriever/parser/pdf.py index 4c5addc..3253dc1 100644 --- a/rag/retriever/parser/pdf.py +++ b/rag/retriever/parser/pdf.py @@ -8,8 +8,10 @@ from langchain_community.document_loaders.parsers.pdf import ( PyPDFParser, ) from langchain_core.documents import Document +from rag.retriever.encoder import Chunks +# TODO: fix the PDFParser, remove langchain class PDFParser: def __init__(self) -> None: self.parser = PyPDFParser(password=None, extract_images=False) @@ -22,7 +24,7 @@ class PDFParser: def chunk( self, document: List[Document], source: Optional[str] = None - ) -> List[Document]: + ) -> Chunks: splitter = RecursiveCharacterTextSplitter( chunk_size=int(os.environ["CHUNK_SIZE"]), chunk_overlap=int(os.environ["CHUNK_OVERLAP"]), @@ -31,4 +33,4 @@ class PDFParser: if source is not None: for c in chunks: c.metadata["source"] = source - return chunks + return Chunks(chunks) diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py index 231d50a..e2bef31 100644 --- a/rag/retriever/rerank/local.py +++ b/rag/retriever/rerank/local.py @@ -1,46 +1,39 @@ import os -from typing import List from loguru import logger as log from sentence_transformers import CrossEncoder -from rag.message import Message +from rag.message import Messages +from rag.retriever.encoder import Query from rag.retriever.rerank.abstract import AbstractReranker -from rag.retriever.vector import Document +from rag.retriever.vector import Documents + +Context = Documents | Messages class Reranker(metaclass=AbstractReranker): def __init__(self) -> None: self.model = CrossEncoder(os.environ["RERANK_MODEL"], device="cpu") self.top_k = int(os.environ["RERANK_TOP_K"]) - self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"]) - - def rerank_documents(self, query: str, documents: List[Document]) -> List[str]: - results = self.model.rank( - query=query, - documents=[d.text for d in documents], - return_documents=False, - top_k=self.top_k, - ) - ranking = list( - filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results) - ) - log.debug( - f"Reranking gave {len(ranking)} relevant documents of {len(documents)}" - ) - return [documents[r.get("corpus_id", 0)] for r in ranking] + self.relevance_threshold = float(os.environ["RERANK_RELEVANCE_THRESHOLD"]) - def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: + def rerank(self, query: Query, documents: Context) -> Context: results = self.model.rank( - query=query, - documents=[m.content for m in messages], + query=query.query, + documents=documents.content(), return_documents=False, top_k=self.top_k, ) - ranking = list( - filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results) + rankings = list( + map( + lambda x: x.get("corpus_id", 0), + filter( + lambda x: x.get("score", 0.0) > self.relevance_threshold, results + ), + ) ) log.debug( - f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}" + f"Reranking gave {len(rankings)} relevant documents of {len(documents)}" ) - return [messages[r.get("corpus_id", 0)] for r in ranking] + documents.rerank(rankings) + return documents diff --git a/rag/retriever/retriever.py b/rag/retriever/retriever.py index 351cfb0..7d43941 100644 --- a/rag/retriever/retriever.py +++ b/rag/retriever/retriever.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from io import BytesIO from pathlib import Path from typing import List, Optional @@ -5,11 +6,25 @@ from typing import List, Optional from loguru import logger as log from .document import DocumentDB -from .encoder import Encoder +from .encoder import Encoder, Query from .parser.pdf import PDFParser from .vector import Document, VectorDB +@dataclass +class FilePath: + path: Path + + +@dataclass +class Blob: + blob: BytesIO + source: Optional[str] = None + + +FileType = FilePath | Blob + + class Retriever: def __init__(self) -> None: self.pdf_parser = PDFParser() @@ -17,35 +32,29 @@ class Retriever: self.doc_db = DocumentDB() self.vec_db = VectorDB() - def __add_pdf_from_path(self, path: Path): - log.debug(f"Adding pdf from {path}") + def __index_pdf_from_path(self, path: Path): + log.debug(f"Indexing pdf from {path}") blob = self.pdf_parser.from_path(path) - self.__add_pdf_from_blob(blob) + self.__index_pdf_from_blob(blob, None) - def __add_pdf_from_blob(self, blob: BytesIO, source: Optional[str] = None): - if self.doc_db.add(blob): - log.debug("Adding pdf to vector database...") + def __index_pdf_from_blob(self, blob: BytesIO, source: Optional[str]): + if self.doc_db.create(blob): + log.debug("Indexing pdf to vector database...") document = self.pdf_parser.from_data(blob) chunks = self.pdf_parser.chunk(document, source) - points = self.encoder.encode_document(chunks) - self.vec_db.add(points) + points = self.encoder.encode(chunks) + self.vec_db.index(points) else: log.debug("Document already exists!") - def add_pdf( - self, - path: Optional[Path] = None, - blob: Optional[BytesIO] = None, - source: Optional[str] = None, - ): - if path: - self.__add_pdf_from_path(path) - elif blob and source: - self.__add_pdf_from_blob(blob, source) - else: - log.error("Invalid input!") + def index(self, filetype: FileType): + match filetype: + case FilePath(path): + self.__index_pdf_from_path(path) + case Blob(blob, source): + self.__index_pdf_from_blob(blob, source) - def retrieve(self, query: str) -> List[Document]: - log.debug(f"Finding documents matching query: {query}") - query_emb = self.encoder.encode_query(query) + def search(self, query: Query) -> List[Document]: + log.debug(f"Finding documents matching query: {query.query}") + query_emb = self.encoder.encode(query) return self.vec_db.search(query_emb) diff --git a/rag/retriever/vector.py b/rag/retriever/vector.py index 1a484f3..b36aee8 100644 --- a/rag/retriever/vector.py +++ b/rag/retriever/vector.py @@ -21,6 +21,20 @@ class Document: text: str +@dataclass +class Documents: + documents: List[Document] + + def __len__(self): + return len(self.documents) + + def content(self) -> List[str]: + return [d.text for d in self.documents] + + def rerank(self, rankings: List[int]): + self.documents = [self.documents[r] for r in rankings] + + class VectorDB: def __init__(self): self.dim = int(os.environ["EMBEDDING_DIM"]) @@ -47,7 +61,7 @@ class VectorDB: log.info(f"Deleting collection {self.collection_name}") self.client.delete_collection(self.collection_name) - def add(self, points: List[Point]): + def index(self, points: List[Point]): log.debug(f"Inserting {len(points)} vectors into the vector db...") self.client.upload_points( collection_name=self.collection_name, @@ -59,7 +73,7 @@ class VectorDB: max_retries=3, ) - def search(self, query: List[float]) -> List[Document]: + def search(self, query: List[float]) -> Documents: log.debug("Searching for vectors...") hits = self.client.search( collection_name=self.collection_name, @@ -68,11 +82,13 @@ class VectorDB: score_threshold=self.score_threshold, ) log.debug(f"Got {len(hits)} hits in the vector db with limit={self.top_k}") - return list( - map( - lambda h: Document( - title=h.payload.get("source", ""), text=h.payload["text"] - ), - hits, + return Documents( + list( + map( + lambda h: Document( + title=h.payload.get("source", ""), text=h.payload["text"] + ), + hits, + ) ) ) diff --git a/rag/static/styles.tcss b/rag/static/styles.tcss new file mode 100644 index 0000000..902f60b --- /dev/null +++ b/rag/static/styles.tcss @@ -0,0 +1,48 @@ +#rag { + background: #151515; + color: #E1E1E1; + height: auto; + width: 100%; + margin: 0 0 2 0; + grid-size: 2 2; +} + +#history { + scrollbar-size: 1 1; + height: 80%; +} + +#output { + height: 100%; +} + +#chat { + border: round #373737; + padding: 1; + margin-top: 2; + margin-right: 2; +} + +#input { + height: 20%; + dock: bottom; +} + +#references { + scrollbar-size: 1 1; + dock: right; + border-left: solid #373737; + height: 100%; + width: 20%; +} + +Input { + dock: bottom; + align_horizontal: center; + width: 90%; + height: 10%; + background: #202020; + color: #E1E1E1; + border: round #373737; +} + diff --git a/rag/tui.py b/rag/tui.py new file mode 100644 index 0000000..ead4f71 --- /dev/null +++ b/rag/tui.py @@ -0,0 +1,71 @@ +from textual import events +from textual.app import App, ComposeResult +from textual.containers import Vertical, Container, VerticalScroll +from textual.widgets import Input, Label, Markdown, Static, TabbedContent, TabPane + +JESSICA = """ +# Lady Jessica + +Bene Gesserit and concubine of Leto, and mother of Paul and Alia. +""" + +PAUL = """ +# Paul Atreides + +Son of Leto and Jessica. +""" + +TEXT = """\ +Docking a widget removes it from the layout and fixes its position, aligned to either the top, right, bottom, or left edges of a container. + +Docked widgets will not scroll out of view, making them ideal for sticky headers, footers, and sidebars. + +""" + +class TabbedApp(App): + """An example of tabbed content.""" + + CSS_PATH = "static/styles.tcss" + + BINDINGS = [ + ("n", "show_tab('rag')", "Rag"), + ("e", "show_tab('settings')", "Settings"), + ] + + def compose(self) -> ComposeResult: + """Compose app with tabbed content.""" + # Add the TabbedContent widget + with Container(id="rag"): + with VerticalScroll(id="references"): + yield Static("test3", classes="context") + with VerticalScroll(id="history"): + yield Static(TEXT * 10, classes="output") + with Vertical(id="chat"): + yield Static("test2", classes="input") + # with TabbedContent(initial="rag"): + # with TabPane("RAG", id="rag"): # First tab + # # yield Input(placeholder=">>", id="chat") + # yield Static("test1", classes="chat") + # yield Static("test", classes="context") + # with TabPane("Settings", id="settings"): + # yield Markdown(JESSICA) + # with TabbedContent("Paul", "Alia"): + # yield TabPane("Paul", Label("First child")) + # yield TabPane("Alia", Label("Second child")) + + def action_show_tab(self, tab: str) -> None: + """Switch to a new tab.""" + self.get_child_by_type(TabbedContent).active = tab + + def on_key(self, event: events.Key) -> None: + if event.key == "s": + self.action_show_tab("settings") + if event.key == "r": + self.action_show_tab("rag") + if event.key == "q": + self.exit() + + +if __name__ == "__main__": + app = TabbedApp() + app.run() diff --git a/rag/ui.py b/rag/ui.py deleted file mode 100644 index 2192ad8..0000000 --- a/rag/ui.py +++ /dev/null @@ -1,134 +0,0 @@ -from typing import List - -import streamlit as st -from dotenv import load_dotenv -from langchain_community.document_loaders.blob_loaders import Blob -from loguru import logger as log - -from rag.generator import MODELS -from rag.generator.prompt import Prompt -from rag.message import Message -from rag.model import Rag -from rag.retriever.vector import Document - - -def set_chat_users(): - log.debug("Setting user and bot value") - ss = st.session_state - ss.user = "user" - ss.bot = "assistant" - -@st.cache_resource -def load_rag(): - log.debug("Loading Rag...") - st.session_state.rag = Rag() - - -@st.cache_resource -def set_client(client: str): - log.debug("Setting client...") - rag = st.session_state.rag - rag.set_client(client) - - -@st.cache_data(show_spinner=False) -def upload(files): - rag = st.session_state.rag - with st.spinner("Uploading documents..."): - for file in files: - source = file.name - blob = Blob.from_data(file.read()) - rag.retriever.add_pdf(blob=blob, source=source) - - -def display_context(documents: List[Document]): - with st.popover("See Context"): - for i, doc in enumerate(documents): - st.markdown(f"### Document {i}") - st.markdown(f"**Title: {doc.title}**") - st.markdown(doc.text) - st.markdown("---") - - -def display_chat(): - ss = st.session_state - for msg in ss.chat: - if isinstance(msg, list): - display_context(msg) - else: - st.chat_message(msg.role).write(msg.content) - - -def generate_chat(query: str): - ss = st.session_state - - with st.chat_message(ss.user): - st.write(query) - - rag = ss.rag - documents = rag.retrieve(query) - prompt = Prompt(query, documents, ss.model) - with st.chat_message(ss.bot): - response = st.write_stream(rag.generate(prompt)) - - rag.add_message(rag.bot, response) - - display_context(prompt.documents) - store_chat(prompt, response) - - -def store_chat(prompt: Prompt, response: str): - log.debug("Storing chat") - ss = st.session_state - query = Message(ss.user, prompt.query, ss.model) - response = Message(ss.bot, response, ss.model) - ss.chat.append(query) - ss.chat.append(response) - ss.chat.append(prompt.documents) - - -def sidebar(): - with st.sidebar: - st.header("Grounding") - st.markdown( - ( - "These files will be uploaded to the knowledge base and used " - "as groudning if they are relevant to the question." - ) - ) - - files = st.file_uploader( - "Choose pdfs to add to the knowledge base", - type="pdf", - accept_multiple_files=True, - ) - - upload(files) - - st.header("Model") - st.markdown( - "Select the model that will be used for reranking and generating the answer." - ) - st.selectbox("Model", key="model", options=MODELS) - set_client(st.session_state.model) - - -def page(): - ss = st.session_state - if "chat" not in st.session_state: - ss.chat = [] - - display_chat() - - query = st.chat_input("Enter query here") - if query: - generate_chat(query) - - -if __name__ == "__main__": - load_dotenv() - st.title("Retrieval Augmented Generation") - set_chat_users() - load_rag() - sidebar() - page() diff --git a/rag/workflows.py b/rag/workflows.py new file mode 100644 index 0000000..58d1d78 --- /dev/null +++ b/rag/workflows.py @@ -0,0 +1,7 @@ + + +def ollama_workflow(): + pass + +def cohere_workflow(): + pass |