diff options
Diffstat (limited to 'rag/cli.py')
-rw-r--r-- | rag/cli.py | 43 |
1 files changed, 21 insertions, 22 deletions
@@ -5,10 +5,8 @@ from dotenv import load_dotenv from loguru import logger as log from tqdm import tqdm -from rag.generator import get_generator from rag.generator.prompt import Prompt from rag.model import Rag -from rag.retriever.rerank import get_reranker from rag.retriever.retriever import Retriever @@ -51,13 +49,6 @@ def upload(directory: str, verbose: int): @click.command() @click.option( - "-q", - "--query", - prompt_required=False, - help="The query for rag", - prompt="Enter your query", -) -@click.option( "-c", "--client", type=click.Choice(["local", "cohere"], case_sensitive=False), @@ -65,21 +56,29 @@ def upload(directory: str, verbose: int): help="Generator and rerank model", ) @click.option("-v", "--verbose", count=True) -def rag(query: str, client: str, verbose: int): +def rag(client: str, verbose: int): configure_logging(verbose) rag = Rag(client) - documents = rag.retrieve(query) - prompt = Prompt(query, documents, client) - print("Answer: ") - for chunk in rag.generate(prompt): - print(chunk, end="", flush=True) - - print("\n\n") - for i, doc in enumerate(prompt.documents): - print(f"### Document {i}") - print(f"**Title: {doc.title}**") - print(doc.text) - print("---") + while True: + query = input("Query: ") + documents = rag.retrieve(query) + prompt = Prompt(query, documents, client) + print("Answer: ") + response = "" + for chunk in rag.generate(prompt): + print(chunk, end="", flush=True) + response += chunk + + rag.add_message(rag.bot, response) + + show_context = input("Display context? [y/n] ").lower() == "y" + print("\n\n") + if show_context: + for i, doc in enumerate(prompt.documents): + print(f"### Document {i}") + print(f"**Title: {doc.title}**") + print(doc.text) + print("---") @click.command() |