diff options
Diffstat (limited to 'rag/retriever')
| -rw-r--r-- | rag/retriever/memory.py | 51 | ||||
| -rw-r--r-- | rag/retriever/rerank/abstract.py | 12 | ||||
| -rw-r--r-- | rag/retriever/rerank/cohere.py | 55 | ||||
| -rw-r--r-- | rag/retriever/rerank/local.py | 70 | 
4 files changed, 76 insertions, 112 deletions
diff --git a/rag/retriever/memory.py b/rag/retriever/memory.py deleted file mode 100644 index c4455ed..0000000 --- a/rag/retriever/memory.py +++ /dev/null @@ -1,51 +0,0 @@ -from dataclasses import dataclass -from typing import Dict, List - - -@dataclass -class Log: -    user: Message -    bot: Message - -    def get(): -        return (user, bot) - - -@dataclass -class Message: -    role: str -    message: str - -    def as_dict(self, model: str) -> Dict[str, str]: -        if model == "cohere": -            match self.role: -                case "user": -                    role = "USER" -                case _: -                    role = "CHATBOT" - -            return {"role": role, "message": self.message} -        else: -            return {"role": self.role, "content": self.message} - - -class Memory: -    def __init__(self, reranker) -> None: -        self.history = [] -        self.reranker = reranker -        self.user = "user" -        self.bot = "assistant" - -    def add(self, prompt: str, response: str): -        self.history.append( -            Log( -                user=Message(role=self.user, message=prompt), -                bot=Message(role=self.bot, message=response), -            ) -        ) - -    def get(self) -> List[Log]: -        return [m.as_dict() for log in self.history for m in log.get()] - -    def reset(self): -        self.history = [] diff --git a/rag/retriever/rerank/abstract.py b/rag/retriever/rerank/abstract.py index b96b70a..f32ee77 100644 --- a/rag/retriever/rerank/abstract.py +++ b/rag/retriever/rerank/abstract.py @@ -1,6 +1,8 @@  from abc import abstractmethod +from typing import List -from rag.generator.prompt import Prompt +from rag.memory import Message +from rag.retriever.vector import Document  class AbstractReranker(type): @@ -13,5 +15,9 @@ class AbstractReranker(type):          return cls._instances[cls]      @abstractmethod -    def rank(self, prompt: Prompt) -> Prompt: -        return prompt +    def rerank_documents(self, query: str, documents: List[Document]) -> List[Document]: +        pass + +    @abstractmethod +    def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: +        pass diff --git a/rag/retriever/rerank/cohere.py b/rag/retriever/rerank/cohere.py index 43690a1..33c373d 100644 --- a/rag/retriever/rerank/cohere.py +++ b/rag/retriever/rerank/cohere.py @@ -1,10 +1,12 @@  import os +from typing import List  import cohere  from loguru import logger as log -from rag.generator.prompt import Prompt +from rag.rag import Message  from rag.retriever.rerank.abstract import AbstractReranker +from rag.retriever.vector import Document  class CohereReranker(metaclass=AbstractReranker): @@ -13,22 +15,39 @@ class CohereReranker(metaclass=AbstractReranker):          self.top_k = int(os.environ["RERANK_TOP_K"])          self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"]) -    def rank(self, prompt: Prompt) -> Prompt: -        if prompt.documents: -            response = self.client.rerank( -                model="rerank-english-v3.0", -                query=prompt.query, -                documents=[d.text for d in prompt.documents], -                top_n=self.top_k, +    def rerank_documents(self, query: str, documents: List[Document]) -> List[str]: +        response = self.client.rerank( +            model="rerank-english-v3.0", +            query=query, +            documents=[d.text for d in documents], +            top_n=self.top_k, +        ) +        ranking = list( +            filter( +                lambda x: x.relevance_score > self.relevance_threshold, +                response.results,              ) -            ranking = list( -                filter( -                    lambda x: x.relevance_score > self.relevance_threshold, -                    response.results, -                ) -            ) -            log.debug( -                f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}" +        ) +        log.debug( +            f"Reranking gave {len(ranking)} relevant documents of {len(documents)}" +        ) +        return [documents[r.index] for r in ranking] + + +    def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: +        response = self.model.rank( +            query=query, +            documents=[m.message for m in messages], +            return_documents=False, +            top_k=self.top_k, +        ) +        ranking = list( +            filter( +                lambda x: x.relevance_score > self.relevance_threshold, +                response.results,              ) -            prompt.documents = [prompt.documents[r.index] for r in ranking] -        return prompt +        ) +        log.debug( +            f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}" +        ) +        return [messages[r.index] for r in ranking] diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py index 8e94882..e727165 100644 --- a/rag/retriever/rerank/local.py +++ b/rag/retriever/rerank/local.py @@ -2,10 +2,10 @@ import os  from typing import List  from loguru import logger as log +from rag.rag import Message +from rag.retriever.vector import Document  from sentence_transformers import CrossEncoder -from rag.generator.prompt import Prompt -from rag.retriever.memory import Log  from rag.retriever.rerank.abstract import AbstractReranker @@ -15,42 +15,32 @@ class Reranker(metaclass=AbstractReranker):          self.top_k = int(os.environ["RERANK_TOP_K"])          self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"]) -    def rank(self, prompt: Prompt) -> Prompt: -        if prompt.documents: -            results = self.model.rank( -                query=prompt.query, -                documents=[d.text for d in prompt.documents], -                return_documents=False, -                top_k=self.top_k, -            ) -            ranking = list( -                filter( -                    lambda x: x.get("score", 0.0) > self.relevance_threshold, results -                ) -            ) -            log.debug( -                f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}" -            ) -            prompt.documents = [ -                prompt.documents[r.get("corpus_id", 0)] for r in ranking -            ] -        return prompt +    def rerank_documents(self, query: str, documents: List[Document]) -> List[str]: +        results = self.model.rank( +            query=query, +            documents=[d.text for d in documents], +            return_documents=False, +            top_k=self.top_k, +        ) +        ranking = list( +            filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results) +        ) +        log.debug( +            f"Reranking gave {len(ranking)} relevant documents of {len(documents)}" +        ) +        return [documents[r.get("corpus_id", 0)] for r in ranking] -    def rank_memory(self, prompt: Prompt, history: List[Log]) -> List[Log]: -        if history: -            results = self.model.rank( -                query=prompt.query, -                documents=[m.bot.message for m in history], -                return_documents=False, -                top_k=self.top_k, -            ) -            ranking = list( -                filter( -                    lambda x: x.get("score", 0.0) > self.relevance_threshold, results -                ) -            ) -            log.debug( -                f"Reranking gave {len(ranking)} relevant messages of {len(history)}" -            ) -            history = [history[r.get("corpus_id", 0)] for r in ranking] -        return history +    def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: +        results = self.model.rank( +            query=query, +            documents=[m.message for m in messages], +            return_documents=False, +            top_k=self.top_k, +        ) +        ranking = list( +            filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results) +        ) +        log.debug( +            f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}" +        ) +        return [messages[r.get("corpus_id", 0)] for r in ranking]  |