diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-06-19 02:07:06 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-06-19 02:07:06 +0200 |
commit | aac821b148c6c0d35b940609dc9b0ddcb053b28e (patch) | |
tree | 5c125045b2b60ead39e093327d664adf43d1d35b /rag/cli.py | |
parent | f2846429310452bebbf0d07203b1e53978c439c7 (diff) |
Still wip on rewrite
Diffstat (limited to 'rag/cli.py')
-rw-r--r-- | rag/cli.py | 17 |
1 files changed, 8 insertions, 9 deletions
@@ -7,6 +7,7 @@ from tqdm import tqdm from rag.generator import get_generator from rag.generator.prompt import Prompt +from rag.model import Rag from rag.retriever.rerank import get_reranker from rag.retriever.retriever import Retriever @@ -57,22 +58,20 @@ def upload(directory: str, verbose: int): prompt="Enter your query", ) @click.option( - "-m", - "--model", + "-c", + "--client", type=click.Choice(["local", "cohere"], case_sensitive=False), default="local", help="Generator and rerank model", ) @click.option("-v", "--verbose", count=True) -def rag(query: str, model: str, verbose: int): +def rag(query: str, client: str, verbose: int): configure_logging(verbose) - retriever = Retriever() - generator = get_generator(model) - reranker = get_reranker(model) - documents = retriever.retrieve(query) - prompt = reranker.rank(Prompt(query, documents)) + rag = Rag(client) + documents = rag.retrieve(query) + prompt = Prompt(query, documents, client) print("Answer: ") - for chunk in generator.generate(prompt): + for chunk in rag.generate(prompt): print(chunk, end="", flush=True) print("\n\n") |