From 95305f59df84caded50286b1a57b6075e48725a8 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 24 Apr 2024 01:10:43 +0200 Subject: Rerank working llama3 sucks at rag --- rag/cli.py | 32 ++++++++----------- rag/generator/__init__.py | 4 +-- rag/generator/abstract.py | 4 --- rag/generator/ollama.py | 10 +++--- rag/generator/prompt.py | 4 +-- rag/retriever/rerank.py | 77 ++++++++++++++++++++++++++++++++++++++++++++++ rag/retriever/retriever.py | 4 +-- rag/retriever/vector.py | 11 ++++--- rag/ui.py | 51 +++++++++++++++--------------- 9 files changed, 131 insertions(+), 66 deletions(-) create mode 100644 rag/retriever/rerank.py (limited to 'rag') diff --git a/rag/cli.py b/rag/cli.py index 932e2a9..b210808 100644 --- a/rag/cli.py +++ b/rag/cli.py @@ -8,6 +8,7 @@ from tqdm import tqdm from rag.generator import get_generator from rag.generator.prompt import Prompt +from rag.retriever.rerank import get_reranker from rag.retriever.retriever import Retriever @@ -33,11 +34,12 @@ def upload(directory: str): retriever.add_pdf(path=path) -def rag(generator: str, query: str, limit): +def rag(model: str, query: str): retriever = Retriever() - generator = get_generator(generator) - documents = retriever.retrieve(query, limit=limit) - prompt = generator.rerank(Prompt(query, documents)) + generator = get_generator(model) + reranker = get_reranker(model) + documents = retriever.retrieve(query) + prompt = reranker.rerank(Prompt(query, documents)) print("Answer: ") for chunk in generator.generate(prompt): print(chunk, end="", flush=True) @@ -50,6 +52,7 @@ def rag(generator: str, query: str, limit): print("---") +@click.command() @click.option( "-q", "--query", @@ -58,20 +61,12 @@ def rag(generator: str, query: str, limit): prompt="Enter your query", ) @click.option( - "-g", - "--generator", - type=click.Choice(["ollama", "cohere"], case_sensitive=False), - default="ollama", - help="Generator client", -) -@click.option( - "-l", - "--limit", - type=click.IntRange(1, 20, clamp=True), - default=5, - help="Max number of documents used in grouding", + "-m", + "--model", + type=click.Choice(["local", "cohere"], case_sensitive=False), + default="local", + help="Generator and rerank model", ) -@click.command() @click.option( "-d", "--directory", @@ -90,7 +85,6 @@ def rag(generator: str, query: str, limit): def main( query: Optional[str], generator: str, - limit: int, directory: Optional[str], verbose: int, ): @@ -98,7 +92,7 @@ def main( if directory: upload(directory) if query: - rag(generator, query, limit) + rag(generator, query) # TODO: maybe add override for models diff --git a/rag/generator/__init__.py b/rag/generator/__init__.py index ba23ffc..a776231 100644 --- a/rag/generator/__init__.py +++ b/rag/generator/__init__.py @@ -4,11 +4,11 @@ from .abstract import AbstractGenerator from .cohere import Cohere from .ollama import Ollama -MODELS = ["ollama", "cohere"] +MODELS = ["local", "cohere"] def get_generator(model: str) -> Type[AbstractGenerator]: match model: - case "ollama": + case "local": return Ollama() case "cohere": return Cohere() diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py index 439c1b5..1beacfb 100644 --- a/rag/generator/abstract.py +++ b/rag/generator/abstract.py @@ -16,7 +16,3 @@ class AbstractGenerator(type): @abstractmethod def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: pass - - @abstractmethod - def rerank(self, prompt: Prompt) -> Prompt: - return prompt diff --git a/rag/generator/ollama.py b/rag/generator/ollama.py index b72d763..9118906 100644 --- a/rag/generator/ollama.py +++ b/rag/generator/ollama.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, Generator, List +from typing import Any, Generator, List import ollama from loguru import logger as log @@ -24,12 +24,12 @@ class Ollama(metaclass=AbstractGenerator): def __metaprompt(self, prompt: Prompt) -> str: metaprompt = ( - "Answer the question based only on the following context:\n" - "\n" - f"{self.__context(prompt.documents)}\n\n" - "\n" f"{ANSWER_INSTRUCTION}" + "Only the information between ... should be used to answer the question.\n" f"Question: {prompt.query.strip()}\n\n" + "\n" + f"{self.__context(prompt.documents)}\n\n" + "\n" "Answer:" ) return metaprompt diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py index fa007db..f607122 100644 --- a/rag/generator/prompt.py +++ b/rag/generator/prompt.py @@ -5,8 +5,8 @@ from rag.retriever.vector import Document ANSWER_INSTRUCTION = ( "Given the context information and not prior knowledge, answer the question." - "If the context is irrelevant to the question, answer that you do not know " - "the answer to the question given the context and stop.\n" + "If the context is irrelevant to the question or empty, then do not attempt to answer " + "the question, just reply that you do not know based on the context provided.\n" ) diff --git a/rag/retriever/rerank.py b/rag/retriever/rerank.py new file mode 100644 index 0000000..08a9a27 --- /dev/null +++ b/rag/retriever/rerank.py @@ -0,0 +1,77 @@ +import os +from abc import abstractmethod +from typing import Type + +import cohere +from loguru import logger as log +from sentence_transformers import CrossEncoder + +from rag.generator.prompt import Prompt + + +class AbstractReranker(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance + return cls._instances[cls] + + @abstractmethod + def rank(self, prompt: Prompt) -> Prompt: + return prompt + + +class Reranker(metaclass=AbstractReranker): + def __init__(self) -> None: + self.model = CrossEncoder(os.environ["RERANK_MODEL"]) + self.top_k = int(os.environ["RERANK_TOP_K"]) + + def rank(self, prompt: Prompt) -> Prompt: + if prompt.documents: + results = self.model.rank( + query=prompt.query, + documents=[d.text for d in prompt.documents], + return_documents=False, + top_k=self.top_k, + ) + ranking = list(filter(lambda x: x.get("score", 0.0) > 0.5, results)) + log.debug( + f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}" + ) + prompt.documents = [ + prompt.documents[r.get("corpus_id", 0)] for r in ranking + ] + return prompt + + +class CohereReranker(metaclass=AbstractReranker): + def __init__(self) -> None: + self.client = cohere.Client(os.environ["COHERE_API_KEY"]) + self.top_k = int(os.environ["RERANK_TOP_K"]) + + def rank(self, prompt: Prompt) -> Prompt: + if prompt.documents: + response = self.client.rerank( + model="rerank-english-v3.0", + query=prompt.query, + documents=[d.text for d in prompt.documents], + top_n=self.top_k, + ) + ranking = list(filter(lambda x: x.relevance_score > 0.5, response.results)) + log.debug( + f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}" + ) + prompt.documents = [prompt.documents[r.index] for r in ranking] + return prompt + + +def get_reranker(model: str) -> Type[AbstractReranker]: + match model: + case "local": + return Reranker() + case "cohere": + return CohereReranker() + case _: + exit(1) diff --git a/rag/retriever/retriever.py b/rag/retriever/retriever.py index deffae5..351cfb0 100644 --- a/rag/retriever/retriever.py +++ b/rag/retriever/retriever.py @@ -45,7 +45,7 @@ class Retriever: else: log.error("Invalid input!") - def retrieve(self, query: str, limit: int = 5) -> List[Document]: + def retrieve(self, query: str) -> List[Document]: log.debug(f"Finding documents matching query: {query}") query_emb = self.encoder.encode_query(query) - return self.vec_db.search(query_emb, limit) + return self.vec_db.search(query_emb) diff --git a/rag/retriever/vector.py b/rag/retriever/vector.py index b72a3c1..1a484f3 100644 --- a/rag/retriever/vector.py +++ b/rag/retriever/vector.py @@ -22,11 +22,12 @@ class Document: class VectorDB: - def __init__(self, score_threshold: float = 0.5): + def __init__(self): self.dim = int(os.environ["EMBEDDING_DIM"]) self.collection_name = os.environ["QDRANT_COLLECTION_NAME"] self.client = QdrantClient(url=os.environ["QDRANT_URL"]) - self.score_threshold = score_threshold + self.top_k = int(os.environ["RETRIEVER_TOP_K"]) + self.score_threshold = float(os.environ["RETRIEVER_SCORE_THRESHOLD"]) self.__configure() def __configure(self): @@ -58,15 +59,15 @@ class VectorDB: max_retries=3, ) - def search(self, query: List[float], limit: int = 5) -> List[Document]: + def search(self, query: List[float]) -> List[Document]: log.debug("Searching for vectors...") hits = self.client.search( collection_name=self.collection_name, query_vector=query, - limit=limit, + limit=self.top_k, score_threshold=self.score_threshold, ) - log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}") + log.debug(f"Got {len(hits)} hits in the vector db with limit={self.top_k}") return list( map( lambda h: Document( diff --git a/rag/ui.py b/rag/ui.py index 2fbf8de..f46c24d 100644 --- a/rag/ui.py +++ b/rag/ui.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from enum import Enum from typing import Dict, List import streamlit as st @@ -9,27 +8,18 @@ from loguru import logger as log from rag.generator import MODELS, get_generator from rag.generator.prompt import Prompt +from rag.retriever.rerank import get_reranker from rag.retriever.retriever import Retriever from rag.retriever.vector import Document -class Cohere(Enum): - USER = "USER" - BOT = "CHATBOT" - - -class Ollama(Enum): - USER = "user" - BOT = "assistant" - - @dataclass class Message: role: str message: str - def as_dict(self, client: str) -> Dict[str, str]: - if client == "cohere": + def as_dict(self, model: str) -> Dict[str, str]: + if model == "cohere": return {"role": self.role, "message": self.message} else: return {"role": self.role, "content": self.message} @@ -38,12 +28,8 @@ class Message: def set_chat_users(): log.debug("Setting user and bot value") ss = st.session_state - if ss.generator == "cohere": - ss.user = Cohere.USER.value - ss.bot = Cohere.BOT.value - else: - ss.user = Ollama.USER.value - ss.bot = Ollama.BOT.value + ss.user = "user" + ss.bot = "assistant" @st.cache_resource @@ -52,13 +38,19 @@ def load_retriever(): st.session_state.retriever = Retriever() -@st.cache_resource -def load_generator(client: str): +# @st.cache_resource +def load_generator(model: str): log.debug("Loading generator model") - st.session_state.generator = get_generator(client) + st.session_state.generator = get_generator(model) set_chat_users() +# @st.cache_resource +def load_reranker(model: str): + log.debug("Loading reranker model") + st.session_state.reranker = get_reranker(model) + + @st.cache_data(show_spinner=False) def upload(files): retriever = st.session_state.retriever @@ -95,11 +87,12 @@ def generate_chat(query: str): retriever = ss.retriever generator = ss.generator + reranker = ss.reranker - documents = retriever.retrieve(query, limit=15) + documents = retriever.retrieve(query) prompt = Prompt(query, documents) - prompt = generator.rerank(prompt) + prompt = reranker.rank(prompt) with st.chat_message(ss.bot): response = st.write_stream(generator.generate(prompt)) @@ -137,9 +130,12 @@ def sidebar(): upload(files) st.header("Generative Model") - st.markdown("Select the model that will be used for generating the answer.") - st.selectbox("Generative Model", key="client", options=MODELS) - load_generator(st.session_state.client) + st.markdown( + "Select the model that will be used for reranking and generating the answer." + ) + st.selectbox("Model", key="model", options=MODELS) + load_generator(st.session_state.model) + load_reranker(st.session_state.model) def page(): @@ -157,6 +153,7 @@ def page(): if __name__ == "__main__": load_dotenv() st.title("Retrieval Augmented Generation") + set_chat_users() load_retriever() sidebar() page() -- cgit v1.2.3-70-g09d2