summaryrefslogtreecommitdiff
path: root/rag/cli.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/cli.py')
-rw-r--r--rag/cli.py32
1 files changed, 13 insertions, 19 deletions
diff --git a/rag/cli.py b/rag/cli.py
index 932e2a9..b210808 100644
--- a/rag/cli.py
+++ b/rag/cli.py
@@ -8,6 +8,7 @@ from tqdm import tqdm
from rag.generator import get_generator
from rag.generator.prompt import Prompt
+from rag.retriever.rerank import get_reranker
from rag.retriever.retriever import Retriever
@@ -33,11 +34,12 @@ def upload(directory: str):
retriever.add_pdf(path=path)
-def rag(generator: str, query: str, limit):
+def rag(model: str, query: str):
retriever = Retriever()
- generator = get_generator(generator)
- documents = retriever.retrieve(query, limit=limit)
- prompt = generator.rerank(Prompt(query, documents))
+ generator = get_generator(model)
+ reranker = get_reranker(model)
+ documents = retriever.retrieve(query)
+ prompt = reranker.rerank(Prompt(query, documents))
print("Answer: ")
for chunk in generator.generate(prompt):
print(chunk, end="", flush=True)
@@ -50,6 +52,7 @@ def rag(generator: str, query: str, limit):
print("---")
+@click.command()
@click.option(
"-q",
"--query",
@@ -58,20 +61,12 @@ def rag(generator: str, query: str, limit):
prompt="Enter your query",
)
@click.option(
- "-g",
- "--generator",
- type=click.Choice(["ollama", "cohere"], case_sensitive=False),
- default="ollama",
- help="Generator client",
-)
-@click.option(
- "-l",
- "--limit",
- type=click.IntRange(1, 20, clamp=True),
- default=5,
- help="Max number of documents used in grouding",
+ "-m",
+ "--model",
+ type=click.Choice(["local", "cohere"], case_sensitive=False),
+ default="local",
+ help="Generator and rerank model",
)
-@click.command()
@click.option(
"-d",
"--directory",
@@ -90,7 +85,6 @@ def rag(generator: str, query: str, limit):
def main(
query: Optional[str],
generator: str,
- limit: int,
directory: Optional[str],
verbose: int,
):
@@ -98,7 +92,7 @@ def main(
if directory:
upload(directory)
if query:
- rag(generator, query, limit)
+ rag(generator, query)
# TODO: maybe add override for models