summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--rag/cli.py17
-rw-r--r--rag/generator/__init__.py6
-rw-r--r--rag/generator/abstract.py11
-rw-r--r--rag/generator/cohere.py12
-rw-r--r--rag/generator/ollama.py11
-rw-r--r--rag/generator/prompt.py10
-rw-r--r--rag/message.py15
-rw-r--r--rag/model.py (renamed from rag/rag.py)48
-rw-r--r--rag/retriever/rerank/abstract.py2
-rw-r--r--rag/retriever/rerank/cohere.py10
-rw-r--r--rag/retriever/rerank/local.py2
-rw-r--r--rag/ui.py52
12 files changed, 96 insertions, 100 deletions
diff --git a/rag/cli.py b/rag/cli.py
index 6c4d3e0..070427d 100644
--- a/rag/cli.py
+++ b/rag/cli.py
@@ -7,6 +7,7 @@ from tqdm import tqdm
from rag.generator import get_generator
from rag.generator.prompt import Prompt
+from rag.model import Rag
from rag.retriever.rerank import get_reranker
from rag.retriever.retriever import Retriever
@@ -57,22 +58,20 @@ def upload(directory: str, verbose: int):
prompt="Enter your query",
)
@click.option(
- "-m",
- "--model",
+ "-c",
+ "--client",
type=click.Choice(["local", "cohere"], case_sensitive=False),
default="local",
help="Generator and rerank model",
)
@click.option("-v", "--verbose", count=True)
-def rag(query: str, model: str, verbose: int):
+def rag(query: str, client: str, verbose: int):
configure_logging(verbose)
- retriever = Retriever()
- generator = get_generator(model)
- reranker = get_reranker(model)
- documents = retriever.retrieve(query)
- prompt = reranker.rank(Prompt(query, documents))
+ rag = Rag(client)
+ documents = rag.retrieve(query)
+ prompt = Prompt(query, documents, client)
print("Answer: ")
- for chunk in generator.generate(prompt):
+ for chunk in rag.generate(prompt):
print(chunk, end="", flush=True)
print("\n\n")
diff --git a/rag/generator/__init__.py b/rag/generator/__init__.py
index a776231..770db16 100644
--- a/rag/generator/__init__.py
+++ b/rag/generator/__init__.py
@@ -1,8 +1,8 @@
from typing import Type
-from .abstract import AbstractGenerator
-from .cohere import Cohere
-from .ollama import Ollama
+from rag.generator.abstract import AbstractGenerator
+from rag.generator.cohere import Cohere
+from rag.generator.ollama import Ollama
MODELS = ["local", "cohere"]
diff --git a/rag/generator/abstract.py b/rag/generator/abstract.py
index 995e937..3ce997e 100644
--- a/rag/generator/abstract.py
+++ b/rag/generator/abstract.py
@@ -1,9 +1,8 @@
from abc import abstractmethod
-from typing import Any, Generator
+from typing import Any, Generator, List
-from rag.rag import Message
-
-from .prompt import Prompt
+from rag.message import Message
+from rag.retriever.vector import Document
class AbstractGenerator(type):
@@ -16,7 +15,5 @@ class AbstractGenerator(type):
return cls._instances[cls]
@abstractmethod
- def generate(
- self, prompt: Prompt, messages: List[Message]
- ) -> Generator[Any, Any, Any]:
+ def generate(self, messages: List[Message], documents: List[Document]) -> Generator[Any, Any, Any]:
pass
diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py
index f30fe69..575452f 100644
--- a/rag/generator/cohere.py
+++ b/rag/generator/cohere.py
@@ -5,10 +5,10 @@ from typing import Any, Generator, List
import cohere
from loguru import logger as log
-from rag.rag import Message
+from rag.message import Message
+from rag.retriever.vector import Document
from .abstract import AbstractGenerator
-from .prompt import Prompt
class Cohere(metaclass=AbstractGenerator):
@@ -16,13 +16,13 @@ class Cohere(metaclass=AbstractGenerator):
self.client = cohere.Client(os.environ["COHERE_API_KEY"])
def generate(
- self, prompt: Prompt, messages: List[Message]
+ self, messages: List[Message], documents: List[Document]
) -> Generator[Any, Any, Any]:
log.debug("Generating answer from cohere...")
for event in self.client.chat_stream(
- message=prompt.to_str(),
- documents=[asdict(d) for d in prompt.documents],
- chat_history=[m.as_dict() for m in messages],
+ message=messages[-1].content,
+ documents=[asdict(d) for d in documents],
+ chat_history=[m.as_dict() for m in messages[:-1]],
prompt_truncation="AUTO",
):
if event.event_type == "text-generation":
diff --git a/rag/generator/ollama.py b/rag/generator/ollama.py
index ff5402b..84563bb 100644
--- a/rag/generator/ollama.py
+++ b/rag/generator/ollama.py
@@ -4,10 +4,10 @@ from typing import Any, Generator, List
import ollama
from loguru import logger as log
-from rag.rag import Message
+from rag.message import Message
+from rag.retriever.vector import Document
from .abstract import AbstractGenerator
-from .prompt import Prompt
class Ollama(metaclass=AbstractGenerator):
@@ -16,12 +16,9 @@ class Ollama(metaclass=AbstractGenerator):
log.debug(f"Using {self.model} for generator...")
def generate(
- self, prompt: Prompt, messages: List[Message]
+ self, messages: List[Message], documents: List[Document]
) -> Generator[Any, Any, Any]:
log.debug("Generating answer with ollama...")
- messages = messages.append(
- Message(role="user", content=prompt.to_str(), client="ollama")
- )
messages = [m.as_dict() for m in messages]
for chunk in ollama.chat(model=self.model, messages=messages, stream=True):
- yield chunk["response"]
+ yield chunk["message"]["content"]
diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py
index 4840fdc..cedf610 100644
--- a/rag/generator/prompt.py
+++ b/rag/generator/prompt.py
@@ -15,7 +15,7 @@ ANSWER_INSTRUCTION = (
class Prompt:
query: str
documents: List[Document]
- generator_model: str
+ client: str
def __context(self, documents: List[Document]) -> str:
results = [
@@ -25,17 +25,15 @@ class Prompt:
return "\n".join(results)
def to_str(self) -> str:
- if self.generator_model == "cohere":
+ if self.client == "cohere":
return f"{self.query}\n\n{ANSWER_INSTRUCTION}"
else:
return (
"Context information is below.\n"
- "---------------------\n"
+ "---\n"
f"{self.__context(self.documents)}\n\n"
- "---------------------\n"
+ "---\n"
f"{ANSWER_INSTRUCTION}"
- "Do not attempt to answer the query without relevant context and do not use"
- " prior knowledge or training data!\n"
f"Query: {self.query.strip()}\n\n"
"Answer:"
)
diff --git a/rag/message.py b/rag/message.py
new file mode 100644
index 0000000..d628982
--- /dev/null
+++ b/rag/message.py
@@ -0,0 +1,15 @@
+from dataclasses import dataclass
+from typing import Dict
+
+
+@dataclass
+class Message:
+ role: str
+ content: str
+ client: str
+
+ def as_dict(self) -> Dict[str, str]:
+ if self.client == "cohere":
+ return {"role": self.role, "message": self.content}
+ else:
+ return {"role": self.role, "content": self.content}
diff --git a/rag/rag.py b/rag/model.py
index 1f6a176..b186d43 100644
--- a/rag/rag.py
+++ b/rag/model.py
@@ -1,43 +1,32 @@
-from dataclasses import dataclass
-from typing import Any, Dict, Generator, List
+from typing import Any, Generator, List
from loguru import logger as log
from rag.generator import get_generator
from rag.generator.prompt import Prompt
+from rag.message import Message
from rag.retriever.rerank import get_reranker
from rag.retriever.retriever import Retriever
from rag.retriever.vector import Document
-@dataclass
-class Message:
- role: str
- content: str
- client: str
-
- def as_dict(self) -> Dict[str, str]:
- if self.client == "cohere":
- return {"role": self.role, "message": self.content}
- else:
- return {"role": self.role, "content": self.content}
-
-
class Rag:
- def __init__(self, client: str) -> None:
+ def __init__(self, client: str = "local") -> None:
+ self.bot = None
+ self.user = None
+ self.client = client
self.messages: List[Message] = []
self.retriever = Retriever()
- self.client = client
self.reranker = get_reranker(self.client)
self.generator = get_generator(self.client)
- self.bot = "assistant" if self.client == "ollama" else "CHATBOT"
- self.user = "user" if self.client == "ollama" else "USER"
+ self.__set_roles()
def __set_roles(self):
- self.bot = "assistant" if self.client == "ollama" else "CHATBOT"
- self.user = "user" if self.client == "ollama" else "USER"
+ self.bot = "assistant" if self.client == "local" else "CHATBOT"
+ self.user = "user" if self.client == "local" else "USER"
def set_client(self, client: str):
+ log.info(f"Setting client to {client}")
self.client = client
self.reranker = get_reranker(self.client)
self.generator = get_generator(self.client)
@@ -55,15 +44,14 @@ class Rag:
return self.reranker.rerank_documents(query, documents)
def add_message(self, role: str, content: str):
- self.messages.append(
- Message(role=role, content=content, client=self.client)
- )
+ self.messages.append(Message(role=role, content=content, client=self.client))
def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
- messages = self.reranker.rerank_messages(prompt.query, self.messages)
- self.messages.append(
- Message(
- role=self.user, content=prompt.to_str(), client=self.client
- )
+ if self.messages:
+ messages = self.reranker.rerank_messages(prompt.query, self.messages)
+ else:
+ messages = []
+ messages.append(
+ Message(role=self.user, content=prompt.to_str(), client=self.client)
)
- return self.generator.generate(prompt, messages)
+ return self.generator.generate(messages, prompt.documents)
diff --git a/rag/retriever/rerank/abstract.py b/rag/retriever/rerank/abstract.py
index f32ee77..015a60d 100644
--- a/rag/retriever/rerank/abstract.py
+++ b/rag/retriever/rerank/abstract.py
@@ -1,7 +1,7 @@
from abc import abstractmethod
from typing import List
-from rag.memory import Message
+from rag.message import Message
from rag.retriever.vector import Document
diff --git a/rag/retriever/rerank/cohere.py b/rag/retriever/rerank/cohere.py
index 33c373d..52f31a8 100644
--- a/rag/retriever/rerank/cohere.py
+++ b/rag/retriever/rerank/cohere.py
@@ -4,7 +4,7 @@ from typing import List
import cohere
from loguru import logger as log
-from rag.rag import Message
+from rag.message import Message
from rag.retriever.rerank.abstract import AbstractReranker
from rag.retriever.vector import Document
@@ -35,11 +35,11 @@ class CohereReranker(metaclass=AbstractReranker):
def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]:
- response = self.model.rank(
+ response = self.client.rerank(
+ model="rerank-english-v3.0",
query=query,
- documents=[m.message for m in messages],
- return_documents=False,
- top_k=self.top_k,
+ documents=[m.content for m in messages],
+ top_n=self.top_k,
)
ranking = list(
filter(
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py
index e727165..fd42c2c 100644
--- a/rag/retriever/rerank/local.py
+++ b/rag/retriever/rerank/local.py
@@ -2,7 +2,7 @@ import os
from typing import List
from loguru import logger as log
-from rag.rag import Message
+from rag.message import Message
from rag.retriever.vector import Document
from sentence_transformers import CrossEncoder
diff --git a/rag/ui.py b/rag/ui.py
index 9a0f2cf..2192ad8 100644
--- a/rag/ui.py
+++ b/rag/ui.py
@@ -1,15 +1,14 @@
-from dataclasses import dataclass
-from typing import Dict, List
+from typing import List
import streamlit as st
from dotenv import load_dotenv
from langchain_community.document_loaders.blob_loaders import Blob
from loguru import logger as log
-from rag.generator import MODELS, get_generator
+from rag.generator import MODELS
from rag.generator.prompt import Prompt
-from rag.retriever.rerank import get_reranker
-from rag.retriever.retriever import Retriever
+from rag.message import Message
+from rag.model import Rag
from rag.retriever.vector import Document
@@ -19,25 +18,27 @@ def set_chat_users():
ss.user = "user"
ss.bot = "assistant"
+@st.cache_resource
+def load_rag():
+ log.debug("Loading Rag...")
+ st.session_state.rag = Rag()
-def load_generator(model: str):
- log.debug("Loading rag")
- st.session_state.rag = get_generator(model)
-
-def load_reranker(model: str):
- log.debug("Loading reranker model")
- st.session_state.reranker = get_reranker(model)
+@st.cache_resource
+def set_client(client: str):
+ log.debug("Setting client...")
+ rag = st.session_state.rag
+ rag.set_client(client)
@st.cache_data(show_spinner=False)
def upload(files):
- retriever = st.session_state.retriever
+ rag = st.session_state.rag
with st.spinner("Uploading documents..."):
for file in files:
source = file.name
blob = Blob.from_data(file.read())
- retriever.add_pdf(blob=blob, source=source)
+ rag.retriever.add_pdf(blob=blob, source=source)
def display_context(documents: List[Document]):
@@ -55,7 +56,7 @@ def display_chat():
if isinstance(msg, list):
display_context(msg)
else:
- st.chat_message(msg.role).write(msg.message)
+ st.chat_message(msg.role).write(msg.content)
def generate_chat(query: str):
@@ -66,22 +67,24 @@ def generate_chat(query: str):
rag = ss.rag
documents = rag.retrieve(query)
- Prompt(query, documents, self.client)
+ prompt = Prompt(query, documents, ss.model)
with st.chat_message(ss.bot):
- response = st.write_stream(rag.generate(query))
+ response = st.write_stream(rag.generate(prompt))
+
+ rag.add_message(rag.bot, response)
display_context(prompt.documents)
- store_chat(query, response, prompt.documents)
+ store_chat(prompt, response)
-def store_chat(query: str, response: str, documents: List[Document]):
+def store_chat(prompt: Prompt, response: str):
log.debug("Storing chat")
ss = st.session_state
- query = Message(role=ss.user, message=query)
- response = Message(role=ss.bot, message=response)
+ query = Message(ss.user, prompt.query, ss.model)
+ response = Message(ss.bot, response, ss.model)
ss.chat.append(query)
ss.chat.append(response)
- ss.chat.append(documents)
+ ss.chat.append(prompt.documents)
def sidebar():
@@ -107,8 +110,7 @@ def sidebar():
"Select the model that will be used for reranking and generating the answer."
)
st.selectbox("Model", key="model", options=MODELS)
- load_generator(st.session_state.model)
- load_reranker(st.session_state.model)
+ set_client(st.session_state.model)
def page():
@@ -127,6 +129,6 @@ if __name__ == "__main__":
load_dotenv()
st.title("Retrieval Augmented Generation")
set_chat_users()
- load_retriever()
+ load_rag()
sidebar()
page()