summaryrefslogtreecommitdiff
path: root/rag/cli.py
blob: 7e7e88a989e88ae56ba6cc92b5f8f63020235cd5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from pathlib import Path

from dotenv import load_dotenv

from rag.generator import MODELS, get_generator
from rag.generator.prompt import Prompt
from rag.retriever.retriever import Retriever

if __name__ == "__main__":
    load_dotenv()
    retriever = Retriever()

    print("\n\nRetrieval Augmented Generation\n")
    model = input(f"Enter model ({MODELS}):")

    while True:
        choice = input("1. Add pdf from path\n2. Enter a query\n")
        match choice:
            case "1":
                path = input("Enter the path to the pdf: ")
                path = Path(path)
                retriever.add_pdf(path=path)
            case "2":
                query = input("Enter your query: ")
                if query:
                    generator = get_generator(model)
                    documents = retriever.retrieve(query)
                    prompt = Prompt(query, documents)
                    print("Answer: \n")
                    for chunk in generator.generate(prompt):
                        print(chunk, end="", flush=True)
            case _:
                print("Invalid option!")