summaryrefslogtreecommitdiff
path: root/rag/retriever
diff options
context:
space:
mode:
Diffstat (limited to 'rag/retriever')
-rw-r--r--rag/retriever/memory.py51
-rw-r--r--rag/retriever/rerank/abstract.py12
-rw-r--r--rag/retriever/rerank/cohere.py55
-rw-r--r--rag/retriever/rerank/local.py70
4 files changed, 76 insertions, 112 deletions
diff --git a/rag/retriever/memory.py b/rag/retriever/memory.py
deleted file mode 100644
index c4455ed..0000000
--- a/rag/retriever/memory.py
+++ /dev/null
@@ -1,51 +0,0 @@
-from dataclasses import dataclass
-from typing import Dict, List
-
-
-@dataclass
-class Log:
- user: Message
- bot: Message
-
- def get():
- return (user, bot)
-
-
-@dataclass
-class Message:
- role: str
- message: str
-
- def as_dict(self, model: str) -> Dict[str, str]:
- if model == "cohere":
- match self.role:
- case "user":
- role = "USER"
- case _:
- role = "CHATBOT"
-
- return {"role": role, "message": self.message}
- else:
- return {"role": self.role, "content": self.message}
-
-
-class Memory:
- def __init__(self, reranker) -> None:
- self.history = []
- self.reranker = reranker
- self.user = "user"
- self.bot = "assistant"
-
- def add(self, prompt: str, response: str):
- self.history.append(
- Log(
- user=Message(role=self.user, message=prompt),
- bot=Message(role=self.bot, message=response),
- )
- )
-
- def get(self) -> List[Log]:
- return [m.as_dict() for log in self.history for m in log.get()]
-
- def reset(self):
- self.history = []
diff --git a/rag/retriever/rerank/abstract.py b/rag/retriever/rerank/abstract.py
index b96b70a..f32ee77 100644
--- a/rag/retriever/rerank/abstract.py
+++ b/rag/retriever/rerank/abstract.py
@@ -1,6 +1,8 @@
from abc import abstractmethod
+from typing import List
-from rag.generator.prompt import Prompt
+from rag.memory import Message
+from rag.retriever.vector import Document
class AbstractReranker(type):
@@ -13,5 +15,9 @@ class AbstractReranker(type):
return cls._instances[cls]
@abstractmethod
- def rank(self, prompt: Prompt) -> Prompt:
- return prompt
+ def rerank_documents(self, query: str, documents: List[Document]) -> List[Document]:
+ pass
+
+ @abstractmethod
+ def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]:
+ pass
diff --git a/rag/retriever/rerank/cohere.py b/rag/retriever/rerank/cohere.py
index 43690a1..33c373d 100644
--- a/rag/retriever/rerank/cohere.py
+++ b/rag/retriever/rerank/cohere.py
@@ -1,10 +1,12 @@
import os
+from typing import List
import cohere
from loguru import logger as log
-from rag.generator.prompt import Prompt
+from rag.rag import Message
from rag.retriever.rerank.abstract import AbstractReranker
+from rag.retriever.vector import Document
class CohereReranker(metaclass=AbstractReranker):
@@ -13,22 +15,39 @@ class CohereReranker(metaclass=AbstractReranker):
self.top_k = int(os.environ["RERANK_TOP_K"])
self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"])
- def rank(self, prompt: Prompt) -> Prompt:
- if prompt.documents:
- response = self.client.rerank(
- model="rerank-english-v3.0",
- query=prompt.query,
- documents=[d.text for d in prompt.documents],
- top_n=self.top_k,
+ def rerank_documents(self, query: str, documents: List[Document]) -> List[str]:
+ response = self.client.rerank(
+ model="rerank-english-v3.0",
+ query=query,
+ documents=[d.text for d in documents],
+ top_n=self.top_k,
+ )
+ ranking = list(
+ filter(
+ lambda x: x.relevance_score > self.relevance_threshold,
+ response.results,
)
- ranking = list(
- filter(
- lambda x: x.relevance_score > self.relevance_threshold,
- response.results,
- )
- )
- log.debug(
- f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}"
+ )
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant documents of {len(documents)}"
+ )
+ return [documents[r.index] for r in ranking]
+
+
+ def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]:
+ response = self.model.rank(
+ query=query,
+ documents=[m.message for m in messages],
+ return_documents=False,
+ top_k=self.top_k,
+ )
+ ranking = list(
+ filter(
+ lambda x: x.relevance_score > self.relevance_threshold,
+ response.results,
)
- prompt.documents = [prompt.documents[r.index] for r in ranking]
- return prompt
+ )
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}"
+ )
+ return [messages[r.index] for r in ranking]
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py
index 8e94882..e727165 100644
--- a/rag/retriever/rerank/local.py
+++ b/rag/retriever/rerank/local.py
@@ -2,10 +2,10 @@ import os
from typing import List
from loguru import logger as log
+from rag.rag import Message
+from rag.retriever.vector import Document
from sentence_transformers import CrossEncoder
-from rag.generator.prompt import Prompt
-from rag.retriever.memory import Log
from rag.retriever.rerank.abstract import AbstractReranker
@@ -15,42 +15,32 @@ class Reranker(metaclass=AbstractReranker):
self.top_k = int(os.environ["RERANK_TOP_K"])
self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"])
- def rank(self, prompt: Prompt) -> Prompt:
- if prompt.documents:
- results = self.model.rank(
- query=prompt.query,
- documents=[d.text for d in prompt.documents],
- return_documents=False,
- top_k=self.top_k,
- )
- ranking = list(
- filter(
- lambda x: x.get("score", 0.0) > self.relevance_threshold, results
- )
- )
- log.debug(
- f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}"
- )
- prompt.documents = [
- prompt.documents[r.get("corpus_id", 0)] for r in ranking
- ]
- return prompt
+ def rerank_documents(self, query: str, documents: List[Document]) -> List[str]:
+ results = self.model.rank(
+ query=query,
+ documents=[d.text for d in documents],
+ return_documents=False,
+ top_k=self.top_k,
+ )
+ ranking = list(
+ filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results)
+ )
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant documents of {len(documents)}"
+ )
+ return [documents[r.get("corpus_id", 0)] for r in ranking]
- def rank_memory(self, prompt: Prompt, history: List[Log]) -> List[Log]:
- if history:
- results = self.model.rank(
- query=prompt.query,
- documents=[m.bot.message for m in history],
- return_documents=False,
- top_k=self.top_k,
- )
- ranking = list(
- filter(
- lambda x: x.get("score", 0.0) > self.relevance_threshold, results
- )
- )
- log.debug(
- f"Reranking gave {len(ranking)} relevant messages of {len(history)}"
- )
- history = [history[r.get("corpus_id", 0)] for r in ranking]
- return history
+ def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]:
+ results = self.model.rank(
+ query=query,
+ documents=[m.message for m in messages],
+ return_documents=False,
+ top_k=self.top_k,
+ )
+ ranking = list(
+ filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results)
+ )
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}"
+ )
+ return [messages[r.get("corpus_id", 0)] for r in ranking]