diff options
Diffstat (limited to 'rag/generator')
| -rw-r--r-- | rag/generator/__init__.py | 6 | ||||
| -rw-r--r-- | rag/generator/abstract.py | 11 | ||||
| -rw-r--r-- | rag/generator/cohere.py | 12 | ||||
| -rw-r--r-- | rag/generator/ollama.py | 11 | ||||
| -rw-r--r-- | rag/generator/prompt.py | 10 | 
5 files changed, 21 insertions, 29 deletions
diff --git a/rag/generator/__init__.py b/rag/generator/__init__.py index a776231..770db16 100644 --- a/rag/generator/__init__.py +++ b/rag/generator/__init__.py @@ -1,8 +1,8 @@  from typing import Type -from .abstract import AbstractGenerator -from .cohere import Cohere -from .ollama import Ollama +from rag.generator.abstract import AbstractGenerator +from rag.generator.cohere import Cohere +from rag.generator.ollama import Ollama  MODELS = ["local", "cohere"] diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py index 995e937..3ce997e 100644 --- a/rag/generator/abstract.py +++ b/rag/generator/abstract.py @@ -1,9 +1,8 @@  from abc import abstractmethod -from typing import Any, Generator +from typing import Any, Generator, List -from rag.rag import Message - -from .prompt import Prompt +from rag.message import Message +from rag.retriever.vector import Document  class AbstractGenerator(type): @@ -16,7 +15,5 @@ class AbstractGenerator(type):          return cls._instances[cls]      @abstractmethod -    def generate( -        self, prompt: Prompt, messages: List[Message] -    ) -> Generator[Any, Any, Any]: +    def generate(self, messages: List[Message], documents: List[Document]) -> Generator[Any, Any, Any]:          pass diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py index f30fe69..575452f 100644 --- a/rag/generator/cohere.py +++ b/rag/generator/cohere.py @@ -5,10 +5,10 @@ from typing import Any, Generator, List  import cohere  from loguru import logger as log -from rag.rag import Message +from rag.message import Message +from rag.retriever.vector import Document  from .abstract import AbstractGenerator -from .prompt import Prompt  class Cohere(metaclass=AbstractGenerator): @@ -16,13 +16,13 @@ class Cohere(metaclass=AbstractGenerator):          self.client = cohere.Client(os.environ["COHERE_API_KEY"])      def generate( -        self, prompt: Prompt, messages: List[Message] +        self, messages: List[Message], documents: List[Document]      ) -> Generator[Any, Any, Any]:          log.debug("Generating answer from cohere...")          for event in self.client.chat_stream( -            message=prompt.to_str(), -            documents=[asdict(d) for d in prompt.documents], -            chat_history=[m.as_dict() for m in messages], +            message=messages[-1].content, +            documents=[asdict(d) for d in documents], +            chat_history=[m.as_dict() for m in messages[:-1]],              prompt_truncation="AUTO",          ):              if event.event_type == "text-generation": diff --git a/rag/generator/ollama.py b/rag/generator/ollama.py index ff5402b..84563bb 100644 --- a/rag/generator/ollama.py +++ b/rag/generator/ollama.py @@ -4,10 +4,10 @@ from typing import Any, Generator, List  import ollama  from loguru import logger as log -from rag.rag import Message +from rag.message import Message +from rag.retriever.vector import Document  from .abstract import AbstractGenerator -from .prompt import Prompt  class Ollama(metaclass=AbstractGenerator): @@ -16,12 +16,9 @@ class Ollama(metaclass=AbstractGenerator):          log.debug(f"Using {self.model} for generator...")      def generate( -        self, prompt: Prompt, messages: List[Message] +        self, messages: List[Message], documents: List[Document]      ) -> Generator[Any, Any, Any]:          log.debug("Generating answer with ollama...") -        messages = messages.append( -            Message(role="user", content=prompt.to_str(), client="ollama") -        )          messages = [m.as_dict() for m in messages]          for chunk in ollama.chat(model=self.model, messages=messages, stream=True): -            yield chunk["response"] +            yield chunk["message"]["content"] diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py index 4840fdc..cedf610 100644 --- a/rag/generator/prompt.py +++ b/rag/generator/prompt.py @@ -15,7 +15,7 @@ ANSWER_INSTRUCTION = (  class Prompt:      query: str      documents: List[Document] -    generator_model: str +    client: str      def __context(self, documents: List[Document]) -> str:          results = [ @@ -25,17 +25,15 @@ class Prompt:          return "\n".join(results)      def to_str(self) -> str: -        if self.generator_model == "cohere": +        if self.client == "cohere":              return f"{self.query}\n\n{ANSWER_INSTRUCTION}"          else:              return (                  "Context information is below.\n" -                "---------------------\n" +                "---\n"                  f"{self.__context(self.documents)}\n\n" -                "---------------------\n" +                "---\n"                  f"{ANSWER_INSTRUCTION}" -                "Do not attempt to answer the query without relevant context and do not use" -                " prior knowledge or training data!\n"                  f"Query: {self.query.strip()}\n\n"                  "Answer:"              )  |