diff options
Diffstat (limited to 'rag/generator')
-rw-r--r-- | rag/generator/abstract.py | 6 | ||||
-rw-r--r-- | rag/generator/cohere.py | 13 | ||||
-rw-r--r-- | rag/generator/ollama.py | 36 | ||||
-rw-r--r-- | rag/generator/prompt.py | 24 |
4 files changed, 46 insertions, 33 deletions
diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py index 1beacfb..995e937 100644 --- a/rag/generator/abstract.py +++ b/rag/generator/abstract.py @@ -1,6 +1,8 @@ from abc import abstractmethod from typing import Any, Generator +from rag.rag import Message + from .prompt import Prompt @@ -14,5 +16,7 @@ class AbstractGenerator(type): return cls._instances[cls] @abstractmethod - def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: + def generate( + self, prompt: Prompt, messages: List[Message] + ) -> Generator[Any, Any, Any]: pass diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py index fb0cc5b..f30fe69 100644 --- a/rag/generator/cohere.py +++ b/rag/generator/cohere.py @@ -1,12 +1,14 @@ import os from dataclasses import asdict -from typing import Any, Dict, Generator, List, Optional +from typing import Any, Generator, List import cohere from loguru import logger as log +from rag.rag import Message + from .abstract import AbstractGenerator -from .prompt import ANSWER_INSTRUCTION, Prompt +from .prompt import Prompt class Cohere(metaclass=AbstractGenerator): @@ -14,14 +16,13 @@ class Cohere(metaclass=AbstractGenerator): self.client = cohere.Client(os.environ["COHERE_API_KEY"]) def generate( - self, prompt: Prompt, history: Optional[List[Dict[str, str]]] + self, prompt: Prompt, messages: List[Message] ) -> Generator[Any, Any, Any]: log.debug("Generating answer from cohere...") - query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}" for event in self.client.chat_stream( - message=query, + message=prompt.to_str(), documents=[asdict(d) for d in prompt.documents], - chat_history=history, + chat_history=[m.as_dict() for m in messages], prompt_truncation="AUTO", ): if event.event_type == "text-generation": diff --git a/rag/generator/ollama.py b/rag/generator/ollama.py index 9bf551a..ff5402b 100644 --- a/rag/generator/ollama.py +++ b/rag/generator/ollama.py @@ -4,10 +4,10 @@ from typing import Any, Generator, List import ollama from loguru import logger as log -from rag.retriever.vector import Document +from rag.rag import Message from .abstract import AbstractGenerator -from .prompt import ANSWER_INSTRUCTION, Prompt +from .prompt import Prompt class Ollama(metaclass=AbstractGenerator): @@ -15,29 +15,13 @@ class Ollama(metaclass=AbstractGenerator): self.model = os.environ["GENERATOR_MODEL"] log.debug(f"Using {self.model} for generator...") - def __context(self, documents: List[Document]) -> str: - results = [ - f"Document: {i}\ntitle: {doc.title}\ntext: {doc.text}" - for i, doc in enumerate(documents) - ] - return "\n".join(results) - - def __metaprompt(self, prompt: Prompt) -> str: - metaprompt = ( - "Context information is below.\n" - "---------------------\n" - f"{self.__context(prompt.documents)}\n\n" - "---------------------\n" - f"{ANSWER_INSTRUCTION}" - "Do not attempt to answer the query without relevant context and do not use" - " prior knowledge or training data!\n" - f"Query: {prompt.query.strip()}\n\n" - "Answer:" - ) - return metaprompt - - def generate(self, prompt: Prompt, memory: Memory) -> Generator[Any, Any, Any]: + def generate( + self, prompt: Prompt, messages: List[Message] + ) -> Generator[Any, Any, Any]: log.debug("Generating answer with ollama...") - metaprompt = self.__metaprompt(prompt) - for chunk in ollama.chat(model=self.model, messages=memory.append(metaprompt), stream=True): + messages = messages.append( + Message(role="user", content=prompt.to_str(), client="ollama") + ) + messages = [m.as_dict() for m in messages] + for chunk in ollama.chat(model=self.model, messages=messages, stream=True): yield chunk["response"] diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py index 6523842..4840fdc 100644 --- a/rag/generator/prompt.py +++ b/rag/generator/prompt.py @@ -15,3 +15,27 @@ ANSWER_INSTRUCTION = ( class Prompt: query: str documents: List[Document] + generator_model: str + + def __context(self, documents: List[Document]) -> str: + results = [ + f"Document: {i}\ntitle: {doc.title}\ntext: {doc.text}" + for i, doc in enumerate(documents) + ] + return "\n".join(results) + + def to_str(self) -> str: + if self.generator_model == "cohere": + return f"{self.query}\n\n{ANSWER_INSTRUCTION}" + else: + return ( + "Context information is below.\n" + "---------------------\n" + f"{self.__context(self.documents)}\n\n" + "---------------------\n" + f"{ANSWER_INSTRUCTION}" + "Do not attempt to answer the query without relevant context and do not use" + " prior knowledge or training data!\n" + f"Query: {self.query.strip()}\n\n" + "Answer:" + ) |