From aac821b148c6c0d35b940609dc9b0ddcb053b28e Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 19 Jun 2024 02:07:06 +0200 Subject: Still wip on rewrite --- rag/cli.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) (limited to 'rag/cli.py') diff --git a/rag/cli.py b/rag/cli.py index 6c4d3e0..070427d 100644 --- a/rag/cli.py +++ b/rag/cli.py @@ -7,6 +7,7 @@ from tqdm import tqdm from rag.generator import get_generator from rag.generator.prompt import Prompt +from rag.model import Rag from rag.retriever.rerank import get_reranker from rag.retriever.retriever import Retriever @@ -57,22 +58,20 @@ def upload(directory: str, verbose: int): prompt="Enter your query", ) @click.option( - "-m", - "--model", + "-c", + "--client", type=click.Choice(["local", "cohere"], case_sensitive=False), default="local", help="Generator and rerank model", ) @click.option("-v", "--verbose", count=True) -def rag(query: str, model: str, verbose: int): +def rag(query: str, client: str, verbose: int): configure_logging(verbose) - retriever = Retriever() - generator = get_generator(model) - reranker = get_reranker(model) - documents = retriever.retrieve(query) - prompt = reranker.rank(Prompt(query, documents)) + rag = Rag(client) + documents = rag.retrieve(query) + prompt = Prompt(query, documents, client) print("Answer: ") - for chunk in generator.generate(prompt): + for chunk in rag.generate(prompt): print(chunk, end="", flush=True) print("\n\n") -- cgit v1.2.3-70-g09d2