diff options
Diffstat (limited to 'rag')
-rw-r--r-- | rag/cli.py | 17 | ||||
-rw-r--r-- | rag/generator/__init__.py | 6 | ||||
-rw-r--r-- | rag/generator/abstract.py | 11 | ||||
-rw-r--r-- | rag/generator/cohere.py | 12 | ||||
-rw-r--r-- | rag/generator/ollama.py | 11 | ||||
-rw-r--r-- | rag/generator/prompt.py | 10 | ||||
-rw-r--r-- | rag/message.py | 15 | ||||
-rw-r--r-- | rag/model.py (renamed from rag/rag.py) | 48 | ||||
-rw-r--r-- | rag/retriever/rerank/abstract.py | 2 | ||||
-rw-r--r-- | rag/retriever/rerank/cohere.py | 10 | ||||
-rw-r--r-- | rag/retriever/rerank/local.py | 2 | ||||
-rw-r--r-- | rag/ui.py | 52 |
12 files changed, 96 insertions, 100 deletions
@@ -7,6 +7,7 @@ from tqdm import tqdm from rag.generator import get_generator from rag.generator.prompt import Prompt +from rag.model import Rag from rag.retriever.rerank import get_reranker from rag.retriever.retriever import Retriever @@ -57,22 +58,20 @@ def upload(directory: str, verbose: int): prompt="Enter your query", ) @click.option( - "-m", - "--model", + "-c", + "--client", type=click.Choice(["local", "cohere"], case_sensitive=False), default="local", help="Generator and rerank model", ) @click.option("-v", "--verbose", count=True) -def rag(query: str, model: str, verbose: int): +def rag(query: str, client: str, verbose: int): configure_logging(verbose) - retriever = Retriever() - generator = get_generator(model) - reranker = get_reranker(model) - documents = retriever.retrieve(query) - prompt = reranker.rank(Prompt(query, documents)) + rag = Rag(client) + documents = rag.retrieve(query) + prompt = Prompt(query, documents, client) print("Answer: ") - for chunk in generator.generate(prompt): + for chunk in rag.generate(prompt): print(chunk, end="", flush=True) print("\n\n") diff --git a/rag/generator/__init__.py b/rag/generator/__init__.py index a776231..770db16 100644 --- a/rag/generator/__init__.py +++ b/rag/generator/__init__.py @@ -1,8 +1,8 @@ from typing import Type -from .abstract import AbstractGenerator -from .cohere import Cohere -from .ollama import Ollama +from rag.generator.abstract import AbstractGenerator +from rag.generator.cohere import Cohere +from rag.generator.ollama import Ollama MODELS = ["local", "cohere"] diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py index 995e937..3ce997e 100644 --- a/rag/generator/abstract.py +++ b/rag/generator/abstract.py @@ -1,9 +1,8 @@ from abc import abstractmethod -from typing import Any, Generator +from typing import Any, Generator, List -from rag.rag import Message - -from .prompt import Prompt +from rag.message import Message +from rag.retriever.vector import Document class AbstractGenerator(type): @@ -16,7 +15,5 @@ class AbstractGenerator(type): return cls._instances[cls] @abstractmethod - def generate( - self, prompt: Prompt, messages: List[Message] - ) -> Generator[Any, Any, Any]: + def generate(self, messages: List[Message], documents: List[Document]) -> Generator[Any, Any, Any]: pass diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py index f30fe69..575452f 100644 --- a/rag/generator/cohere.py +++ b/rag/generator/cohere.py @@ -5,10 +5,10 @@ from typing import Any, Generator, List import cohere from loguru import logger as log -from rag.rag import Message +from rag.message import Message +from rag.retriever.vector import Document from .abstract import AbstractGenerator -from .prompt import Prompt class Cohere(metaclass=AbstractGenerator): @@ -16,13 +16,13 @@ class Cohere(metaclass=AbstractGenerator): self.client = cohere.Client(os.environ["COHERE_API_KEY"]) def generate( - self, prompt: Prompt, messages: List[Message] + self, messages: List[Message], documents: List[Document] ) -> Generator[Any, Any, Any]: log.debug("Generating answer from cohere...") for event in self.client.chat_stream( - message=prompt.to_str(), - documents=[asdict(d) for d in prompt.documents], - chat_history=[m.as_dict() for m in messages], + message=messages[-1].content, + documents=[asdict(d) for d in documents], + chat_history=[m.as_dict() for m in messages[:-1]], prompt_truncation="AUTO", ): if event.event_type == "text-generation": diff --git a/rag/generator/ollama.py b/rag/generator/ollama.py index ff5402b..84563bb 100644 --- a/rag/generator/ollama.py +++ b/rag/generator/ollama.py @@ -4,10 +4,10 @@ from typing import Any, Generator, List import ollama from loguru import logger as log -from rag.rag import Message +from rag.message import Message +from rag.retriever.vector import Document from .abstract import AbstractGenerator -from .prompt import Prompt class Ollama(metaclass=AbstractGenerator): @@ -16,12 +16,9 @@ class Ollama(metaclass=AbstractGenerator): log.debug(f"Using {self.model} for generator...") def generate( - self, prompt: Prompt, messages: List[Message] + self, messages: List[Message], documents: List[Document] ) -> Generator[Any, Any, Any]: log.debug("Generating answer with ollama...") - messages = messages.append( - Message(role="user", content=prompt.to_str(), client="ollama") - ) messages = [m.as_dict() for m in messages] for chunk in ollama.chat(model=self.model, messages=messages, stream=True): - yield chunk["response"] + yield chunk["message"]["content"] diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py index 4840fdc..cedf610 100644 --- a/rag/generator/prompt.py +++ b/rag/generator/prompt.py @@ -15,7 +15,7 @@ ANSWER_INSTRUCTION = ( class Prompt: query: str documents: List[Document] - generator_model: str + client: str def __context(self, documents: List[Document]) -> str: results = [ @@ -25,17 +25,15 @@ class Prompt: return "\n".join(results) def to_str(self) -> str: - if self.generator_model == "cohere": + if self.client == "cohere": return f"{self.query}\n\n{ANSWER_INSTRUCTION}" else: return ( "Context information is below.\n" - "---------------------\n" + "---\n" f"{self.__context(self.documents)}\n\n" - "---------------------\n" + "---\n" f"{ANSWER_INSTRUCTION}" - "Do not attempt to answer the query without relevant context and do not use" - " prior knowledge or training data!\n" f"Query: {self.query.strip()}\n\n" "Answer:" ) diff --git a/rag/message.py b/rag/message.py new file mode 100644 index 0000000..d628982 --- /dev/null +++ b/rag/message.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass +from typing import Dict + + +@dataclass +class Message: + role: str + content: str + client: str + + def as_dict(self) -> Dict[str, str]: + if self.client == "cohere": + return {"role": self.role, "message": self.content} + else: + return {"role": self.role, "content": self.content} diff --git a/rag/rag.py b/rag/model.py index 1f6a176..b186d43 100644 --- a/rag/rag.py +++ b/rag/model.py @@ -1,43 +1,32 @@ -from dataclasses import dataclass -from typing import Any, Dict, Generator, List +from typing import Any, Generator, List from loguru import logger as log from rag.generator import get_generator from rag.generator.prompt import Prompt +from rag.message import Message from rag.retriever.rerank import get_reranker from rag.retriever.retriever import Retriever from rag.retriever.vector import Document -@dataclass -class Message: - role: str - content: str - client: str - - def as_dict(self) -> Dict[str, str]: - if self.client == "cohere": - return {"role": self.role, "message": self.content} - else: - return {"role": self.role, "content": self.content} - - class Rag: - def __init__(self, client: str) -> None: + def __init__(self, client: str = "local") -> None: + self.bot = None + self.user = None + self.client = client self.messages: List[Message] = [] self.retriever = Retriever() - self.client = client self.reranker = get_reranker(self.client) self.generator = get_generator(self.client) - self.bot = "assistant" if self.client == "ollama" else "CHATBOT" - self.user = "user" if self.client == "ollama" else "USER" + self.__set_roles() def __set_roles(self): - self.bot = "assistant" if self.client == "ollama" else "CHATBOT" - self.user = "user" if self.client == "ollama" else "USER" + self.bot = "assistant" if self.client == "local" else "CHATBOT" + self.user = "user" if self.client == "local" else "USER" def set_client(self, client: str): + log.info(f"Setting client to {client}") self.client = client self.reranker = get_reranker(self.client) self.generator = get_generator(self.client) @@ -55,15 +44,14 @@ class Rag: return self.reranker.rerank_documents(query, documents) def add_message(self, role: str, content: str): - self.messages.append( - Message(role=role, content=content, client=self.client) - ) + self.messages.append(Message(role=role, content=content, client=self.client)) def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: - messages = self.reranker.rerank_messages(prompt.query, self.messages) - self.messages.append( - Message( - role=self.user, content=prompt.to_str(), client=self.client - ) + if self.messages: + messages = self.reranker.rerank_messages(prompt.query, self.messages) + else: + messages = [] + messages.append( + Message(role=self.user, content=prompt.to_str(), client=self.client) ) - return self.generator.generate(prompt, messages) + return self.generator.generate(messages, prompt.documents) diff --git a/rag/retriever/rerank/abstract.py b/rag/retriever/rerank/abstract.py index f32ee77..015a60d 100644 --- a/rag/retriever/rerank/abstract.py +++ b/rag/retriever/rerank/abstract.py @@ -1,7 +1,7 @@ from abc import abstractmethod from typing import List -from rag.memory import Message +from rag.message import Message from rag.retriever.vector import Document diff --git a/rag/retriever/rerank/cohere.py b/rag/retriever/rerank/cohere.py index 33c373d..52f31a8 100644 --- a/rag/retriever/rerank/cohere.py +++ b/rag/retriever/rerank/cohere.py @@ -4,7 +4,7 @@ from typing import List import cohere from loguru import logger as log -from rag.rag import Message +from rag.message import Message from rag.retriever.rerank.abstract import AbstractReranker from rag.retriever.vector import Document @@ -35,11 +35,11 @@ class CohereReranker(metaclass=AbstractReranker): def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: - response = self.model.rank( + response = self.client.rerank( + model="rerank-english-v3.0", query=query, - documents=[m.message for m in messages], - return_documents=False, - top_k=self.top_k, + documents=[m.content for m in messages], + top_n=self.top_k, ) ranking = list( filter( diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py index e727165..fd42c2c 100644 --- a/rag/retriever/rerank/local.py +++ b/rag/retriever/rerank/local.py @@ -2,7 +2,7 @@ import os from typing import List from loguru import logger as log -from rag.rag import Message +from rag.message import Message from rag.retriever.vector import Document from sentence_transformers import CrossEncoder @@ -1,15 +1,14 @@ -from dataclasses import dataclass -from typing import Dict, List +from typing import List import streamlit as st from dotenv import load_dotenv from langchain_community.document_loaders.blob_loaders import Blob from loguru import logger as log -from rag.generator import MODELS, get_generator +from rag.generator import MODELS from rag.generator.prompt import Prompt -from rag.retriever.rerank import get_reranker -from rag.retriever.retriever import Retriever +from rag.message import Message +from rag.model import Rag from rag.retriever.vector import Document @@ -19,25 +18,27 @@ def set_chat_users(): ss.user = "user" ss.bot = "assistant" +@st.cache_resource +def load_rag(): + log.debug("Loading Rag...") + st.session_state.rag = Rag() -def load_generator(model: str): - log.debug("Loading rag") - st.session_state.rag = get_generator(model) - -def load_reranker(model: str): - log.debug("Loading reranker model") - st.session_state.reranker = get_reranker(model) +@st.cache_resource +def set_client(client: str): + log.debug("Setting client...") + rag = st.session_state.rag + rag.set_client(client) @st.cache_data(show_spinner=False) def upload(files): - retriever = st.session_state.retriever + rag = st.session_state.rag with st.spinner("Uploading documents..."): for file in files: source = file.name blob = Blob.from_data(file.read()) - retriever.add_pdf(blob=blob, source=source) + rag.retriever.add_pdf(blob=blob, source=source) def display_context(documents: List[Document]): @@ -55,7 +56,7 @@ def display_chat(): if isinstance(msg, list): display_context(msg) else: - st.chat_message(msg.role).write(msg.message) + st.chat_message(msg.role).write(msg.content) def generate_chat(query: str): @@ -66,22 +67,24 @@ def generate_chat(query: str): rag = ss.rag documents = rag.retrieve(query) - Prompt(query, documents, self.client) + prompt = Prompt(query, documents, ss.model) with st.chat_message(ss.bot): - response = st.write_stream(rag.generate(query)) + response = st.write_stream(rag.generate(prompt)) + + rag.add_message(rag.bot, response) display_context(prompt.documents) - store_chat(query, response, prompt.documents) + store_chat(prompt, response) -def store_chat(query: str, response: str, documents: List[Document]): +def store_chat(prompt: Prompt, response: str): log.debug("Storing chat") ss = st.session_state - query = Message(role=ss.user, message=query) - response = Message(role=ss.bot, message=response) + query = Message(ss.user, prompt.query, ss.model) + response = Message(ss.bot, response, ss.model) ss.chat.append(query) ss.chat.append(response) - ss.chat.append(documents) + ss.chat.append(prompt.documents) def sidebar(): @@ -107,8 +110,7 @@ def sidebar(): "Select the model that will be used for reranking and generating the answer." ) st.selectbox("Model", key="model", options=MODELS) - load_generator(st.session_state.model) - load_reranker(st.session_state.model) + set_client(st.session_state.model) def page(): @@ -127,6 +129,6 @@ if __name__ == "__main__": load_dotenv() st.title("Retrieval Augmented Generation") set_chat_users() - load_retriever() + load_rag() sidebar() page() |