summaryrefslogtreecommitdiff
path: root/rag
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-24 01:10:43 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-24 01:10:43 +0200
commit95305f59df84caded50286b1a57b6075e48725a8 (patch)
treec0f0157a99da6332a3c96462b0aba2bd02dfcb33 /rag
parent75be0914f6bd2cdeda1539f83b38fcbc854d5cfa (diff)
Rerank working
llama3 sucks at rag
Diffstat (limited to 'rag')
-rw-r--r--rag/cli.py32
-rw-r--r--rag/generator/__init__.py4
-rw-r--r--rag/generator/abstract.py4
-rw-r--r--rag/generator/ollama.py10
-rw-r--r--rag/generator/prompt.py4
-rw-r--r--rag/retriever/rerank.py77
-rw-r--r--rag/retriever/retriever.py4
-rw-r--r--rag/retriever/vector.py11
-rw-r--r--rag/ui.py51
9 files changed, 131 insertions, 66 deletions
diff --git a/rag/cli.py b/rag/cli.py
index 932e2a9..b210808 100644
--- a/rag/cli.py
+++ b/rag/cli.py
@@ -8,6 +8,7 @@ from tqdm import tqdm
from rag.generator import get_generator
from rag.generator.prompt import Prompt
+from rag.retriever.rerank import get_reranker
from rag.retriever.retriever import Retriever
@@ -33,11 +34,12 @@ def upload(directory: str):
retriever.add_pdf(path=path)
-def rag(generator: str, query: str, limit):
+def rag(model: str, query: str):
retriever = Retriever()
- generator = get_generator(generator)
- documents = retriever.retrieve(query, limit=limit)
- prompt = generator.rerank(Prompt(query, documents))
+ generator = get_generator(model)
+ reranker = get_reranker(model)
+ documents = retriever.retrieve(query)
+ prompt = reranker.rerank(Prompt(query, documents))
print("Answer: ")
for chunk in generator.generate(prompt):
print(chunk, end="", flush=True)
@@ -50,6 +52,7 @@ def rag(generator: str, query: str, limit):
print("---")
+@click.command()
@click.option(
"-q",
"--query",
@@ -58,20 +61,12 @@ def rag(generator: str, query: str, limit):
prompt="Enter your query",
)
@click.option(
- "-g",
- "--generator",
- type=click.Choice(["ollama", "cohere"], case_sensitive=False),
- default="ollama",
- help="Generator client",
-)
-@click.option(
- "-l",
- "--limit",
- type=click.IntRange(1, 20, clamp=True),
- default=5,
- help="Max number of documents used in grouding",
+ "-m",
+ "--model",
+ type=click.Choice(["local", "cohere"], case_sensitive=False),
+ default="local",
+ help="Generator and rerank model",
)
-@click.command()
@click.option(
"-d",
"--directory",
@@ -90,7 +85,6 @@ def rag(generator: str, query: str, limit):
def main(
query: Optional[str],
generator: str,
- limit: int,
directory: Optional[str],
verbose: int,
):
@@ -98,7 +92,7 @@ def main(
if directory:
upload(directory)
if query:
- rag(generator, query, limit)
+ rag(generator, query)
# TODO: maybe add override for models
diff --git a/rag/generator/__init__.py b/rag/generator/__init__.py
index ba23ffc..a776231 100644
--- a/rag/generator/__init__.py
+++ b/rag/generator/__init__.py
@@ -4,11 +4,11 @@ from .abstract import AbstractGenerator
from .cohere import Cohere
from .ollama import Ollama
-MODELS = ["ollama", "cohere"]
+MODELS = ["local", "cohere"]
def get_generator(model: str) -> Type[AbstractGenerator]:
match model:
- case "ollama":
+ case "local":
return Ollama()
case "cohere":
return Cohere()
diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py
index 439c1b5..1beacfb 100644
--- a/rag/generator/abstract.py
+++ b/rag/generator/abstract.py
@@ -16,7 +16,3 @@ class AbstractGenerator(type):
@abstractmethod
def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
pass
-
- @abstractmethod
- def rerank(self, prompt: Prompt) -> Prompt:
- return prompt
diff --git a/rag/generator/ollama.py b/rag/generator/ollama.py
index b72d763..9118906 100644
--- a/rag/generator/ollama.py
+++ b/rag/generator/ollama.py
@@ -1,5 +1,5 @@
import os
-from typing import Any, Dict, Generator, List
+from typing import Any, Generator, List
import ollama
from loguru import logger as log
@@ -24,12 +24,12 @@ class Ollama(metaclass=AbstractGenerator):
def __metaprompt(self, prompt: Prompt) -> str:
metaprompt = (
- "Answer the question based only on the following context:\n"
- "<context>\n"
- f"{self.__context(prompt.documents)}\n\n"
- "</context>\n"
f"{ANSWER_INSTRUCTION}"
+ "Only the information between <results>...</results> should be used to answer the question.\n"
f"Question: {prompt.query.strip()}\n\n"
+ "<results>\n"
+ f"{self.__context(prompt.documents)}\n\n"
+ "</results>\n"
"Answer:"
)
return metaprompt
diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py
index fa007db..f607122 100644
--- a/rag/generator/prompt.py
+++ b/rag/generator/prompt.py
@@ -5,8 +5,8 @@ from rag.retriever.vector import Document
ANSWER_INSTRUCTION = (
"Given the context information and not prior knowledge, answer the question."
- "If the context is irrelevant to the question, answer that you do not know "
- "the answer to the question given the context and stop.\n"
+ "If the context is irrelevant to the question or empty, then do not attempt to answer "
+ "the question, just reply that you do not know based on the context provided.\n"
)
diff --git a/rag/retriever/rerank.py b/rag/retriever/rerank.py
new file mode 100644
index 0000000..08a9a27
--- /dev/null
+++ b/rag/retriever/rerank.py
@@ -0,0 +1,77 @@
+import os
+from abc import abstractmethod
+from typing import Type
+
+import cohere
+from loguru import logger as log
+from sentence_transformers import CrossEncoder
+
+from rag.generator.prompt import Prompt
+
+
+class AbstractReranker(type):
+ _instances = {}
+
+ def __call__(cls, *args, **kwargs):
+ if cls not in cls._instances:
+ instance = super().__call__(*args, **kwargs)
+ cls._instances[cls] = instance
+ return cls._instances[cls]
+
+ @abstractmethod
+ def rank(self, prompt: Prompt) -> Prompt:
+ return prompt
+
+
+class Reranker(metaclass=AbstractReranker):
+ def __init__(self) -> None:
+ self.model = CrossEncoder(os.environ["RERANK_MODEL"])
+ self.top_k = int(os.environ["RERANK_TOP_K"])
+
+ def rank(self, prompt: Prompt) -> Prompt:
+ if prompt.documents:
+ results = self.model.rank(
+ query=prompt.query,
+ documents=[d.text for d in prompt.documents],
+ return_documents=False,
+ top_k=self.top_k,
+ )
+ ranking = list(filter(lambda x: x.get("score", 0.0) > 0.5, results))
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}"
+ )
+ prompt.documents = [
+ prompt.documents[r.get("corpus_id", 0)] for r in ranking
+ ]
+ return prompt
+
+
+class CohereReranker(metaclass=AbstractReranker):
+ def __init__(self) -> None:
+ self.client = cohere.Client(os.environ["COHERE_API_KEY"])
+ self.top_k = int(os.environ["RERANK_TOP_K"])
+
+ def rank(self, prompt: Prompt) -> Prompt:
+ if prompt.documents:
+ response = self.client.rerank(
+ model="rerank-english-v3.0",
+ query=prompt.query,
+ documents=[d.text for d in prompt.documents],
+ top_n=self.top_k,
+ )
+ ranking = list(filter(lambda x: x.relevance_score > 0.5, response.results))
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}"
+ )
+ prompt.documents = [prompt.documents[r.index] for r in ranking]
+ return prompt
+
+
+def get_reranker(model: str) -> Type[AbstractReranker]:
+ match model:
+ case "local":
+ return Reranker()
+ case "cohere":
+ return CohereReranker()
+ case _:
+ exit(1)
diff --git a/rag/retriever/retriever.py b/rag/retriever/retriever.py
index deffae5..351cfb0 100644
--- a/rag/retriever/retriever.py
+++ b/rag/retriever/retriever.py
@@ -45,7 +45,7 @@ class Retriever:
else:
log.error("Invalid input!")
- def retrieve(self, query: str, limit: int = 5) -> List[Document]:
+ def retrieve(self, query: str) -> List[Document]:
log.debug(f"Finding documents matching query: {query}")
query_emb = self.encoder.encode_query(query)
- return self.vec_db.search(query_emb, limit)
+ return self.vec_db.search(query_emb)
diff --git a/rag/retriever/vector.py b/rag/retriever/vector.py
index b72a3c1..1a484f3 100644
--- a/rag/retriever/vector.py
+++ b/rag/retriever/vector.py
@@ -22,11 +22,12 @@ class Document:
class VectorDB:
- def __init__(self, score_threshold: float = 0.5):
+ def __init__(self):
self.dim = int(os.environ["EMBEDDING_DIM"])
self.collection_name = os.environ["QDRANT_COLLECTION_NAME"]
self.client = QdrantClient(url=os.environ["QDRANT_URL"])
- self.score_threshold = score_threshold
+ self.top_k = int(os.environ["RETRIEVER_TOP_K"])
+ self.score_threshold = float(os.environ["RETRIEVER_SCORE_THRESHOLD"])
self.__configure()
def __configure(self):
@@ -58,15 +59,15 @@ class VectorDB:
max_retries=3,
)
- def search(self, query: List[float], limit: int = 5) -> List[Document]:
+ def search(self, query: List[float]) -> List[Document]:
log.debug("Searching for vectors...")
hits = self.client.search(
collection_name=self.collection_name,
query_vector=query,
- limit=limit,
+ limit=self.top_k,
score_threshold=self.score_threshold,
)
- log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}")
+ log.debug(f"Got {len(hits)} hits in the vector db with limit={self.top_k}")
return list(
map(
lambda h: Document(
diff --git a/rag/ui.py b/rag/ui.py
index 2fbf8de..f46c24d 100644
--- a/rag/ui.py
+++ b/rag/ui.py
@@ -1,5 +1,4 @@
from dataclasses import dataclass
-from enum import Enum
from typing import Dict, List
import streamlit as st
@@ -9,27 +8,18 @@ from loguru import logger as log
from rag.generator import MODELS, get_generator
from rag.generator.prompt import Prompt
+from rag.retriever.rerank import get_reranker
from rag.retriever.retriever import Retriever
from rag.retriever.vector import Document
-class Cohere(Enum):
- USER = "USER"
- BOT = "CHATBOT"
-
-
-class Ollama(Enum):
- USER = "user"
- BOT = "assistant"
-
-
@dataclass
class Message:
role: str
message: str
- def as_dict(self, client: str) -> Dict[str, str]:
- if client == "cohere":
+ def as_dict(self, model: str) -> Dict[str, str]:
+ if model == "cohere":
return {"role": self.role, "message": self.message}
else:
return {"role": self.role, "content": self.message}
@@ -38,12 +28,8 @@ class Message:
def set_chat_users():
log.debug("Setting user and bot value")
ss = st.session_state
- if ss.generator == "cohere":
- ss.user = Cohere.USER.value
- ss.bot = Cohere.BOT.value
- else:
- ss.user = Ollama.USER.value
- ss.bot = Ollama.BOT.value
+ ss.user = "user"
+ ss.bot = "assistant"
@st.cache_resource
@@ -52,13 +38,19 @@ def load_retriever():
st.session_state.retriever = Retriever()
-@st.cache_resource
-def load_generator(client: str):
+# @st.cache_resource
+def load_generator(model: str):
log.debug("Loading generator model")
- st.session_state.generator = get_generator(client)
+ st.session_state.generator = get_generator(model)
set_chat_users()
+# @st.cache_resource
+def load_reranker(model: str):
+ log.debug("Loading reranker model")
+ st.session_state.reranker = get_reranker(model)
+
+
@st.cache_data(show_spinner=False)
def upload(files):
retriever = st.session_state.retriever
@@ -95,11 +87,12 @@ def generate_chat(query: str):
retriever = ss.retriever
generator = ss.generator
+ reranker = ss.reranker
- documents = retriever.retrieve(query, limit=15)
+ documents = retriever.retrieve(query)
prompt = Prompt(query, documents)
- prompt = generator.rerank(prompt)
+ prompt = reranker.rank(prompt)
with st.chat_message(ss.bot):
response = st.write_stream(generator.generate(prompt))
@@ -137,9 +130,12 @@ def sidebar():
upload(files)
st.header("Generative Model")
- st.markdown("Select the model that will be used for generating the answer.")
- st.selectbox("Generative Model", key="client", options=MODELS)
- load_generator(st.session_state.client)
+ st.markdown(
+ "Select the model that will be used for reranking and generating the answer."
+ )
+ st.selectbox("Model", key="model", options=MODELS)
+ load_generator(st.session_state.model)
+ load_reranker(st.session_state.model)
def page():
@@ -157,6 +153,7 @@ def page():
if __name__ == "__main__":
load_dotenv()
st.title("Retrieval Augmented Generation")
+ set_chat_users()
load_retriever()
sidebar()
page()