From aac821b148c6c0d35b940609dc9b0ddcb053b28e Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 19 Jun 2024 02:07:06 +0200 Subject: Still wip on rewrite --- rag/model.py | 57 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 rag/model.py (limited to 'rag/model.py') diff --git a/rag/model.py b/rag/model.py new file mode 100644 index 0000000..b186d43 --- /dev/null +++ b/rag/model.py @@ -0,0 +1,57 @@ +from typing import Any, Generator, List + +from loguru import logger as log + +from rag.generator import get_generator +from rag.generator.prompt import Prompt +from rag.message import Message +from rag.retriever.rerank import get_reranker +from rag.retriever.retriever import Retriever +from rag.retriever.vector import Document + + +class Rag: + def __init__(self, client: str = "local") -> None: + self.bot = None + self.user = None + self.client = client + self.messages: List[Message] = [] + self.retriever = Retriever() + self.reranker = get_reranker(self.client) + self.generator = get_generator(self.client) + self.__set_roles() + + def __set_roles(self): + self.bot = "assistant" if self.client == "local" else "CHATBOT" + self.user = "user" if self.client == "local" else "USER" + + def set_client(self, client: str): + log.info(f"Setting client to {client}") + self.client = client + self.reranker = get_reranker(self.client) + self.generator = get_generator(self.client) + self.__set_roles() + self.__reset_messages() + log.debug(f"Swapped client to {self.client}") + + def __reset_messages(self): + log.debug("Deleting messages...") + self.messages = [] + + def retrieve(self, query: str) -> List[Document]: + documents = self.retriever.retrieve(query) + log.info(f"Found {len(documents)} relevant documents") + return self.reranker.rerank_documents(query, documents) + + def add_message(self, role: str, content: str): + self.messages.append(Message(role=role, content=content, client=self.client)) + + def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: + if self.messages: + messages = self.reranker.rerank_messages(prompt.query, self.messages) + else: + messages = [] + messages.append( + Message(role=self.user, content=prompt.to_str(), client=self.client) + ) + return self.generator.generate(messages, prompt.documents) -- cgit v1.2.3-70-g09d2