summaryrefslogtreecommitdiff
path: root/rag/cli.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/cli.py')
-rw-r--r--rag/cli.py43
1 files changed, 21 insertions, 22 deletions
diff --git a/rag/cli.py b/rag/cli.py
index 070427d..b047255 100644
--- a/rag/cli.py
+++ b/rag/cli.py
@@ -5,10 +5,8 @@ from dotenv import load_dotenv
from loguru import logger as log
from tqdm import tqdm
-from rag.generator import get_generator
from rag.generator.prompt import Prompt
from rag.model import Rag
-from rag.retriever.rerank import get_reranker
from rag.retriever.retriever import Retriever
@@ -51,13 +49,6 @@ def upload(directory: str, verbose: int):
@click.command()
@click.option(
- "-q",
- "--query",
- prompt_required=False,
- help="The query for rag",
- prompt="Enter your query",
-)
-@click.option(
"-c",
"--client",
type=click.Choice(["local", "cohere"], case_sensitive=False),
@@ -65,21 +56,29 @@ def upload(directory: str, verbose: int):
help="Generator and rerank model",
)
@click.option("-v", "--verbose", count=True)
-def rag(query: str, client: str, verbose: int):
+def rag(client: str, verbose: int):
configure_logging(verbose)
rag = Rag(client)
- documents = rag.retrieve(query)
- prompt = Prompt(query, documents, client)
- print("Answer: ")
- for chunk in rag.generate(prompt):
- print(chunk, end="", flush=True)
-
- print("\n\n")
- for i, doc in enumerate(prompt.documents):
- print(f"### Document {i}")
- print(f"**Title: {doc.title}**")
- print(doc.text)
- print("---")
+ while True:
+ query = input("Query: ")
+ documents = rag.retrieve(query)
+ prompt = Prompt(query, documents, client)
+ print("Answer: ")
+ response = ""
+ for chunk in rag.generate(prompt):
+ print(chunk, end="", flush=True)
+ response += chunk
+
+ rag.add_message(rag.bot, response)
+
+ show_context = input("Display context? [y/n] ").lower() == "y"
+ print("\n\n")
+ if show_context:
+ for i, doc in enumerate(prompt.documents):
+ print(f"### Document {i}")
+ print(f"**Title: {doc.title}**")
+ print(doc.text)
+ print("---")
@click.command()