From 9e0cbcb4e7f1f3f95f304046d3190c6ebc4d3901 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 24 Apr 2024 09:09:24 +0200 Subject: Reformat and fix typo --- rag/generator/prompt.py | 7 ++-- rag/retriever/rerank.py | 77 ---------------------------------------- rag/retriever/rerank/__init__.py | 15 ++++++++ rag/retriever/rerank/abstract.py | 17 +++++++++ rag/retriever/rerank/cohere.py | 28 +++++++++++++++ rag/retriever/rerank/local.py | 30 ++++++++++++++++ rag/ui.py | 5 ++- 7 files changed, 96 insertions(+), 83 deletions(-) delete mode 100644 rag/retriever/rerank.py create mode 100644 rag/retriever/rerank/__init__.py create mode 100644 rag/retriever/rerank/abstract.py create mode 100644 rag/retriever/rerank/cohere.py create mode 100644 rag/retriever/rerank/local.py diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py index 10afe7b..6523842 100644 --- a/rag/generator/prompt.py +++ b/rag/generator/prompt.py @@ -4,9 +4,10 @@ from typing import List from rag.retriever.vector import Document ANSWER_INSTRUCTION = ( - "Given the context information and not prior knowledge, answer the query." - "If the context is irrelevant to the query or empty, then do not attempt to answer " - "the query, just reply that you do not know based on the context provided.\n" + "Do not attempt to answer the query without relevant context, and do not use " + "prior knowledge or training data!\n" + "If the context does not contain the answer or is empty, only reply that you " + "cannot answer the query given the context." ) diff --git a/rag/retriever/rerank.py b/rag/retriever/rerank.py deleted file mode 100644 index 08a9a27..0000000 --- a/rag/retriever/rerank.py +++ /dev/null @@ -1,77 +0,0 @@ -import os -from abc import abstractmethod -from typing import Type - -import cohere -from loguru import logger as log -from sentence_transformers import CrossEncoder - -from rag.generator.prompt import Prompt - - -class AbstractReranker(type): - _instances = {} - - def __call__(cls, *args, **kwargs): - if cls not in cls._instances: - instance = super().__call__(*args, **kwargs) - cls._instances[cls] = instance - return cls._instances[cls] - - @abstractmethod - def rank(self, prompt: Prompt) -> Prompt: - return prompt - - -class Reranker(metaclass=AbstractReranker): - def __init__(self) -> None: - self.model = CrossEncoder(os.environ["RERANK_MODEL"]) - self.top_k = int(os.environ["RERANK_TOP_K"]) - - def rank(self, prompt: Prompt) -> Prompt: - if prompt.documents: - results = self.model.rank( - query=prompt.query, - documents=[d.text for d in prompt.documents], - return_documents=False, - top_k=self.top_k, - ) - ranking = list(filter(lambda x: x.get("score", 0.0) > 0.5, results)) - log.debug( - f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}" - ) - prompt.documents = [ - prompt.documents[r.get("corpus_id", 0)] for r in ranking - ] - return prompt - - -class CohereReranker(metaclass=AbstractReranker): - def __init__(self) -> None: - self.client = cohere.Client(os.environ["COHERE_API_KEY"]) - self.top_k = int(os.environ["RERANK_TOP_K"]) - - def rank(self, prompt: Prompt) -> Prompt: - if prompt.documents: - response = self.client.rerank( - model="rerank-english-v3.0", - query=prompt.query, - documents=[d.text for d in prompt.documents], - top_n=self.top_k, - ) - ranking = list(filter(lambda x: x.relevance_score > 0.5, response.results)) - log.debug( - f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}" - ) - prompt.documents = [prompt.documents[r.index] for r in ranking] - return prompt - - -def get_reranker(model: str) -> Type[AbstractReranker]: - match model: - case "local": - return Reranker() - case "cohere": - return CohereReranker() - case _: - exit(1) diff --git a/rag/retriever/rerank/__init__.py b/rag/retriever/rerank/__init__.py new file mode 100644 index 0000000..16b2fac --- /dev/null +++ b/rag/retriever/rerank/__init__.py @@ -0,0 +1,15 @@ +from typing import Type + +from rag.retriever.rerank.abstract import AbstractReranker +from rag.retriever.rerank.cohere import CohereReranker +from rag.retriever.rerank.local import Reranker + + +def get_reranker(model: str) -> Type[AbstractReranker]: + match model: + case "local": + return Reranker() + case "cohere": + return CohereReranker() + case _: + exit(1) diff --git a/rag/retriever/rerank/abstract.py b/rag/retriever/rerank/abstract.py new file mode 100644 index 0000000..b96b70a --- /dev/null +++ b/rag/retriever/rerank/abstract.py @@ -0,0 +1,17 @@ +from abc import abstractmethod + +from rag.generator.prompt import Prompt + + +class AbstractReranker(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance + return cls._instances[cls] + + @abstractmethod + def rank(self, prompt: Prompt) -> Prompt: + return prompt diff --git a/rag/retriever/rerank/cohere.py b/rag/retriever/rerank/cohere.py new file mode 100644 index 0000000..dac9ab5 --- /dev/null +++ b/rag/retriever/rerank/cohere.py @@ -0,0 +1,28 @@ +import os + +import cohere +from loguru import logger as log + +from rag.generator.prompt import Prompt +from rag.retriever.rerank.abstract import AbstractReranker + + +class CohereReranker(metaclass=AbstractReranker): + def __init__(self) -> None: + self.client = cohere.Client(os.environ["COHERE_API_KEY"]) + self.top_k = int(os.environ["RERANK_TOP_K"]) + + def rank(self, prompt: Prompt) -> Prompt: + if prompt.documents: + response = self.client.rerank( + model="rerank-english-v3.0", + query=prompt.query, + documents=[d.text for d in prompt.documents], + top_n=self.top_k, + ) + ranking = list(filter(lambda x: x.relevance_score > 0.5, response.results)) + log.debug( + f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}" + ) + prompt.documents = [prompt.documents[r.index] for r in ranking] + return prompt diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py new file mode 100644 index 0000000..758c5dc --- /dev/null +++ b/rag/retriever/rerank/local.py @@ -0,0 +1,30 @@ +import os + +from loguru import logger as log +from sentence_transformers import CrossEncoder + +from rag.generator.prompt import Prompt +from rag.retriever.rerank.abstract import AbstractReranker + + +class Reranker(metaclass=AbstractReranker): + def __init__(self) -> None: + self.model = CrossEncoder(os.environ["RERANK_MODEL"]) + self.top_k = int(os.environ["RERANK_TOP_K"]) + + def rank(self, prompt: Prompt) -> Prompt: + if prompt.documents: + results = self.model.rank( + query=prompt.query, + documents=[d.text for d in prompt.documents], + return_documents=False, + top_k=self.top_k, + ) + ranking = list(filter(lambda x: x.get("score", 0.0) > 0.5, results)) + log.debug( + f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}" + ) + prompt.documents = [ + prompt.documents[r.get("corpus_id", 0)] for r in ranking + ] + return prompt diff --git a/rag/ui.py b/rag/ui.py index f46c24d..36e8c4c 100644 --- a/rag/ui.py +++ b/rag/ui.py @@ -42,7 +42,6 @@ def load_retriever(): def load_generator(model: str): log.debug("Loading generator model") st.session_state.generator = get_generator(model) - set_chat_users() # @st.cache_resource @@ -113,7 +112,7 @@ def store_chat(query: str, response: str, documents: List[Document]): def sidebar(): with st.sidebar: - st.header("Grouding") + st.header("Grounding") st.markdown( ( "These files will be uploaded to the knowledge base and used " @@ -129,7 +128,7 @@ def sidebar(): upload(files) - st.header("Generative Model") + st.header("Model") st.markdown( "Select the model that will be used for reranking and generating the answer." ) -- cgit v1.2.3-70-g09d2