summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-24 01:51:06 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-24 01:51:06 +0200
commitf76660d052a79905748163b96c2cca8671ee8a24 (patch)
treeb615edc2ee6ab7a91d188c62fc05b96d08e48292
parent95305f59df84caded50286b1a57b6075e48725a8 (diff)
Remove rerank from generator
-rw-r--r--rag/generator/cohere.py15
1 files changed, 0 insertions, 15 deletions
diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py
index 049aea4..28a87e7 100644
--- a/rag/generator/cohere.py
+++ b/rag/generator/cohere.py
@@ -13,21 +13,6 @@ class Cohere(metaclass=AbstractGenerator):
def __init__(self) -> None:
self.client = cohere.Client(os.environ["COHERE_API_KEY"])
- def rerank(self, prompt: Prompt) -> Prompt:
- if prompt.documents:
- response = self.client.rerank(
- model="rerank-english-v3.0",
- query=prompt.query,
- documents=[d.text for d in prompt.documents],
- top_n=3,
- )
- ranking = list(filter(lambda x: x.relevance_score > 0.5, response.results))
- log.debug(
- f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}"
- )
- prompt.documents = [prompt.documents[r.index] for r in ranking]
- return prompt
-
def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
log.debug("Generating answer from cohere...")
query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}"