From bd87b7fed75d5e504eb40c6616c2f1e1e56a0451 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 14 Apr 2024 23:14:21 +0200 Subject: Refactor cli --- rag/cli.py | 91 +++++++++++++++++++++++++++++++++++++++++++---------------- rag/upload.py | 30 -------------------- 2 files changed, 67 insertions(+), 54 deletions(-) delete mode 100644 rag/upload.py (limited to 'rag') diff --git a/rag/cli.py b/rag/cli.py index 7e7e88a..a037cb4 100644 --- a/rag/cli.py +++ b/rag/cli.py @@ -1,33 +1,76 @@ from pathlib import Path +import click from dotenv import load_dotenv +from loguru import logger as log -from rag.generator import MODELS, get_generator +from rag.generator import get_generator from rag.generator.prompt import Prompt from rag.retriever.retriever import Retriever -if __name__ == "__main__": - load_dotenv() + +def upload(directory: str): + log.info(f"Uploading pfs found in directory {directory}...") + retriever = Retriever() + pdfs = Path(directory).glob("**/*.pdf") + for path in tqdm(list(pdfs)): + retriever.add_pdf(path=path) + + +def rag(generator: str, query: str, limit): retriever = Retriever() + generator = get_generator(generator) + documents = retriever.retrieve(query, limit=limit) + prompt = Prompt(query, documents) + print("Answer: ") + for chunk in generator.generate(prompt): + print(chunk, end="", flush=True) - print("\n\nRetrieval Augmented Generation\n") - model = input(f"Enter model ({MODELS}):") - - while True: - choice = input("1. Add pdf from path\n2. Enter a query\n") - match choice: - case "1": - path = input("Enter the path to the pdf: ") - path = Path(path) - retriever.add_pdf(path=path) - case "2": - query = input("Enter your query: ") - if query: - generator = get_generator(model) - documents = retriever.retrieve(query) - prompt = Prompt(query, documents) - print("Answer: \n") - for chunk in generator.generate(prompt): - print(chunk, end="", flush=True) - case _: - print("Invalid option!") + print(f"\n\n") + for i, doc in enumerate(documents): + print(f"### Document {i}") + print(f"**Title: {doc.title}**") + print(doc.text) + print("---") + + +@click.option( + "-q", + "--query", + help="The query for rag", + prompt="Enter your query", +) +@click.option( + "-g", + "--generator", + type=click.Choice(["ollama", "cohere"], case_sensitive=False), + default="ollama", + help="Generator client", +) +@click.option( + "-l", + "--limit", + type=click.IntRange(1, 20, clamp=True), + default=5, + help="Max number of documents used in grouding", +) +@click.command() +@click.option( + "-d", + "--directory", + help="The full path to the root directory containing pdfs to upload", + type=click.Path(exists=True), + default=None, +) +def main(query: str, generator: str, limit: int, directory: str): + if query: + rag(generator, query, limit) + elif directory: + upload(directory) + # TODO: truncate databases + # TODO: maybe add override for models + + +if __name__ == "__main__": + load_dotenv() + main() diff --git a/rag/upload.py b/rag/upload.py deleted file mode 100644 index 8567142..0000000 --- a/rag/upload.py +++ /dev/null @@ -1,30 +0,0 @@ -from pathlib import Path - -import click -from dotenv import load_dotenv -from loguru import logger as log -from tqdm import tqdm - -from rag.retriever.retriever import Retriever - - -@click.command() -@click.option( - "-d", - "--directory", - help="The full path to the root directory containing pdfs to upload", - type=click.Path(exists=True), -) -def main(directory: str): - log.info(f"Uploading pfs found in directory {directory}...") - retriever = Retriever() - pdfs = Path(directory).glob("**/*.pdf") - for path in tqdm(list(pdfs)): - retriever.add_pdf(path=path) - - -if __name__ == "__main__": - log.remove() - log.add(lambda msg: tqdm.write(msg, end=""), colorize=True) - load_dotenv() - main() -- cgit v1.2.3-70-g09d2