diff options
Diffstat (limited to 'rag/generator')
-rw-r--r-- | rag/generator/__init__.py | 6 | ||||
-rw-r--r-- | rag/generator/abstract.py | 11 | ||||
-rw-r--r-- | rag/generator/cohere.py | 12 | ||||
-rw-r--r-- | rag/generator/ollama.py | 11 | ||||
-rw-r--r-- | rag/generator/prompt.py | 10 |
5 files changed, 21 insertions, 29 deletions
diff --git a/rag/generator/__init__.py b/rag/generator/__init__.py index a776231..770db16 100644 --- a/rag/generator/__init__.py +++ b/rag/generator/__init__.py @@ -1,8 +1,8 @@ from typing import Type -from .abstract import AbstractGenerator -from .cohere import Cohere -from .ollama import Ollama +from rag.generator.abstract import AbstractGenerator +from rag.generator.cohere import Cohere +from rag.generator.ollama import Ollama MODELS = ["local", "cohere"] diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py index 995e937..3ce997e 100644 --- a/rag/generator/abstract.py +++ b/rag/generator/abstract.py @@ -1,9 +1,8 @@ from abc import abstractmethod -from typing import Any, Generator +from typing import Any, Generator, List -from rag.rag import Message - -from .prompt import Prompt +from rag.message import Message +from rag.retriever.vector import Document class AbstractGenerator(type): @@ -16,7 +15,5 @@ class AbstractGenerator(type): return cls._instances[cls] @abstractmethod - def generate( - self, prompt: Prompt, messages: List[Message] - ) -> Generator[Any, Any, Any]: + def generate(self, messages: List[Message], documents: List[Document]) -> Generator[Any, Any, Any]: pass diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py index f30fe69..575452f 100644 --- a/rag/generator/cohere.py +++ b/rag/generator/cohere.py @@ -5,10 +5,10 @@ from typing import Any, Generator, List import cohere from loguru import logger as log -from rag.rag import Message +from rag.message import Message +from rag.retriever.vector import Document from .abstract import AbstractGenerator -from .prompt import Prompt class Cohere(metaclass=AbstractGenerator): @@ -16,13 +16,13 @@ class Cohere(metaclass=AbstractGenerator): self.client = cohere.Client(os.environ["COHERE_API_KEY"]) def generate( - self, prompt: Prompt, messages: List[Message] + self, messages: List[Message], documents: List[Document] ) -> Generator[Any, Any, Any]: log.debug("Generating answer from cohere...") for event in self.client.chat_stream( - message=prompt.to_str(), - documents=[asdict(d) for d in prompt.documents], - chat_history=[m.as_dict() for m in messages], + message=messages[-1].content, + documents=[asdict(d) for d in documents], + chat_history=[m.as_dict() for m in messages[:-1]], prompt_truncation="AUTO", ): if event.event_type == "text-generation": diff --git a/rag/generator/ollama.py b/rag/generator/ollama.py index ff5402b..84563bb 100644 --- a/rag/generator/ollama.py +++ b/rag/generator/ollama.py @@ -4,10 +4,10 @@ from typing import Any, Generator, List import ollama from loguru import logger as log -from rag.rag import Message +from rag.message import Message +from rag.retriever.vector import Document from .abstract import AbstractGenerator -from .prompt import Prompt class Ollama(metaclass=AbstractGenerator): @@ -16,12 +16,9 @@ class Ollama(metaclass=AbstractGenerator): log.debug(f"Using {self.model} for generator...") def generate( - self, prompt: Prompt, messages: List[Message] + self, messages: List[Message], documents: List[Document] ) -> Generator[Any, Any, Any]: log.debug("Generating answer with ollama...") - messages = messages.append( - Message(role="user", content=prompt.to_str(), client="ollama") - ) messages = [m.as_dict() for m in messages] for chunk in ollama.chat(model=self.model, messages=messages, stream=True): - yield chunk["response"] + yield chunk["message"]["content"] diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py index 4840fdc..cedf610 100644 --- a/rag/generator/prompt.py +++ b/rag/generator/prompt.py @@ -15,7 +15,7 @@ ANSWER_INSTRUCTION = ( class Prompt: query: str documents: List[Document] - generator_model: str + client: str def __context(self, documents: List[Document]) -> str: results = [ @@ -25,17 +25,15 @@ class Prompt: return "\n".join(results) def to_str(self) -> str: - if self.generator_model == "cohere": + if self.client == "cohere": return f"{self.query}\n\n{ANSWER_INSTRUCTION}" else: return ( "Context information is below.\n" - "---------------------\n" + "---\n" f"{self.__context(self.documents)}\n\n" - "---------------------\n" + "---\n" f"{ANSWER_INSTRUCTION}" - "Do not attempt to answer the query without relevant context and do not use" - " prior knowledge or training data!\n" f"Query: {self.query.strip()}\n\n" "Answer:" ) |