summaryrefslogtreecommitdiff
path: root/rag/generator
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-09 00:41:55 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-09 00:41:55 +0200
commit3f447bff69c20109474c455f1ad52bd547ab49e9 (patch)
tree66695ca0e3423e2973c5b24ec1ae7096019b5dd0 /rag/generator
parentc05eae81f9aaa0a764203446ec54d7dd7cbeb66f (diff)
Update
Diffstat (limited to 'rag/generator')
-rw-r--r--rag/generator/__init__.py1
-rw-r--r--rag/generator/abstract.py4
2 files changed, 3 insertions, 2 deletions
diff --git a/rag/generator/__init__.py b/rag/generator/__init__.py
index 7da603c..541eff8 100644
--- a/rag/generator/__init__.py
+++ b/rag/generator/__init__.py
@@ -4,6 +4,7 @@ from .abstract import AbstractGenerator
from .ollama import Ollama
from .cohere import Cohere
+MODELS = ["ollama", "cohere"]
def get_generator(model: str) -> Type[AbstractGenerator]:
match model:
diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py
index a53b5d8..5b336ea 100644
--- a/rag/generator/abstract.py
+++ b/rag/generator/abstract.py
@@ -1,11 +1,11 @@
-from abc import ABC, abstractmethod
+from abc import abstractmethod
from typing import Any, Generator
from .prompt import Prompt
-class AbstractGenerator(ABC, type):
+class AbstractGenerator(type):
_instances = {}
def __call__(cls, *args, **kwargs):