diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-06-19 02:07:06 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-06-19 02:07:06 +0200 |
commit | aac821b148c6c0d35b940609dc9b0ddcb053b28e (patch) | |
tree | 5c125045b2b60ead39e093327d664adf43d1d35b /rag/rag.py | |
parent | f2846429310452bebbf0d07203b1e53978c439c7 (diff) |
Still wip on rewrite
Diffstat (limited to 'rag/rag.py')
-rw-r--r-- | rag/rag.py | 69 |
1 files changed, 0 insertions, 69 deletions
diff --git a/rag/rag.py b/rag/rag.py deleted file mode 100644 index 1f6a176..0000000 --- a/rag/rag.py +++ /dev/null @@ -1,69 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Dict, Generator, List - -from loguru import logger as log - -from rag.generator import get_generator -from rag.generator.prompt import Prompt -from rag.retriever.rerank import get_reranker -from rag.retriever.retriever import Retriever -from rag.retriever.vector import Document - - -@dataclass -class Message: - role: str - content: str - client: str - - def as_dict(self) -> Dict[str, str]: - if self.client == "cohere": - return {"role": self.role, "message": self.content} - else: - return {"role": self.role, "content": self.content} - - -class Rag: - def __init__(self, client: str) -> None: - self.messages: List[Message] = [] - self.retriever = Retriever() - self.client = client - self.reranker = get_reranker(self.client) - self.generator = get_generator(self.client) - self.bot = "assistant" if self.client == "ollama" else "CHATBOT" - self.user = "user" if self.client == "ollama" else "USER" - - def __set_roles(self): - self.bot = "assistant" if self.client == "ollama" else "CHATBOT" - self.user = "user" if self.client == "ollama" else "USER" - - def set_client(self, client: str): - self.client = client - self.reranker = get_reranker(self.client) - self.generator = get_generator(self.client) - self.__set_roles() - self.__reset_messages() - log.debug(f"Swapped client to {self.client}") - - def __reset_messages(self): - log.debug("Deleting messages...") - self.messages = [] - - def retrieve(self, query: str) -> List[Document]: - documents = self.retriever.retrieve(query) - log.info(f"Found {len(documents)} relevant documents") - return self.reranker.rerank_documents(query, documents) - - def add_message(self, role: str, content: str): - self.messages.append( - Message(role=role, content=content, client=self.client) - ) - - def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: - messages = self.reranker.rerank_messages(prompt.query, self.messages) - self.messages.append( - Message( - role=self.user, content=prompt.to_str(), client=self.client - ) - ) - return self.generator.generate(prompt, messages) |