summaryrefslogtreecommitdiff
path: root/rag/retriever/memory.py
blob: c4455ed22d464e4675f41e803c0c35b2e92bc8c9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from dataclasses import dataclass
from typing import Dict, List


@dataclass
class Log:
    user: Message
    bot: Message

    def get():
        return (user, bot)


@dataclass
class Message:
    role: str
    message: str

    def as_dict(self, model: str) -> Dict[str, str]:
        if model == "cohere":
            match self.role:
                case "user":
                    role = "USER"
                case _:
                    role = "CHATBOT"

            return {"role": role, "message": self.message}
        else:
            return {"role": self.role, "content": self.message}


class Memory:
    def __init__(self, reranker) -> None:
        self.history = []
        self.reranker = reranker
        self.user = "user"
        self.bot = "assistant"

    def add(self, prompt: str, response: str):
        self.history.append(
            Log(
                user=Message(role=self.user, message=prompt),
                bot=Message(role=self.bot, message=response),
            )
        )

    def get(self) -> List[Log]:
        return [m.as_dict() for log in self.history for m in log.get()]

    def reset(self):
        self.history = []