diff options
Diffstat (limited to 'rag/cli.py')
-rw-r--r-- | rag/cli.py | 32 |
1 files changed, 13 insertions, 19 deletions
@@ -8,6 +8,7 @@ from tqdm import tqdm from rag.generator import get_generator from rag.generator.prompt import Prompt +from rag.retriever.rerank import get_reranker from rag.retriever.retriever import Retriever @@ -33,11 +34,12 @@ def upload(directory: str): retriever.add_pdf(path=path) -def rag(generator: str, query: str, limit): +def rag(model: str, query: str): retriever = Retriever() - generator = get_generator(generator) - documents = retriever.retrieve(query, limit=limit) - prompt = generator.rerank(Prompt(query, documents)) + generator = get_generator(model) + reranker = get_reranker(model) + documents = retriever.retrieve(query) + prompt = reranker.rerank(Prompt(query, documents)) print("Answer: ") for chunk in generator.generate(prompt): print(chunk, end="", flush=True) @@ -50,6 +52,7 @@ def rag(generator: str, query: str, limit): print("---") +@click.command() @click.option( "-q", "--query", @@ -58,20 +61,12 @@ def rag(generator: str, query: str, limit): prompt="Enter your query", ) @click.option( - "-g", - "--generator", - type=click.Choice(["ollama", "cohere"], case_sensitive=False), - default="ollama", - help="Generator client", -) -@click.option( - "-l", - "--limit", - type=click.IntRange(1, 20, clamp=True), - default=5, - help="Max number of documents used in grouding", + "-m", + "--model", + type=click.Choice(["local", "cohere"], case_sensitive=False), + default="local", + help="Generator and rerank model", ) -@click.command() @click.option( "-d", "--directory", @@ -90,7 +85,6 @@ def rag(generator: str, query: str, limit): def main( query: Optional[str], generator: str, - limit: int, directory: Optional[str], verbose: int, ): @@ -98,7 +92,7 @@ def main( if directory: upload(directory) if query: - rag(generator, query, limit) + rag(generator, query) # TODO: maybe add override for models |