summaryrefslogtreecommitdiff
path: root/rag/cli.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/cli.py')
-rw-r--r--rag/cli.py91
1 files changed, 67 insertions, 24 deletions
diff --git a/rag/cli.py b/rag/cli.py
index 7e7e88a..a037cb4 100644
--- a/rag/cli.py
+++ b/rag/cli.py
@@ -1,33 +1,76 @@
from pathlib import Path
+import click
from dotenv import load_dotenv
+from loguru import logger as log
-from rag.generator import MODELS, get_generator
+from rag.generator import get_generator
from rag.generator.prompt import Prompt
from rag.retriever.retriever import Retriever
-if __name__ == "__main__":
- load_dotenv()
+
+def upload(directory: str):
+ log.info(f"Uploading pfs found in directory {directory}...")
+ retriever = Retriever()
+ pdfs = Path(directory).glob("**/*.pdf")
+ for path in tqdm(list(pdfs)):
+ retriever.add_pdf(path=path)
+
+
+def rag(generator: str, query: str, limit):
retriever = Retriever()
+ generator = get_generator(generator)
+ documents = retriever.retrieve(query, limit=limit)
+ prompt = Prompt(query, documents)
+ print("Answer: ")
+ for chunk in generator.generate(prompt):
+ print(chunk, end="", flush=True)
- print("\n\nRetrieval Augmented Generation\n")
- model = input(f"Enter model ({MODELS}):")
-
- while True:
- choice = input("1. Add pdf from path\n2. Enter a query\n")
- match choice:
- case "1":
- path = input("Enter the path to the pdf: ")
- path = Path(path)
- retriever.add_pdf(path=path)
- case "2":
- query = input("Enter your query: ")
- if query:
- generator = get_generator(model)
- documents = retriever.retrieve(query)
- prompt = Prompt(query, documents)
- print("Answer: \n")
- for chunk in generator.generate(prompt):
- print(chunk, end="", flush=True)
- case _:
- print("Invalid option!")
+ print(f"\n\n")
+ for i, doc in enumerate(documents):
+ print(f"### Document {i}")
+ print(f"**Title: {doc.title}**")
+ print(doc.text)
+ print("---")
+
+
+@click.option(
+ "-q",
+ "--query",
+ help="The query for rag",
+ prompt="Enter your query",
+)
+@click.option(
+ "-g",
+ "--generator",
+ type=click.Choice(["ollama", "cohere"], case_sensitive=False),
+ default="ollama",
+ help="Generator client",
+)
+@click.option(
+ "-l",
+ "--limit",
+ type=click.IntRange(1, 20, clamp=True),
+ default=5,
+ help="Max number of documents used in grouding",
+)
+@click.command()
+@click.option(
+ "-d",
+ "--directory",
+ help="The full path to the root directory containing pdfs to upload",
+ type=click.Path(exists=True),
+ default=None,
+)
+def main(query: str, generator: str, limit: int, directory: str):
+ if query:
+ rag(generator, query, limit)
+ elif directory:
+ upload(directory)
+ # TODO: truncate databases
+ # TODO: maybe add override for models
+
+
+if __name__ == "__main__":
+ load_dotenv()
+ main()