diff options
Diffstat (limited to 'rag/cli.py')
-rw-r--r-- | rag/cli.py | 91 |
1 files changed, 67 insertions, 24 deletions
@@ -1,33 +1,76 @@ from pathlib import Path +import click from dotenv import load_dotenv +from loguru import logger as log -from rag.generator import MODELS, get_generator +from rag.generator import get_generator from rag.generator.prompt import Prompt from rag.retriever.retriever import Retriever -if __name__ == "__main__": - load_dotenv() + +def upload(directory: str): + log.info(f"Uploading pfs found in directory {directory}...") + retriever = Retriever() + pdfs = Path(directory).glob("**/*.pdf") + for path in tqdm(list(pdfs)): + retriever.add_pdf(path=path) + + +def rag(generator: str, query: str, limit): retriever = Retriever() + generator = get_generator(generator) + documents = retriever.retrieve(query, limit=limit) + prompt = Prompt(query, documents) + print("Answer: ") + for chunk in generator.generate(prompt): + print(chunk, end="", flush=True) - print("\n\nRetrieval Augmented Generation\n") - model = input(f"Enter model ({MODELS}):") - - while True: - choice = input("1. Add pdf from path\n2. Enter a query\n") - match choice: - case "1": - path = input("Enter the path to the pdf: ") - path = Path(path) - retriever.add_pdf(path=path) - case "2": - query = input("Enter your query: ") - if query: - generator = get_generator(model) - documents = retriever.retrieve(query) - prompt = Prompt(query, documents) - print("Answer: \n") - for chunk in generator.generate(prompt): - print(chunk, end="", flush=True) - case _: - print("Invalid option!") + print(f"\n\n") + for i, doc in enumerate(documents): + print(f"### Document {i}") + print(f"**Title: {doc.title}**") + print(doc.text) + print("---") + + +@click.option( + "-q", + "--query", + help="The query for rag", + prompt="Enter your query", +) +@click.option( + "-g", + "--generator", + type=click.Choice(["ollama", "cohere"], case_sensitive=False), + default="ollama", + help="Generator client", +) +@click.option( + "-l", + "--limit", + type=click.IntRange(1, 20, clamp=True), + default=5, + help="Max number of documents used in grouding", +) +@click.command() +@click.option( + "-d", + "--directory", + help="The full path to the root directory containing pdfs to upload", + type=click.Path(exists=True), + default=None, +) +def main(query: str, generator: str, limit: int, directory: str): + if query: + rag(generator, query, limit) + elif directory: + upload(directory) + # TODO: truncate databases + # TODO: maybe add override for models + + +if __name__ == "__main__": + load_dotenv() + main() |