summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-06-18 01:37:32 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-06-18 01:37:32 +0200
commitb1ff0c55422d7b0af2c379679b8721014ef36926 (patch)
tree52aa88b2a8a0bba07f968c6ae24c002ce2d44226
parentb8c6a78f70d84f3360461aa91864e8538569d450 (diff)
Wip rewrite
-rw-r--r--rag/generator/abstract.py6
-rw-r--r--rag/generator/cohere.py13
-rw-r--r--rag/generator/ollama.py36
-rw-r--r--rag/generator/prompt.py24
-rw-r--r--rag/rag.py69
-rw-r--r--rag/retriever/memory.py51
-rw-r--r--rag/retriever/rerank/abstract.py12
-rw-r--r--rag/retriever/rerank/cohere.py55
-rw-r--r--rag/retriever/rerank/local.py70
-rw-r--r--rag/ui.py23
10 files changed, 196 insertions, 163 deletions
diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py
index 1beacfb..995e937 100644
--- a/rag/generator/abstract.py
+++ b/rag/generator/abstract.py
@@ -1,6 +1,8 @@
from abc import abstractmethod
from typing import Any, Generator
+from rag.rag import Message
+
from .prompt import Prompt
@@ -14,5 +16,7 @@ class AbstractGenerator(type):
return cls._instances[cls]
@abstractmethod
- def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
+ def generate(
+ self, prompt: Prompt, messages: List[Message]
+ ) -> Generator[Any, Any, Any]:
pass
diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py
index fb0cc5b..f30fe69 100644
--- a/rag/generator/cohere.py
+++ b/rag/generator/cohere.py
@@ -1,12 +1,14 @@
import os
from dataclasses import asdict
-from typing import Any, Dict, Generator, List, Optional
+from typing import Any, Generator, List
import cohere
from loguru import logger as log
+from rag.rag import Message
+
from .abstract import AbstractGenerator
-from .prompt import ANSWER_INSTRUCTION, Prompt
+from .prompt import Prompt
class Cohere(metaclass=AbstractGenerator):
@@ -14,14 +16,13 @@ class Cohere(metaclass=AbstractGenerator):
self.client = cohere.Client(os.environ["COHERE_API_KEY"])
def generate(
- self, prompt: Prompt, history: Optional[List[Dict[str, str]]]
+ self, prompt: Prompt, messages: List[Message]
) -> Generator[Any, Any, Any]:
log.debug("Generating answer from cohere...")
- query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}"
for event in self.client.chat_stream(
- message=query,
+ message=prompt.to_str(),
documents=[asdict(d) for d in prompt.documents],
- chat_history=history,
+ chat_history=[m.as_dict() for m in messages],
prompt_truncation="AUTO",
):
if event.event_type == "text-generation":
diff --git a/rag/generator/ollama.py b/rag/generator/ollama.py
index 9bf551a..ff5402b 100644
--- a/rag/generator/ollama.py
+++ b/rag/generator/ollama.py
@@ -4,10 +4,10 @@ from typing import Any, Generator, List
import ollama
from loguru import logger as log
-from rag.retriever.vector import Document
+from rag.rag import Message
from .abstract import AbstractGenerator
-from .prompt import ANSWER_INSTRUCTION, Prompt
+from .prompt import Prompt
class Ollama(metaclass=AbstractGenerator):
@@ -15,29 +15,13 @@ class Ollama(metaclass=AbstractGenerator):
self.model = os.environ["GENERATOR_MODEL"]
log.debug(f"Using {self.model} for generator...")
- def __context(self, documents: List[Document]) -> str:
- results = [
- f"Document: {i}\ntitle: {doc.title}\ntext: {doc.text}"
- for i, doc in enumerate(documents)
- ]
- return "\n".join(results)
-
- def __metaprompt(self, prompt: Prompt) -> str:
- metaprompt = (
- "Context information is below.\n"
- "---------------------\n"
- f"{self.__context(prompt.documents)}\n\n"
- "---------------------\n"
- f"{ANSWER_INSTRUCTION}"
- "Do not attempt to answer the query without relevant context and do not use"
- " prior knowledge or training data!\n"
- f"Query: {prompt.query.strip()}\n\n"
- "Answer:"
- )
- return metaprompt
-
- def generate(self, prompt: Prompt, memory: Memory) -> Generator[Any, Any, Any]:
+ def generate(
+ self, prompt: Prompt, messages: List[Message]
+ ) -> Generator[Any, Any, Any]:
log.debug("Generating answer with ollama...")
- metaprompt = self.__metaprompt(prompt)
- for chunk in ollama.chat(model=self.model, messages=memory.append(metaprompt), stream=True):
+ messages = messages.append(
+ Message(role="user", content=prompt.to_str(), client="ollama")
+ )
+ messages = [m.as_dict() for m in messages]
+ for chunk in ollama.chat(model=self.model, messages=messages, stream=True):
yield chunk["response"]
diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py
index 6523842..4840fdc 100644
--- a/rag/generator/prompt.py
+++ b/rag/generator/prompt.py
@@ -15,3 +15,27 @@ ANSWER_INSTRUCTION = (
class Prompt:
query: str
documents: List[Document]
+ generator_model: str
+
+ def __context(self, documents: List[Document]) -> str:
+ results = [
+ f"Document: {i}\ntitle: {doc.title}\ntext: {doc.text}"
+ for i, doc in enumerate(documents)
+ ]
+ return "\n".join(results)
+
+ def to_str(self) -> str:
+ if self.generator_model == "cohere":
+ return f"{self.query}\n\n{ANSWER_INSTRUCTION}"
+ else:
+ return (
+ "Context information is below.\n"
+ "---------------------\n"
+ f"{self.__context(self.documents)}\n\n"
+ "---------------------\n"
+ f"{ANSWER_INSTRUCTION}"
+ "Do not attempt to answer the query without relevant context and do not use"
+ " prior knowledge or training data!\n"
+ f"Query: {self.query.strip()}\n\n"
+ "Answer:"
+ )
diff --git a/rag/rag.py b/rag/rag.py
new file mode 100644
index 0000000..1f6a176
--- /dev/null
+++ b/rag/rag.py
@@ -0,0 +1,69 @@
+from dataclasses import dataclass
+from typing import Any, Dict, Generator, List
+
+from loguru import logger as log
+
+from rag.generator import get_generator
+from rag.generator.prompt import Prompt
+from rag.retriever.rerank import get_reranker
+from rag.retriever.retriever import Retriever
+from rag.retriever.vector import Document
+
+
+@dataclass
+class Message:
+ role: str
+ content: str
+ client: str
+
+ def as_dict(self) -> Dict[str, str]:
+ if self.client == "cohere":
+ return {"role": self.role, "message": self.content}
+ else:
+ return {"role": self.role, "content": self.content}
+
+
+class Rag:
+ def __init__(self, client: str) -> None:
+ self.messages: List[Message] = []
+ self.retriever = Retriever()
+ self.client = client
+ self.reranker = get_reranker(self.client)
+ self.generator = get_generator(self.client)
+ self.bot = "assistant" if self.client == "ollama" else "CHATBOT"
+ self.user = "user" if self.client == "ollama" else "USER"
+
+ def __set_roles(self):
+ self.bot = "assistant" if self.client == "ollama" else "CHATBOT"
+ self.user = "user" if self.client == "ollama" else "USER"
+
+ def set_client(self, client: str):
+ self.client = client
+ self.reranker = get_reranker(self.client)
+ self.generator = get_generator(self.client)
+ self.__set_roles()
+ self.__reset_messages()
+ log.debug(f"Swapped client to {self.client}")
+
+ def __reset_messages(self):
+ log.debug("Deleting messages...")
+ self.messages = []
+
+ def retrieve(self, query: str) -> List[Document]:
+ documents = self.retriever.retrieve(query)
+ log.info(f"Found {len(documents)} relevant documents")
+ return self.reranker.rerank_documents(query, documents)
+
+ def add_message(self, role: str, content: str):
+ self.messages.append(
+ Message(role=role, content=content, client=self.client)
+ )
+
+ def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
+ messages = self.reranker.rerank_messages(prompt.query, self.messages)
+ self.messages.append(
+ Message(
+ role=self.user, content=prompt.to_str(), client=self.client
+ )
+ )
+ return self.generator.generate(prompt, messages)
diff --git a/rag/retriever/memory.py b/rag/retriever/memory.py
deleted file mode 100644
index c4455ed..0000000
--- a/rag/retriever/memory.py
+++ /dev/null
@@ -1,51 +0,0 @@
-from dataclasses import dataclass
-from typing import Dict, List
-
-
-@dataclass
-class Log:
- user: Message
- bot: Message
-
- def get():
- return (user, bot)
-
-
-@dataclass
-class Message:
- role: str
- message: str
-
- def as_dict(self, model: str) -> Dict[str, str]:
- if model == "cohere":
- match self.role:
- case "user":
- role = "USER"
- case _:
- role = "CHATBOT"
-
- return {"role": role, "message": self.message}
- else:
- return {"role": self.role, "content": self.message}
-
-
-class Memory:
- def __init__(self, reranker) -> None:
- self.history = []
- self.reranker = reranker
- self.user = "user"
- self.bot = "assistant"
-
- def add(self, prompt: str, response: str):
- self.history.append(
- Log(
- user=Message(role=self.user, message=prompt),
- bot=Message(role=self.bot, message=response),
- )
- )
-
- def get(self) -> List[Log]:
- return [m.as_dict() for log in self.history for m in log.get()]
-
- def reset(self):
- self.history = []
diff --git a/rag/retriever/rerank/abstract.py b/rag/retriever/rerank/abstract.py
index b96b70a..f32ee77 100644
--- a/rag/retriever/rerank/abstract.py
+++ b/rag/retriever/rerank/abstract.py
@@ -1,6 +1,8 @@
from abc import abstractmethod
+from typing import List
-from rag.generator.prompt import Prompt
+from rag.memory import Message
+from rag.retriever.vector import Document
class AbstractReranker(type):
@@ -13,5 +15,9 @@ class AbstractReranker(type):
return cls._instances[cls]
@abstractmethod
- def rank(self, prompt: Prompt) -> Prompt:
- return prompt
+ def rerank_documents(self, query: str, documents: List[Document]) -> List[Document]:
+ pass
+
+ @abstractmethod
+ def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]:
+ pass
diff --git a/rag/retriever/rerank/cohere.py b/rag/retriever/rerank/cohere.py
index 43690a1..33c373d 100644
--- a/rag/retriever/rerank/cohere.py
+++ b/rag/retriever/rerank/cohere.py
@@ -1,10 +1,12 @@
import os
+from typing import List
import cohere
from loguru import logger as log
-from rag.generator.prompt import Prompt
+from rag.rag import Message
from rag.retriever.rerank.abstract import AbstractReranker
+from rag.retriever.vector import Document
class CohereReranker(metaclass=AbstractReranker):
@@ -13,22 +15,39 @@ class CohereReranker(metaclass=AbstractReranker):
self.top_k = int(os.environ["RERANK_TOP_K"])
self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"])
- def rank(self, prompt: Prompt) -> Prompt:
- if prompt.documents:
- response = self.client.rerank(
- model="rerank-english-v3.0",
- query=prompt.query,
- documents=[d.text for d in prompt.documents],
- top_n=self.top_k,
+ def rerank_documents(self, query: str, documents: List[Document]) -> List[str]:
+ response = self.client.rerank(
+ model="rerank-english-v3.0",
+ query=query,
+ documents=[d.text for d in documents],
+ top_n=self.top_k,
+ )
+ ranking = list(
+ filter(
+ lambda x: x.relevance_score > self.relevance_threshold,
+ response.results,
)
- ranking = list(
- filter(
- lambda x: x.relevance_score > self.relevance_threshold,
- response.results,
- )
- )
- log.debug(
- f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}"
+ )
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant documents of {len(documents)}"
+ )
+ return [documents[r.index] for r in ranking]
+
+
+ def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]:
+ response = self.model.rank(
+ query=query,
+ documents=[m.message for m in messages],
+ return_documents=False,
+ top_k=self.top_k,
+ )
+ ranking = list(
+ filter(
+ lambda x: x.relevance_score > self.relevance_threshold,
+ response.results,
)
- prompt.documents = [prompt.documents[r.index] for r in ranking]
- return prompt
+ )
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}"
+ )
+ return [messages[r.index] for r in ranking]
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py
index 8e94882..e727165 100644
--- a/rag/retriever/rerank/local.py
+++ b/rag/retriever/rerank/local.py
@@ -2,10 +2,10 @@ import os
from typing import List
from loguru import logger as log
+from rag.rag import Message
+from rag.retriever.vector import Document
from sentence_transformers import CrossEncoder
-from rag.generator.prompt import Prompt
-from rag.retriever.memory import Log
from rag.retriever.rerank.abstract import AbstractReranker
@@ -15,42 +15,32 @@ class Reranker(metaclass=AbstractReranker):
self.top_k = int(os.environ["RERANK_TOP_K"])
self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"])
- def rank(self, prompt: Prompt) -> Prompt:
- if prompt.documents:
- results = self.model.rank(
- query=prompt.query,
- documents=[d.text for d in prompt.documents],
- return_documents=False,
- top_k=self.top_k,
- )
- ranking = list(
- filter(
- lambda x: x.get("score", 0.0) > self.relevance_threshold, results
- )
- )
- log.debug(
- f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}"
- )
- prompt.documents = [
- prompt.documents[r.get("corpus_id", 0)] for r in ranking
- ]
- return prompt
+ def rerank_documents(self, query: str, documents: List[Document]) -> List[str]:
+ results = self.model.rank(
+ query=query,
+ documents=[d.text for d in documents],
+ return_documents=False,
+ top_k=self.top_k,
+ )
+ ranking = list(
+ filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results)
+ )
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant documents of {len(documents)}"
+ )
+ return [documents[r.get("corpus_id", 0)] for r in ranking]
- def rank_memory(self, prompt: Prompt, history: List[Log]) -> List[Log]:
- if history:
- results = self.model.rank(
- query=prompt.query,
- documents=[m.bot.message for m in history],
- return_documents=False,
- top_k=self.top_k,
- )
- ranking = list(
- filter(
- lambda x: x.get("score", 0.0) > self.relevance_threshold, results
- )
- )
- log.debug(
- f"Reranking gave {len(ranking)} relevant messages of {len(history)}"
- )
- history = [history[r.get("corpus_id", 0)] for r in ranking]
- return history
+ def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]:
+ results = self.model.rank(
+ query=query,
+ documents=[m.message for m in messages],
+ return_documents=False,
+ top_k=self.top_k,
+ )
+ ranking = list(
+ filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results)
+ )
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}"
+ )
+ return [messages[r.get("corpus_id", 0)] for r in ranking]
diff --git a/rag/ui.py b/rag/ui.py
index ddb3d78..a453f47 100644
--- a/rag/ui.py
+++ b/rag/ui.py
@@ -20,15 +20,9 @@ def set_chat_users():
ss.bot = "assistant"
-@st.cache_resource
-def load_retriever():
- log.debug("Loading retriever model")
- st.session_state.retriever = Retriever()
-
-
def load_generator(model: str):
- log.debug("Loading generator model")
- st.session_state.generator = get_generator(model)
+ log.debug("Loading rag")
+ st.session_state.rag = get_generator(model)
def load_reranker(model: str):
@@ -70,17 +64,10 @@ def generate_chat(query: str):
with st.chat_message(ss.user):
st.write(query)
- retriever = ss.retriever
- generator = ss.generator
- reranker = ss.reranker
-
- documents = retriever.retrieve(query)
- prompt = Prompt(query, documents)
-
- prompt = reranker.rank(prompt)
-
+ rag = ss.rag
+ prompt = rag.retrieve(query)
with st.chat_message(ss.bot):
- response = st.write_stream(generator.generate(prompt))
+ response = st.write_stream(rag.generate(query))
display_context(prompt.documents)
store_chat(query, response, prompt.documents)