summaryrefslogtreecommitdiff
path: root/rag/retriever
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-05-29 00:53:39 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-05-29 00:53:39 +0200
commit716e3fe58adee5b8a6bfa91de4b3ba6cf204d172 (patch)
tree778da9011d21051006fc206ce0978f0fc114b77b /rag/retriever
parent2d91c118d71a8dd7fbd7f9cf21f86e92da33827e (diff)
Wip memory
Diffstat (limited to 'rag/retriever')
-rw-r--r--rag/retriever/memory.py51
-rw-r--r--rag/retriever/rerank/local.py21
2 files changed, 72 insertions, 0 deletions
diff --git a/rag/retriever/memory.py b/rag/retriever/memory.py
new file mode 100644
index 0000000..c4455ed
--- /dev/null
+++ b/rag/retriever/memory.py
@@ -0,0 +1,51 @@
+from dataclasses import dataclass
+from typing import Dict, List
+
+
+@dataclass
+class Log:
+ user: Message
+ bot: Message
+
+ def get():
+ return (user, bot)
+
+
+@dataclass
+class Message:
+ role: str
+ message: str
+
+ def as_dict(self, model: str) -> Dict[str, str]:
+ if model == "cohere":
+ match self.role:
+ case "user":
+ role = "USER"
+ case _:
+ role = "CHATBOT"
+
+ return {"role": role, "message": self.message}
+ else:
+ return {"role": self.role, "content": self.message}
+
+
+class Memory:
+ def __init__(self, reranker) -> None:
+ self.history = []
+ self.reranker = reranker
+ self.user = "user"
+ self.bot = "assistant"
+
+ def add(self, prompt: str, response: str):
+ self.history.append(
+ Log(
+ user=Message(role=self.user, message=prompt),
+ bot=Message(role=self.bot, message=response),
+ )
+ )
+
+ def get(self) -> List[Log]:
+ return [m.as_dict() for log in self.history for m in log.get()]
+
+ def reset(self):
+ self.history = []
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py
index 75fedd8..8e94882 100644
--- a/rag/retriever/rerank/local.py
+++ b/rag/retriever/rerank/local.py
@@ -1,9 +1,11 @@
import os
+from typing import List
from loguru import logger as log
from sentence_transformers import CrossEncoder
from rag.generator.prompt import Prompt
+from rag.retriever.memory import Log
from rag.retriever.rerank.abstract import AbstractReranker
@@ -33,3 +35,22 @@ class Reranker(metaclass=AbstractReranker):
prompt.documents[r.get("corpus_id", 0)] for r in ranking
]
return prompt
+
+ def rank_memory(self, prompt: Prompt, history: List[Log]) -> List[Log]:
+ if history:
+ results = self.model.rank(
+ query=prompt.query,
+ documents=[m.bot.message for m in history],
+ return_documents=False,
+ top_k=self.top_k,
+ )
+ ranking = list(
+ filter(
+ lambda x: x.get("score", 0.0) > self.relevance_threshold, results
+ )
+ )
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant messages of {len(history)}"
+ )
+ history = [history[r.get("corpus_id", 0)] for r in ranking]
+ return history