From 75be0914f6bd2cdeda1539f83b38fcbc854d5cfa Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 23 Apr 2024 22:08:03 +0200 Subject: Add reranking as a separate step --- rag/cli.py | 4 ++-- rag/generator/abstract.py | 4 ++++ rag/generator/cohere.py | 26 +++++++++++++------------- rag/ui.py | 6 ++++-- 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/rag/cli.py b/rag/cli.py index 690563e..932e2a9 100644 --- a/rag/cli.py +++ b/rag/cli.py @@ -37,13 +37,13 @@ def rag(generator: str, query: str, limit): retriever = Retriever() generator = get_generator(generator) documents = retriever.retrieve(query, limit=limit) - prompt = Prompt(query, documents) + prompt = generator.rerank(Prompt(query, documents)) print("Answer: ") for chunk in generator.generate(prompt): print(chunk, end="", flush=True) print("\n\n") - for i, doc in enumerate(documents): + for i, doc in enumerate(prompt.documents): print(f"### Document {i}") print(f"**Title: {doc.title}**") print(doc.text) diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py index 1beacfb..439c1b5 100644 --- a/rag/generator/abstract.py +++ b/rag/generator/abstract.py @@ -16,3 +16,7 @@ class AbstractGenerator(type): @abstractmethod def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: pass + + @abstractmethod + def rerank(self, prompt: Prompt) -> Prompt: + return prompt diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py index 3499e3b..049aea4 100644 --- a/rag/generator/cohere.py +++ b/rag/generator/cohere.py @@ -13,23 +13,23 @@ class Cohere(metaclass=AbstractGenerator): def __init__(self) -> None: self.client = cohere.Client(os.environ["COHERE_API_KEY"]) - def rerank(self, prompt: Prompt): - response = self.client.rerank( - model="rerank-english-v3.0", - query=prompt.query, - documents=[d.text for d in prompt.documents], - top_n=3, - ) - ranking = list( - filter(lambda x: x.relevance_score > 0.5, response.results) - ) - prompt.documents = [prompt.documents[r.index] for r in ranking] + def rerank(self, prompt: Prompt) -> Prompt: + if prompt.documents: + response = self.client.rerank( + model="rerank-english-v3.0", + query=prompt.query, + documents=[d.text for d in prompt.documents], + top_n=3, + ) + ranking = list(filter(lambda x: x.relevance_score > 0.5, response.results)) + log.debug( + f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}" + ) + prompt.documents = [prompt.documents[r.index] for r in ranking] return prompt def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: log.debug("Generating answer from cohere...") - if prompt.documents: - prompt = self.rerank(prompt) query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}" for event in self.client.chat_stream( message=query, diff --git a/rag/ui.py b/rag/ui.py index fb02e5c..2fbf8de 100644 --- a/rag/ui.py +++ b/rag/ui.py @@ -99,11 +99,13 @@ def generate_chat(query: str): documents = retriever.retrieve(query, limit=15) prompt = Prompt(query, documents) + prompt = generator.rerank(prompt) + with st.chat_message(ss.bot): response = st.write_stream(generator.generate(prompt)) - display_context(documents) - store_chat(query, response, documents) + display_context(prompt.documents) + store_chat(query, response, prompt.documents) def store_chat(query: str, response: str, documents: List[Document]): -- cgit v1.2.3-70-g09d2