summaryrefslogtreecommitdiff
path: root/rag
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-05-29 00:53:39 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-05-29 00:53:39 +0200
commit716e3fe58adee5b8a6bfa91de4b3ba6cf204d172 (patch)
tree778da9011d21051006fc206ce0978f0fc114b77b /rag
parent2d91c118d71a8dd7fbd7f9cf21f86e92da33827e (diff)
Wip memory
Diffstat (limited to 'rag')
-rw-r--r--rag/generator/cohere.py7
-rw-r--r--rag/generator/ollama.py4
-rw-r--r--rag/retriever/memory.py51
-rw-r--r--rag/retriever/rerank/local.py21
-rw-r--r--rag/ui.py14
5 files changed, 79 insertions, 18 deletions
diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py
index 28a87e7..fb0cc5b 100644
--- a/rag/generator/cohere.py
+++ b/rag/generator/cohere.py
@@ -1,6 +1,6 @@
import os
from dataclasses import asdict
-from typing import Any, Generator
+from typing import Any, Dict, Generator, List, Optional
import cohere
from loguru import logger as log
@@ -13,12 +13,15 @@ class Cohere(metaclass=AbstractGenerator):
def __init__(self) -> None:
self.client = cohere.Client(os.environ["COHERE_API_KEY"])
- def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
+ def generate(
+ self, prompt: Prompt, history: Optional[List[Dict[str, str]]]
+ ) -> Generator[Any, Any, Any]:
log.debug("Generating answer from cohere...")
query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}"
for event in self.client.chat_stream(
message=query,
documents=[asdict(d) for d in prompt.documents],
+ chat_history=history,
prompt_truncation="AUTO",
):
if event.event_type == "text-generation":
diff --git a/rag/generator/ollama.py b/rag/generator/ollama.py
index 52521ca..9bf551a 100644
--- a/rag/generator/ollama.py
+++ b/rag/generator/ollama.py
@@ -36,8 +36,8 @@ class Ollama(metaclass=AbstractGenerator):
)
return metaprompt
- def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
+ def generate(self, prompt: Prompt, memory: Memory) -> Generator[Any, Any, Any]:
log.debug("Generating answer with ollama...")
metaprompt = self.__metaprompt(prompt)
- for chunk in ollama.generate(model=self.model, prompt=metaprompt, stream=True):
+ for chunk in ollama.chat(model=self.model, messages=memory.append(metaprompt), stream=True):
yield chunk["response"]
diff --git a/rag/retriever/memory.py b/rag/retriever/memory.py
new file mode 100644
index 0000000..c4455ed
--- /dev/null
+++ b/rag/retriever/memory.py
@@ -0,0 +1,51 @@
+from dataclasses import dataclass
+from typing import Dict, List
+
+
+@dataclass
+class Log:
+ user: Message
+ bot: Message
+
+ def get():
+ return (user, bot)
+
+
+@dataclass
+class Message:
+ role: str
+ message: str
+
+ def as_dict(self, model: str) -> Dict[str, str]:
+ if model == "cohere":
+ match self.role:
+ case "user":
+ role = "USER"
+ case _:
+ role = "CHATBOT"
+
+ return {"role": role, "message": self.message}
+ else:
+ return {"role": self.role, "content": self.message}
+
+
+class Memory:
+ def __init__(self, reranker) -> None:
+ self.history = []
+ self.reranker = reranker
+ self.user = "user"
+ self.bot = "assistant"
+
+ def add(self, prompt: str, response: str):
+ self.history.append(
+ Log(
+ user=Message(role=self.user, message=prompt),
+ bot=Message(role=self.bot, message=response),
+ )
+ )
+
+ def get(self) -> List[Log]:
+ return [m.as_dict() for log in self.history for m in log.get()]
+
+ def reset(self):
+ self.history = []
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py
index 75fedd8..8e94882 100644
--- a/rag/retriever/rerank/local.py
+++ b/rag/retriever/rerank/local.py
@@ -1,9 +1,11 @@
import os
+from typing import List
from loguru import logger as log
from sentence_transformers import CrossEncoder
from rag.generator.prompt import Prompt
+from rag.retriever.memory import Log
from rag.retriever.rerank.abstract import AbstractReranker
@@ -33,3 +35,22 @@ class Reranker(metaclass=AbstractReranker):
prompt.documents[r.get("corpus_id", 0)] for r in ranking
]
return prompt
+
+ def rank_memory(self, prompt: Prompt, history: List[Log]) -> List[Log]:
+ if history:
+ results = self.model.rank(
+ query=prompt.query,
+ documents=[m.bot.message for m in history],
+ return_documents=False,
+ top_k=self.top_k,
+ )
+ ranking = list(
+ filter(
+ lambda x: x.get("score", 0.0) > self.relevance_threshold, results
+ )
+ )
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant messages of {len(history)}"
+ )
+ history = [history[r.get("corpus_id", 0)] for r in ranking]
+ return history
diff --git a/rag/ui.py b/rag/ui.py
index 36e8c4c..ddb3d78 100644
--- a/rag/ui.py
+++ b/rag/ui.py
@@ -13,18 +13,6 @@ from rag.retriever.retriever import Retriever
from rag.retriever.vector import Document
-@dataclass
-class Message:
- role: str
- message: str
-
- def as_dict(self, model: str) -> Dict[str, str]:
- if model == "cohere":
- return {"role": self.role, "message": self.message}
- else:
- return {"role": self.role, "content": self.message}
-
-
def set_chat_users():
log.debug("Setting user and bot value")
ss = st.session_state
@@ -38,13 +26,11 @@ def load_retriever():
st.session_state.retriever = Retriever()
-# @st.cache_resource
def load_generator(model: str):
log.debug("Loading generator model")
st.session_state.generator = get_generator(model)
-# @st.cache_resource
def load_reranker(model: str):
log.debug("Loading reranker model")
st.session_state.reranker = get_reranker(model)