diff options
Diffstat (limited to 'rag/cli.py')
| -rw-r--r-- | rag/cli.py | 91 | 
1 files changed, 67 insertions, 24 deletions
@@ -1,33 +1,76 @@  from pathlib import Path +import click  from dotenv import load_dotenv +from loguru import logger as log -from rag.generator import MODELS, get_generator +from rag.generator import get_generator  from rag.generator.prompt import Prompt  from rag.retriever.retriever import Retriever -if __name__ == "__main__": -    load_dotenv() + +def upload(directory: str): +    log.info(f"Uploading pfs found in directory {directory}...") +    retriever = Retriever() +    pdfs = Path(directory).glob("**/*.pdf") +    for path in tqdm(list(pdfs)): +        retriever.add_pdf(path=path) + + +def rag(generator: str, query: str, limit):      retriever = Retriever() +    generator = get_generator(generator) +    documents = retriever.retrieve(query, limit=limit) +    prompt = Prompt(query, documents) +    print("Answer: ") +    for chunk in generator.generate(prompt): +        print(chunk, end="", flush=True) -    print("\n\nRetrieval Augmented Generation\n") -    model = input(f"Enter model ({MODELS}):") - -    while True: -        choice = input("1. Add pdf from path\n2. Enter a query\n") -        match choice: -            case "1": -                path = input("Enter the path to the pdf: ") -                path = Path(path) -                retriever.add_pdf(path=path) -            case "2": -                query = input("Enter your query: ") -                if query: -                    generator = get_generator(model) -                    documents = retriever.retrieve(query) -                    prompt = Prompt(query, documents) -                    print("Answer: \n") -                    for chunk in generator.generate(prompt): -                        print(chunk, end="", flush=True) -            case _: -                print("Invalid option!") +    print(f"\n\n") +    for i, doc in enumerate(documents): +        print(f"### Document {i}") +        print(f"**Title: {doc.title}**") +        print(doc.text) +        print("---") + + +@click.option( +    "-q", +    "--query", +    help="The query for rag", +    prompt="Enter your query", +) +@click.option( +    "-g", +    "--generator", +    type=click.Choice(["ollama", "cohere"], case_sensitive=False), +    default="ollama", +    help="Generator client", +) +@click.option( +    "-l", +    "--limit", +    type=click.IntRange(1, 20, clamp=True), +    default=5, +    help="Max number of documents used in grouding", +) +@click.command() +@click.option( +    "-d", +    "--directory", +    help="The full path to the root directory containing pdfs to upload", +    type=click.Path(exists=True), +    default=None, +) +def main(query: str, generator: str, limit: int, directory: str): +    if query: +        rag(generator, query, limit) +    elif directory: +        upload(directory) +    # TODO: truncate databases +    # TODO: maybe add override for models + + +if __name__ == "__main__": +    load_dotenv() +    main()  |