summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-24 09:09:24 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-24 09:09:24 +0200
commit9e0cbcb4e7f1f3f95f304046d3190c6ebc4d3901 (patch)
tree5d890ce2705b79f23d63988c140d08edadaf35c5
parent2e85325639ce3827cc2eb32f9750dfa873e3a480 (diff)
Reformat and fix typo
-rw-r--r--rag/generator/prompt.py7
-rw-r--r--rag/retriever/rerank.py77
-rw-r--r--rag/retriever/rerank/__init__.py15
-rw-r--r--rag/retriever/rerank/abstract.py17
-rw-r--r--rag/retriever/rerank/cohere.py28
-rw-r--r--rag/retriever/rerank/local.py30
-rw-r--r--rag/ui.py5
7 files changed, 96 insertions, 83 deletions
diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py
index 10afe7b..6523842 100644
--- a/rag/generator/prompt.py
+++ b/rag/generator/prompt.py
@@ -4,9 +4,10 @@ from typing import List
from rag.retriever.vector import Document
ANSWER_INSTRUCTION = (
- "Given the context information and not prior knowledge, answer the query."
- "If the context is irrelevant to the query or empty, then do not attempt to answer "
- "the query, just reply that you do not know based on the context provided.\n"
+ "Do not attempt to answer the query without relevant context, and do not use "
+ "prior knowledge or training data!\n"
+ "If the context does not contain the answer or is empty, only reply that you "
+ "cannot answer the query given the context."
)
diff --git a/rag/retriever/rerank.py b/rag/retriever/rerank.py
deleted file mode 100644
index 08a9a27..0000000
--- a/rag/retriever/rerank.py
+++ /dev/null
@@ -1,77 +0,0 @@
-import os
-from abc import abstractmethod
-from typing import Type
-
-import cohere
-from loguru import logger as log
-from sentence_transformers import CrossEncoder
-
-from rag.generator.prompt import Prompt
-
-
-class AbstractReranker(type):
- _instances = {}
-
- def __call__(cls, *args, **kwargs):
- if cls not in cls._instances:
- instance = super().__call__(*args, **kwargs)
- cls._instances[cls] = instance
- return cls._instances[cls]
-
- @abstractmethod
- def rank(self, prompt: Prompt) -> Prompt:
- return prompt
-
-
-class Reranker(metaclass=AbstractReranker):
- def __init__(self) -> None:
- self.model = CrossEncoder(os.environ["RERANK_MODEL"])
- self.top_k = int(os.environ["RERANK_TOP_K"])
-
- def rank(self, prompt: Prompt) -> Prompt:
- if prompt.documents:
- results = self.model.rank(
- query=prompt.query,
- documents=[d.text for d in prompt.documents],
- return_documents=False,
- top_k=self.top_k,
- )
- ranking = list(filter(lambda x: x.get("score", 0.0) > 0.5, results))
- log.debug(
- f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}"
- )
- prompt.documents = [
- prompt.documents[r.get("corpus_id", 0)] for r in ranking
- ]
- return prompt
-
-
-class CohereReranker(metaclass=AbstractReranker):
- def __init__(self) -> None:
- self.client = cohere.Client(os.environ["COHERE_API_KEY"])
- self.top_k = int(os.environ["RERANK_TOP_K"])
-
- def rank(self, prompt: Prompt) -> Prompt:
- if prompt.documents:
- response = self.client.rerank(
- model="rerank-english-v3.0",
- query=prompt.query,
- documents=[d.text for d in prompt.documents],
- top_n=self.top_k,
- )
- ranking = list(filter(lambda x: x.relevance_score > 0.5, response.results))
- log.debug(
- f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}"
- )
- prompt.documents = [prompt.documents[r.index] for r in ranking]
- return prompt
-
-
-def get_reranker(model: str) -> Type[AbstractReranker]:
- match model:
- case "local":
- return Reranker()
- case "cohere":
- return CohereReranker()
- case _:
- exit(1)
diff --git a/rag/retriever/rerank/__init__.py b/rag/retriever/rerank/__init__.py
new file mode 100644
index 0000000..16b2fac
--- /dev/null
+++ b/rag/retriever/rerank/__init__.py
@@ -0,0 +1,15 @@
+from typing import Type
+
+from rag.retriever.rerank.abstract import AbstractReranker
+from rag.retriever.rerank.cohere import CohereReranker
+from rag.retriever.rerank.local import Reranker
+
+
+def get_reranker(model: str) -> Type[AbstractReranker]:
+ match model:
+ case "local":
+ return Reranker()
+ case "cohere":
+ return CohereReranker()
+ case _:
+ exit(1)
diff --git a/rag/retriever/rerank/abstract.py b/rag/retriever/rerank/abstract.py
new file mode 100644
index 0000000..b96b70a
--- /dev/null
+++ b/rag/retriever/rerank/abstract.py
@@ -0,0 +1,17 @@
+from abc import abstractmethod
+
+from rag.generator.prompt import Prompt
+
+
+class AbstractReranker(type):
+ _instances = {}
+
+ def __call__(cls, *args, **kwargs):
+ if cls not in cls._instances:
+ instance = super().__call__(*args, **kwargs)
+ cls._instances[cls] = instance
+ return cls._instances[cls]
+
+ @abstractmethod
+ def rank(self, prompt: Prompt) -> Prompt:
+ return prompt
diff --git a/rag/retriever/rerank/cohere.py b/rag/retriever/rerank/cohere.py
new file mode 100644
index 0000000..dac9ab5
--- /dev/null
+++ b/rag/retriever/rerank/cohere.py
@@ -0,0 +1,28 @@
+import os
+
+import cohere
+from loguru import logger as log
+
+from rag.generator.prompt import Prompt
+from rag.retriever.rerank.abstract import AbstractReranker
+
+
+class CohereReranker(metaclass=AbstractReranker):
+ def __init__(self) -> None:
+ self.client = cohere.Client(os.environ["COHERE_API_KEY"])
+ self.top_k = int(os.environ["RERANK_TOP_K"])
+
+ def rank(self, prompt: Prompt) -> Prompt:
+ if prompt.documents:
+ response = self.client.rerank(
+ model="rerank-english-v3.0",
+ query=prompt.query,
+ documents=[d.text for d in prompt.documents],
+ top_n=self.top_k,
+ )
+ ranking = list(filter(lambda x: x.relevance_score > 0.5, response.results))
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}"
+ )
+ prompt.documents = [prompt.documents[r.index] for r in ranking]
+ return prompt
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py
new file mode 100644
index 0000000..758c5dc
--- /dev/null
+++ b/rag/retriever/rerank/local.py
@@ -0,0 +1,30 @@
+import os
+
+from loguru import logger as log
+from sentence_transformers import CrossEncoder
+
+from rag.generator.prompt import Prompt
+from rag.retriever.rerank.abstract import AbstractReranker
+
+
+class Reranker(metaclass=AbstractReranker):
+ def __init__(self) -> None:
+ self.model = CrossEncoder(os.environ["RERANK_MODEL"])
+ self.top_k = int(os.environ["RERANK_TOP_K"])
+
+ def rank(self, prompt: Prompt) -> Prompt:
+ if prompt.documents:
+ results = self.model.rank(
+ query=prompt.query,
+ documents=[d.text for d in prompt.documents],
+ return_documents=False,
+ top_k=self.top_k,
+ )
+ ranking = list(filter(lambda x: x.get("score", 0.0) > 0.5, results))
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}"
+ )
+ prompt.documents = [
+ prompt.documents[r.get("corpus_id", 0)] for r in ranking
+ ]
+ return prompt
diff --git a/rag/ui.py b/rag/ui.py
index f46c24d..36e8c4c 100644
--- a/rag/ui.py
+++ b/rag/ui.py
@@ -42,7 +42,6 @@ def load_retriever():
def load_generator(model: str):
log.debug("Loading generator model")
st.session_state.generator = get_generator(model)
- set_chat_users()
# @st.cache_resource
@@ -113,7 +112,7 @@ def store_chat(query: str, response: str, documents: List[Document]):
def sidebar():
with st.sidebar:
- st.header("Grouding")
+ st.header("Grounding")
st.markdown(
(
"These files will be uploaded to the knowledge base and used "
@@ -129,7 +128,7 @@ def sidebar():
upload(files)
- st.header("Generative Model")
+ st.header("Model")
st.markdown(
"Select the model that will be used for reranking and generating the answer."
)