summaryrefslogtreecommitdiff
path: root/rag/generator/ollama.py
blob: ff5402bc6aa47229cfb78454be8dbaa4dda835c8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import os
from typing import Any, Generator, List

import ollama
from loguru import logger as log

from rag.rag import Message

from .abstract import AbstractGenerator
from .prompt import Prompt


class Ollama(metaclass=AbstractGenerator):
    def __init__(self) -> None:
        self.model = os.environ["GENERATOR_MODEL"]
        log.debug(f"Using {self.model} for generator...")

    def generate(
        self, prompt: Prompt, messages: List[Message]
    ) -> Generator[Any, Any, Any]:
        log.debug("Generating answer with ollama...")
        messages = messages.append(
            Message(role="user", content=prompt.to_str(), client="ollama")
        )
        messages = [m.as_dict() for m in messages]
        for chunk in ollama.chat(model=self.model, messages=messages, stream=True):
            yield chunk["response"]