diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-05-29 00:53:39 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-05-29 00:53:39 +0200 |
commit | 716e3fe58adee5b8a6bfa91de4b3ba6cf204d172 (patch) | |
tree | 778da9011d21051006fc206ce0978f0fc114b77b /rag/retriever/memory.py | |
parent | 2d91c118d71a8dd7fbd7f9cf21f86e92da33827e (diff) |
Wip memory
Diffstat (limited to 'rag/retriever/memory.py')
-rw-r--r-- | rag/retriever/memory.py | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/rag/retriever/memory.py b/rag/retriever/memory.py new file mode 100644 index 0000000..c4455ed --- /dev/null +++ b/rag/retriever/memory.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass +from typing import Dict, List + + +@dataclass +class Log: + user: Message + bot: Message + + def get(): + return (user, bot) + + +@dataclass +class Message: + role: str + message: str + + def as_dict(self, model: str) -> Dict[str, str]: + if model == "cohere": + match self.role: + case "user": + role = "USER" + case _: + role = "CHATBOT" + + return {"role": role, "message": self.message} + else: + return {"role": self.role, "content": self.message} + + +class Memory: + def __init__(self, reranker) -> None: + self.history = [] + self.reranker = reranker + self.user = "user" + self.bot = "assistant" + + def add(self, prompt: str, response: str): + self.history.append( + Log( + user=Message(role=self.user, message=prompt), + bot=Message(role=self.bot, message=response), + ) + ) + + def get(self) -> List[Log]: + return [m.as_dict() for log in self.history for m in log.get()] + + def reset(self): + self.history = [] |