summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-23 22:08:03 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-23 22:08:03 +0200
commit75be0914f6bd2cdeda1539f83b38fcbc854d5cfa (patch)
tree8d518f43926a98ee2cbae262d152662d0c07e9f6
parent694a4ad0e5a9e4c7eb6d11fff5ae414292ef8169 (diff)
Add reranking as a separate step
-rw-r--r--rag/cli.py4
-rw-r--r--rag/generator/abstract.py4
-rw-r--r--rag/generator/cohere.py26
-rw-r--r--rag/ui.py6
4 files changed, 23 insertions, 17 deletions
diff --git a/rag/cli.py b/rag/cli.py
index 690563e..932e2a9 100644
--- a/rag/cli.py
+++ b/rag/cli.py
@@ -37,13 +37,13 @@ def rag(generator: str, query: str, limit):
retriever = Retriever()
generator = get_generator(generator)
documents = retriever.retrieve(query, limit=limit)
- prompt = Prompt(query, documents)
+ prompt = generator.rerank(Prompt(query, documents))
print("Answer: ")
for chunk in generator.generate(prompt):
print(chunk, end="", flush=True)
print("\n\n")
- for i, doc in enumerate(documents):
+ for i, doc in enumerate(prompt.documents):
print(f"### Document {i}")
print(f"**Title: {doc.title}**")
print(doc.text)
diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py
index 1beacfb..439c1b5 100644
--- a/rag/generator/abstract.py
+++ b/rag/generator/abstract.py
@@ -16,3 +16,7 @@ class AbstractGenerator(type):
@abstractmethod
def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
pass
+
+ @abstractmethod
+ def rerank(self, prompt: Prompt) -> Prompt:
+ return prompt
diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py
index 3499e3b..049aea4 100644
--- a/rag/generator/cohere.py
+++ b/rag/generator/cohere.py
@@ -13,23 +13,23 @@ class Cohere(metaclass=AbstractGenerator):
def __init__(self) -> None:
self.client = cohere.Client(os.environ["COHERE_API_KEY"])
- def rerank(self, prompt: Prompt):
- response = self.client.rerank(
- model="rerank-english-v3.0",
- query=prompt.query,
- documents=[d.text for d in prompt.documents],
- top_n=3,
- )
- ranking = list(
- filter(lambda x: x.relevance_score > 0.5, response.results)
- )
- prompt.documents = [prompt.documents[r.index] for r in ranking]
+ def rerank(self, prompt: Prompt) -> Prompt:
+ if prompt.documents:
+ response = self.client.rerank(
+ model="rerank-english-v3.0",
+ query=prompt.query,
+ documents=[d.text for d in prompt.documents],
+ top_n=3,
+ )
+ ranking = list(filter(lambda x: x.relevance_score > 0.5, response.results))
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}"
+ )
+ prompt.documents = [prompt.documents[r.index] for r in ranking]
return prompt
def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
log.debug("Generating answer from cohere...")
- if prompt.documents:
- prompt = self.rerank(prompt)
query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}"
for event in self.client.chat_stream(
message=query,
diff --git a/rag/ui.py b/rag/ui.py
index fb02e5c..2fbf8de 100644
--- a/rag/ui.py
+++ b/rag/ui.py
@@ -99,11 +99,13 @@ def generate_chat(query: str):
documents = retriever.retrieve(query, limit=15)
prompt = Prompt(query, documents)
+ prompt = generator.rerank(prompt)
+
with st.chat_message(ss.bot):
response = st.write_stream(generator.generate(prompt))
- display_context(documents)
- store_chat(query, response, documents)
+ display_context(prompt.documents)
+ store_chat(query, response, prompt.documents)
def store_chat(query: str, response: str, documents: List[Document]):