summaryrefslogtreecommitdiff
path: root/rag/generator/prompt.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/generator/prompt.py')
-rw-r--r--rag/generator/prompt.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py
index 6523842..4840fdc 100644
--- a/rag/generator/prompt.py
+++ b/rag/generator/prompt.py
@@ -15,3 +15,27 @@ ANSWER_INSTRUCTION = (
class Prompt:
query: str
documents: List[Document]
+ generator_model: str
+
+ def __context(self, documents: List[Document]) -> str:
+ results = [
+ f"Document: {i}\ntitle: {doc.title}\ntext: {doc.text}"
+ for i, doc in enumerate(documents)
+ ]
+ return "\n".join(results)
+
+ def to_str(self) -> str:
+ if self.generator_model == "cohere":
+ return f"{self.query}\n\n{ANSWER_INSTRUCTION}"
+ else:
+ return (
+ "Context information is below.\n"
+ "---------------------\n"
+ f"{self.__context(self.documents)}\n\n"
+ "---------------------\n"
+ f"{ANSWER_INSTRUCTION}"
+ "Do not attempt to answer the query without relevant context and do not use"
+ " prior knowledge or training data!\n"
+ f"Query: {self.query.strip()}\n\n"
+ "Answer:"
+ )