diff options
Diffstat (limited to 'rag')
-rw-r--r-- | rag/__main__.py | 6 | ||||
-rw-r--r-- | rag/cli.py | 43 | ||||
-rw-r--r-- | rag/generator/prompt.py | 10 | ||||
-rw-r--r-- | rag/retriever/rerank/local.py | 8 |
4 files changed, 37 insertions, 30 deletions
diff --git a/rag/__main__.py b/rag/__main__.py new file mode 100644 index 0000000..be85a2e --- /dev/null +++ b/rag/__main__.py @@ -0,0 +1,6 @@ +from dotenv import load_dotenv + +from rag.cli import cli + +load_dotenv() +cli() @@ -5,10 +5,8 @@ from dotenv import load_dotenv from loguru import logger as log from tqdm import tqdm -from rag.generator import get_generator from rag.generator.prompt import Prompt from rag.model import Rag -from rag.retriever.rerank import get_reranker from rag.retriever.retriever import Retriever @@ -51,13 +49,6 @@ def upload(directory: str, verbose: int): @click.command() @click.option( - "-q", - "--query", - prompt_required=False, - help="The query for rag", - prompt="Enter your query", -) -@click.option( "-c", "--client", type=click.Choice(["local", "cohere"], case_sensitive=False), @@ -65,21 +56,29 @@ def upload(directory: str, verbose: int): help="Generator and rerank model", ) @click.option("-v", "--verbose", count=True) -def rag(query: str, client: str, verbose: int): +def rag(client: str, verbose: int): configure_logging(verbose) rag = Rag(client) - documents = rag.retrieve(query) - prompt = Prompt(query, documents, client) - print("Answer: ") - for chunk in rag.generate(prompt): - print(chunk, end="", flush=True) - - print("\n\n") - for i, doc in enumerate(prompt.documents): - print(f"### Document {i}") - print(f"**Title: {doc.title}**") - print(doc.text) - print("---") + while True: + query = input("Query: ") + documents = rag.retrieve(query) + prompt = Prompt(query, documents, client) + print("Answer: ") + response = "" + for chunk in rag.generate(prompt): + print(chunk, end="", flush=True) + response += chunk + + rag.add_message(rag.bot, response) + + show_context = input("Display context? [y/n] ").lower() == "y" + print("\n\n") + if show_context: + for i, doc in enumerate(prompt.documents): + print(f"### Document {i}") + print(f"**Title: {doc.title}**") + print(doc.text) + print("---") @click.command() diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py index cedf610..75966e8 100644 --- a/rag/generator/prompt.py +++ b/rag/generator/prompt.py @@ -29,11 +29,13 @@ class Prompt: return f"{self.query}\n\n{ANSWER_INSTRUCTION}" else: return ( - "Context information is below.\n" + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n" + "Using the information contained in the context, give a comprehensive answer to the question.\n" + "If the answer cannot be deduced from the context, do not give an answer.\n\n" + "Context:\n" "---\n" f"{self.__context(self.documents)}\n\n" "---\n" - f"{ANSWER_INSTRUCTION}" - f"Query: {self.query.strip()}\n\n" - "Answer:" + f"Question: {self.query}<|eot_id|>\n" + "<|start_header_id|>assistant<|end_header_id|>\n\n" ) diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py index fd42c2c..231d50a 100644 --- a/rag/retriever/rerank/local.py +++ b/rag/retriever/rerank/local.py @@ -2,16 +2,16 @@ import os from typing import List from loguru import logger as log -from rag.message import Message -from rag.retriever.vector import Document from sentence_transformers import CrossEncoder +from rag.message import Message from rag.retriever.rerank.abstract import AbstractReranker +from rag.retriever.vector import Document class Reranker(metaclass=AbstractReranker): def __init__(self) -> None: - self.model = CrossEncoder(os.environ["RERANK_MODEL"]) + self.model = CrossEncoder(os.environ["RERANK_MODEL"], device="cpu") self.top_k = int(os.environ["RERANK_TOP_K"]) self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"]) @@ -33,7 +33,7 @@ class Reranker(metaclass=AbstractReranker): def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]: results = self.model.rank( query=query, - documents=[m.message for m in messages], + documents=[m.content for m in messages], return_documents=False, top_k=self.top_k, ) |