summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-24 23:36:58 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-24 23:36:58 +0200
commitfd635eb8e61c18d7dbe4d430e2e3eeb4a1755947 (patch)
tree3860127c5f456e79c739e67a2754580bd77cc03b
parent7fcb7489401bd3f4d2fb84b758388d06a703b5be (diff)
Refactor cli
-rw-r--r--rag/cli.py93
-rw-r--r--rag/drop.py25
2 files changed, 48 insertions, 70 deletions
diff --git a/rag/cli.py b/rag/cli.py
index b210808..e8fdd86 100644
--- a/rag/cli.py
+++ b/rag/cli.py
@@ -1,5 +1,4 @@
from pathlib import Path
-from typing import Optional
import click
from dotenv import load_dotenv
@@ -26,7 +25,22 @@ def configure_logging(verbose: int):
log.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=level)
-def upload(directory: str):
+@click.group()
+def cli():
+ pass
+
+
+@click.command()
+@click.option(
+ "-d",
+ "--directory",
+ help="The full path to the root directory containing pdfs to upload",
+ type=click.Path(exists=True),
+ default=None,
+)
+@click.option("-v", "--verbose", count=True)
+def upload(directory: str, verbose: int):
+ configure_logging(verbose)
log.info(f"Uploading pfs found in directory {directory}...")
retriever = Retriever()
pdfs = Path(directory).glob("**/*.pdf")
@@ -34,7 +48,24 @@ def upload(directory: str):
retriever.add_pdf(path=path)
-def rag(model: str, query: str):
+@click.command()
+@click.option(
+ "-q",
+ "--query",
+ prompt_required=False,
+ help="The query for rag",
+ prompt="Enter your query",
+)
+@click.option(
+ "-m",
+ "--model",
+ type=click.Choice(["local", "cohere"], case_sensitive=False),
+ default="local",
+ help="Generator and rerank model",
+)
+@click.option("-v", "--verbose", count=True)
+def rag(query: str, model: str, verbose: int):
+ configure_logging(verbose)
retriever = Retriever()
generator = get_generator(model)
reranker = get_reranker(model)
@@ -53,49 +84,21 @@ def rag(model: str, query: str):
@click.command()
-@click.option(
- "-q",
- "--query",
- prompt_required=False,
- help="The query for rag",
- prompt="Enter your query",
-)
-@click.option(
- "-m",
- "--model",
- type=click.Choice(["local", "cohere"], case_sensitive=False),
- default="local",
- help="Generator and rerank model",
-)
-@click.option(
- "-d",
- "--directory",
- help="The full path to the root directory containing pdfs to upload",
- type=click.Path(exists=True),
- default=None,
-)
-@click.option(
- "-q",
- "--query",
- prompt_required=False,
- help="The query for rag",
- prompt="Enter your query",
-)
-@click.option("-v", "--verbose", count=True)
-def main(
- query: Optional[str],
- generator: str,
- directory: Optional[str],
- verbose: int,
-):
- configure_logging(verbose)
- if directory:
- upload(directory)
- if query:
- rag(generator, query)
- # TODO: maybe add override for models
+@click.confirmation_option(prompt="Are you sure you want to drop the db?")
+def drop():
+ log.debug("Deleting all data...")
+ retriever = Retriever()
+ doc_db = retriever.doc_db
+ doc_db.delete_all()
+ vec_db = retriever.vec_db
+ vec_db.delete_collection()
+
+cli.add_command(configure_logging)
+cli.add_command(rag)
+cli.add_command(upload)
+cli.add_command(drop)
if __name__ == "__main__":
load_dotenv()
- main()
+ cli()
diff --git a/rag/drop.py b/rag/drop.py
deleted file mode 100644
index 5ff5983..0000000
--- a/rag/drop.py
+++ /dev/null
@@ -1,25 +0,0 @@
-import click
-from dotenv import load_dotenv
-from loguru import logger as log
-
-from rag.retriever.retriever import Retriever
-
-
-def drop():
- log.debug("Deleting all data...")
- retriever = Retriever()
- doc_db = retriever.doc_db
- doc_db.delete_all()
- vec_db = retriever.vec_db
- vec_db.delete_collection()
-
-
-@click.command()
-@click.confirmation_option(prompt="Are you sure you want to drop the db?")
-def main():
- drop()
-
-
-if __name__ == "__main__":
- load_dotenv()
- main()