summaryrefslogtreecommitdiff
path: root/rag/message.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/message.py')
-rw-r--r--rag/message.py25
1 files changed, 24 insertions, 1 deletions
diff --git a/rag/message.py b/rag/message.py
index d628982..d207596 100644
--- a/rag/message.py
+++ b/rag/message.py
@@ -1,5 +1,7 @@
from dataclasses import dataclass
-from typing import Dict
+from typing import Dict, List
+
+from loguru import logger as log
@dataclass
@@ -13,3 +15,24 @@ class Message:
return {"role": self.role, "message": self.content}
else:
return {"role": self.role, "content": self.content}
+
+
+@dataclass
+class Messages:
+ messages: List[Message]
+
+ def __add__(self, message: Message):
+ self.messages.append(message)
+
+ def __len__(self):
+ return len(self.messages)
+
+ def reset(self):
+ log.debug("Resetting messages...")
+ self.messages = []
+
+ def content(self) -> List[str]:
+ return [m.content for m in self.messages]
+
+ def rerank(self, rankings: List[int]):
+ self.messages = [self.messages[r] for r in rankings]