summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-14 23:14:21 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-14 23:14:21 +0200
commitbd87b7fed75d5e504eb40c6616c2f1e1e56a0451 (patch)
treef6e3b68c48f150c8b6e6acd33d6e760f334a456d
parent3c59ed779e3b30ab4877ae94242c9df076df681a (diff)
Refactor cli
-rw-r--r--rag/cli.py91
-rw-r--r--rag/upload.py30
2 files changed, 67 insertions, 54 deletions
diff --git a/rag/cli.py b/rag/cli.py
index 7e7e88a..a037cb4 100644
--- a/rag/cli.py
+++ b/rag/cli.py
@@ -1,33 +1,76 @@
from pathlib import Path
+import click
from dotenv import load_dotenv
+from loguru import logger as log
-from rag.generator import MODELS, get_generator
+from rag.generator import get_generator
from rag.generator.prompt import Prompt
from rag.retriever.retriever import Retriever
-if __name__ == "__main__":
- load_dotenv()
+
+def upload(directory: str):
+ log.info(f"Uploading pfs found in directory {directory}...")
+ retriever = Retriever()
+ pdfs = Path(directory).glob("**/*.pdf")
+ for path in tqdm(list(pdfs)):
+ retriever.add_pdf(path=path)
+
+
+def rag(generator: str, query: str, limit):
retriever = Retriever()
+ generator = get_generator(generator)
+ documents = retriever.retrieve(query, limit=limit)
+ prompt = Prompt(query, documents)
+ print("Answer: ")
+ for chunk in generator.generate(prompt):
+ print(chunk, end="", flush=True)
- print("\n\nRetrieval Augmented Generation\n")
- model = input(f"Enter model ({MODELS}):")
-
- while True:
- choice = input("1. Add pdf from path\n2. Enter a query\n")
- match choice:
- case "1":
- path = input("Enter the path to the pdf: ")
- path = Path(path)
- retriever.add_pdf(path=path)
- case "2":
- query = input("Enter your query: ")
- if query:
- generator = get_generator(model)
- documents = retriever.retrieve(query)
- prompt = Prompt(query, documents)
- print("Answer: \n")
- for chunk in generator.generate(prompt):
- print(chunk, end="", flush=True)
- case _:
- print("Invalid option!")
+ print(f"\n\n")
+ for i, doc in enumerate(documents):
+ print(f"### Document {i}")
+ print(f"**Title: {doc.title}**")
+ print(doc.text)
+ print("---")
+
+
+@click.option(
+ "-q",
+ "--query",
+ help="The query for rag",
+ prompt="Enter your query",
+)
+@click.option(
+ "-g",
+ "--generator",
+ type=click.Choice(["ollama", "cohere"], case_sensitive=False),
+ default="ollama",
+ help="Generator client",
+)
+@click.option(
+ "-l",
+ "--limit",
+ type=click.IntRange(1, 20, clamp=True),
+ default=5,
+ help="Max number of documents used in grouding",
+)
+@click.command()
+@click.option(
+ "-d",
+ "--directory",
+ help="The full path to the root directory containing pdfs to upload",
+ type=click.Path(exists=True),
+ default=None,
+)
+def main(query: str, generator: str, limit: int, directory: str):
+ if query:
+ rag(generator, query, limit)
+ elif directory:
+ upload(directory)
+ # TODO: truncate databases
+ # TODO: maybe add override for models
+
+
+if __name__ == "__main__":
+ load_dotenv()
+ main()
diff --git a/rag/upload.py b/rag/upload.py
deleted file mode 100644
index 8567142..0000000
--- a/rag/upload.py
+++ /dev/null
@@ -1,30 +0,0 @@
-from pathlib import Path
-
-import click
-from dotenv import load_dotenv
-from loguru import logger as log
-from tqdm import tqdm
-
-from rag.retriever.retriever import Retriever
-
-
-@click.command()
-@click.option(
- "-d",
- "--directory",
- help="The full path to the root directory containing pdfs to upload",
- type=click.Path(exists=True),
-)
-def main(directory: str):
- log.info(f"Uploading pfs found in directory {directory}...")
- retriever = Retriever()
- pdfs = Path(directory).glob("**/*.pdf")
- for path in tqdm(list(pdfs)):
- retriever.add_pdf(path=path)
-
-
-if __name__ == "__main__":
- log.remove()
- log.add(lambda msg: tqdm.write(msg, end=""), colorize=True)
- load_dotenv()
- main()