summaryrefslogtreecommitdiff
path: root/rag
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-08-05 00:37:21 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-08-05 00:37:21 +0200
commit5531d8147e52324a16c977f385715f934af5f246 (patch)
tree8688c70a4cfc1ee617c9533a401530bd15556bf9 /rag
parent5142aaaa356549ba7e7e9cdacf365326191831ac (diff)
Fix broken stuff
Diffstat (limited to 'rag')
-rw-r--r--rag/__main__.py6
-rw-r--r--rag/cli.py43
-rw-r--r--rag/generator/prompt.py10
-rw-r--r--rag/retriever/rerank/local.py8
4 files changed, 37 insertions, 30 deletions
diff --git a/rag/__main__.py b/rag/__main__.py
new file mode 100644
index 0000000..be85a2e
--- /dev/null
+++ b/rag/__main__.py
@@ -0,0 +1,6 @@
+from dotenv import load_dotenv
+
+from rag.cli import cli
+
+load_dotenv()
+cli()
diff --git a/rag/cli.py b/rag/cli.py
index 070427d..b047255 100644
--- a/rag/cli.py
+++ b/rag/cli.py
@@ -5,10 +5,8 @@ from dotenv import load_dotenv
from loguru import logger as log
from tqdm import tqdm
-from rag.generator import get_generator
from rag.generator.prompt import Prompt
from rag.model import Rag
-from rag.retriever.rerank import get_reranker
from rag.retriever.retriever import Retriever
@@ -51,13 +49,6 @@ def upload(directory: str, verbose: int):
@click.command()
@click.option(
- "-q",
- "--query",
- prompt_required=False,
- help="The query for rag",
- prompt="Enter your query",
-)
-@click.option(
"-c",
"--client",
type=click.Choice(["local", "cohere"], case_sensitive=False),
@@ -65,21 +56,29 @@ def upload(directory: str, verbose: int):
help="Generator and rerank model",
)
@click.option("-v", "--verbose", count=True)
-def rag(query: str, client: str, verbose: int):
+def rag(client: str, verbose: int):
configure_logging(verbose)
rag = Rag(client)
- documents = rag.retrieve(query)
- prompt = Prompt(query, documents, client)
- print("Answer: ")
- for chunk in rag.generate(prompt):
- print(chunk, end="", flush=True)
-
- print("\n\n")
- for i, doc in enumerate(prompt.documents):
- print(f"### Document {i}")
- print(f"**Title: {doc.title}**")
- print(doc.text)
- print("---")
+ while True:
+ query = input("Query: ")
+ documents = rag.retrieve(query)
+ prompt = Prompt(query, documents, client)
+ print("Answer: ")
+ response = ""
+ for chunk in rag.generate(prompt):
+ print(chunk, end="", flush=True)
+ response += chunk
+
+ rag.add_message(rag.bot, response)
+
+ show_context = input("Display context? [y/n] ").lower() == "y"
+ print("\n\n")
+ if show_context:
+ for i, doc in enumerate(prompt.documents):
+ print(f"### Document {i}")
+ print(f"**Title: {doc.title}**")
+ print(doc.text)
+ print("---")
@click.command()
diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py
index cedf610..75966e8 100644
--- a/rag/generator/prompt.py
+++ b/rag/generator/prompt.py
@@ -29,11 +29,13 @@ class Prompt:
return f"{self.query}\n\n{ANSWER_INSTRUCTION}"
else:
return (
- "Context information is below.\n"
+ "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n"
+ "Using the information contained in the context, give a comprehensive answer to the question.\n"
+ "If the answer cannot be deduced from the context, do not give an answer.\n\n"
+ "Context:\n"
"---\n"
f"{self.__context(self.documents)}\n\n"
"---\n"
- f"{ANSWER_INSTRUCTION}"
- f"Query: {self.query.strip()}\n\n"
- "Answer:"
+ f"Question: {self.query}<|eot_id|>\n"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
)
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py
index fd42c2c..231d50a 100644
--- a/rag/retriever/rerank/local.py
+++ b/rag/retriever/rerank/local.py
@@ -2,16 +2,16 @@ import os
from typing import List
from loguru import logger as log
-from rag.message import Message
-from rag.retriever.vector import Document
from sentence_transformers import CrossEncoder
+from rag.message import Message
from rag.retriever.rerank.abstract import AbstractReranker
+from rag.retriever.vector import Document
class Reranker(metaclass=AbstractReranker):
def __init__(self) -> None:
- self.model = CrossEncoder(os.environ["RERANK_MODEL"])
+ self.model = CrossEncoder(os.environ["RERANK_MODEL"], device="cpu")
self.top_k = int(os.environ["RERANK_TOP_K"])
self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"])
@@ -33,7 +33,7 @@ class Reranker(metaclass=AbstractReranker):
def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]:
results = self.model.rank(
query=query,
- documents=[m.message for m in messages],
+ documents=[m.content for m in messages],
return_documents=False,
top_k=self.top_k,
)