summaryrefslogtreecommitdiff
path: root/rag/cli.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/cli.py')
-rw-r--r--rag/cli.py17
1 files changed, 8 insertions, 9 deletions
diff --git a/rag/cli.py b/rag/cli.py
index 6c4d3e0..070427d 100644
--- a/rag/cli.py
+++ b/rag/cli.py
@@ -7,6 +7,7 @@ from tqdm import tqdm
from rag.generator import get_generator
from rag.generator.prompt import Prompt
+from rag.model import Rag
from rag.retriever.rerank import get_reranker
from rag.retriever.retriever import Retriever
@@ -57,22 +58,20 @@ def upload(directory: str, verbose: int):
prompt="Enter your query",
)
@click.option(
- "-m",
- "--model",
+ "-c",
+ "--client",
type=click.Choice(["local", "cohere"], case_sensitive=False),
default="local",
help="Generator and rerank model",
)
@click.option("-v", "--verbose", count=True)
-def rag(query: str, model: str, verbose: int):
+def rag(query: str, client: str, verbose: int):
configure_logging(verbose)
- retriever = Retriever()
- generator = get_generator(model)
- reranker = get_reranker(model)
- documents = retriever.retrieve(query)
- prompt = reranker.rank(Prompt(query, documents))
+ rag = Rag(client)
+ documents = rag.retrieve(query)
+ prompt = Prompt(query, documents, client)
print("Answer: ")
- for chunk in generator.generate(prompt):
+ for chunk in rag.generate(prompt):
print(chunk, end="", flush=True)
print("\n\n")