From fd635eb8e61c18d7dbe4d430e2e3eeb4a1755947 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 24 Apr 2024 23:36:58 +0200 Subject: Refactor cli --- rag/cli.py | 93 +++++++++++++++++++++++++++++++------------------------------ rag/drop.py | 25 ----------------- 2 files changed, 48 insertions(+), 70 deletions(-) delete mode 100644 rag/drop.py (limited to 'rag') diff --git a/rag/cli.py b/rag/cli.py index b210808..e8fdd86 100644 --- a/rag/cli.py +++ b/rag/cli.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import Optional import click from dotenv import load_dotenv @@ -26,7 +25,22 @@ def configure_logging(verbose: int): log.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=level) -def upload(directory: str): +@click.group() +def cli(): + pass + + +@click.command() +@click.option( + "-d", + "--directory", + help="The full path to the root directory containing pdfs to upload", + type=click.Path(exists=True), + default=None, +) +@click.option("-v", "--verbose", count=True) +def upload(directory: str, verbose: int): + configure_logging(verbose) log.info(f"Uploading pfs found in directory {directory}...") retriever = Retriever() pdfs = Path(directory).glob("**/*.pdf") @@ -34,7 +48,24 @@ def upload(directory: str): retriever.add_pdf(path=path) -def rag(model: str, query: str): +@click.command() +@click.option( + "-q", + "--query", + prompt_required=False, + help="The query for rag", + prompt="Enter your query", +) +@click.option( + "-m", + "--model", + type=click.Choice(["local", "cohere"], case_sensitive=False), + default="local", + help="Generator and rerank model", +) +@click.option("-v", "--verbose", count=True) +def rag(query: str, model: str, verbose: int): + configure_logging(verbose) retriever = Retriever() generator = get_generator(model) reranker = get_reranker(model) @@ -53,49 +84,21 @@ def rag(model: str, query: str): @click.command() -@click.option( - "-q", - "--query", - prompt_required=False, - help="The query for rag", - prompt="Enter your query", -) -@click.option( - "-m", - "--model", - type=click.Choice(["local", "cohere"], case_sensitive=False), - default="local", - help="Generator and rerank model", -) -@click.option( - "-d", - "--directory", - help="The full path to the root directory containing pdfs to upload", - type=click.Path(exists=True), - default=None, -) -@click.option( - "-q", - "--query", - prompt_required=False, - help="The query for rag", - prompt="Enter your query", -) -@click.option("-v", "--verbose", count=True) -def main( - query: Optional[str], - generator: str, - directory: Optional[str], - verbose: int, -): - configure_logging(verbose) - if directory: - upload(directory) - if query: - rag(generator, query) - # TODO: maybe add override for models +@click.confirmation_option(prompt="Are you sure you want to drop the db?") +def drop(): + log.debug("Deleting all data...") + retriever = Retriever() + doc_db = retriever.doc_db + doc_db.delete_all() + vec_db = retriever.vec_db + vec_db.delete_collection() + +cli.add_command(configure_logging) +cli.add_command(rag) +cli.add_command(upload) +cli.add_command(drop) if __name__ == "__main__": load_dotenv() - main() + cli() diff --git a/rag/drop.py b/rag/drop.py deleted file mode 100644 index 5ff5983..0000000 --- a/rag/drop.py +++ /dev/null @@ -1,25 +0,0 @@ -import click -from dotenv import load_dotenv -from loguru import logger as log - -from rag.retriever.retriever import Retriever - - -def drop(): - log.debug("Deleting all data...") - retriever = Retriever() - doc_db = retriever.doc_db - doc_db.delete_all() - vec_db = retriever.vec_db - vec_db.delete_collection() - - -@click.command() -@click.confirmation_option(prompt="Are you sure you want to drop the db?") -def main(): - drop() - - -if __name__ == "__main__": - load_dotenv() - main() -- cgit v1.2.3-70-g09d2