summaryrefslogtreecommitdiff
path: root/rag
diff options
context:
space:
mode:
Diffstat (limited to 'rag')
-rw-r--r--rag/cli.py11
-rw-r--r--rag/generator/prompt.py9
-rw-r--r--rag/mcp/__init__.py0
-rw-r--r--rag/mcp/client.py0
-rw-r--r--rag/mcp/server.py0
-rw-r--r--rag/message.py25
-rw-r--r--rag/model.py26
-rw-r--r--rag/retriever/document.py2
-rw-r--r--rag/retriever/encoder.py42
-rw-r--r--rag/retriever/parser/pdf.py6
-rw-r--r--rag/retriever/rerank/local.py45
-rw-r--r--rag/retriever/retriever.py57
-rw-r--r--rag/retriever/vector.py32
-rw-r--r--rag/static/styles.tcss48
-rw-r--r--rag/tui.py71
-rw-r--r--rag/ui.py134
-rw-r--r--rag/workflows.py7
17 files changed, 285 insertions, 230 deletions
diff --git a/rag/cli.py b/rag/cli.py
index 9d54549..0366be9 100644
--- a/rag/cli.py
+++ b/rag/cli.py
@@ -7,6 +7,7 @@ from tqdm import tqdm
from rag.generator.prompt import Prompt
from rag.model import Rag
+from rag.retriever.retriever import FilePath
from rag.retriever.retriever import Retriever
@@ -38,13 +39,13 @@ def cli():
default=None,
)
@click.option("-v", "--verbose", count=True)
-def upload(directory: str, verbose: int):
+def index(directory: str, verbose: int):
configure_logging(verbose)
log.info(f"Uploading pfs found in directory {directory}...")
retriever = Retriever()
pdfs = Path(directory).glob("**/*.pdf")
for path in tqdm(list(pdfs)):
- retriever.add_pdf(path=path)
+ retriever.index(FilePath(path))
@click.command()
@@ -56,7 +57,7 @@ def upload(directory: str, verbose: int):
help="Generator and rerank model",
)
@click.option("-v", "--verbose", count=True)
-def rag(client: str, verbose: int):
+def search(client: str, verbose: int):
configure_logging(verbose)
rag = Rag(client)
while True:
@@ -92,8 +93,8 @@ def drop():
vec_db.delete_collection()
-cli.add_command(rag)
-cli.add_command(upload)
+cli.add_command(search)
+cli.add_command(index)
cli.add_command(drop)
if __name__ == "__main__":
diff --git a/rag/generator/prompt.py b/rag/generator/prompt.py
index 75966e8..16ea447 100644
--- a/rag/generator/prompt.py
+++ b/rag/generator/prompt.py
@@ -4,10 +4,8 @@ from typing import List
from rag.retriever.vector import Document
ANSWER_INSTRUCTION = (
- "Do not attempt to answer the query without relevant context, and do not use "
- "prior knowledge or training data!\n"
- "If the context does not contain the answer or is empty, only reply that you "
- "cannot answer the query given the context."
+ "Using the information contained in the context, give a comprehensive answer to the question.\n"
+ "If the answer cannot be deduced from the context, do not give an answer.\n\n"
)
@@ -30,8 +28,7 @@ class Prompt:
else:
return (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n"
- "Using the information contained in the context, give a comprehensive answer to the question.\n"
- "If the answer cannot be deduced from the context, do not give an answer.\n\n"
+ f"{ANSWER_INSTRUCTION}"
"Context:\n"
"---\n"
f"{self.__context(self.documents)}\n\n"
diff --git a/rag/mcp/__init__.py b/rag/mcp/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/rag/mcp/__init__.py
diff --git a/rag/mcp/client.py b/rag/mcp/client.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/rag/mcp/client.py
diff --git a/rag/mcp/server.py b/rag/mcp/server.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/rag/mcp/server.py
diff --git a/rag/message.py b/rag/message.py
index d628982..d207596 100644
--- a/rag/message.py
+++ b/rag/message.py
@@ -1,5 +1,7 @@
from dataclasses import dataclass
-from typing import Dict
+from typing import Dict, List
+
+from loguru import logger as log
@dataclass
@@ -13,3 +15,24 @@ class Message:
return {"role": self.role, "message": self.content}
else:
return {"role": self.role, "content": self.content}
+
+
+@dataclass
+class Messages:
+ messages: List[Message]
+
+ def __add__(self, message: Message):
+ self.messages.append(message)
+
+ def __len__(self):
+ return len(self.messages)
+
+ def reset(self):
+ log.debug("Resetting messages...")
+ self.messages = []
+
+ def content(self) -> List[str]:
+ return [m.content for m in self.messages]
+
+ def rerank(self, rankings: List[int]):
+ self.messages = [self.messages[r] for r in rankings]
diff --git a/rag/model.py b/rag/model.py
index b186d43..c6df9e2 100644
--- a/rag/model.py
+++ b/rag/model.py
@@ -1,10 +1,11 @@
-from typing import Any, Generator, List
+from typing import Any, Generator, List, Optional
from loguru import logger as log
from rag.generator import get_generator
from rag.generator.prompt import Prompt
-from rag.message import Message
+from rag.message import Message, Messages
+from rag.retriever.encoder import Query
from rag.retriever.rerank import get_reranker
from rag.retriever.retriever import Retriever
from rag.retriever.vector import Document
@@ -21,6 +22,7 @@ class Rag:
self.generator = get_generator(self.client)
self.__set_roles()
+ # TODO: move this to messages
def __set_roles(self):
self.bot = "assistant" if self.client == "local" else "CHATBOT"
self.user = "user" if self.client == "local" else "USER"
@@ -34,21 +36,21 @@ class Rag:
self.__reset_messages()
log.debug(f"Swapped client to {self.client}")
- def __reset_messages(self):
- log.debug("Deleting messages...")
- self.messages = []
-
- def retrieve(self, query: str) -> List[Document]:
- documents = self.retriever.retrieve(query)
+ def retrieve(self, query: Query) -> List[Document]:
+ documents = self.retriever.search(query)
log.info(f"Found {len(documents)} relevant documents")
- return self.reranker.rerank_documents(query, documents)
+ return self.reranker.rerank(query, documents)
+ # TODO: move this to messages
def add_message(self, role: str, content: str):
self.messages.append(Message(role=role, content=content, client=self.client))
- def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
- if self.messages:
- messages = self.reranker.rerank_messages(prompt.query, self.messages)
+ # TODO: fix the reranking
+ def generate(
+ self, prompt: Prompt, messages: Optional[Messages] = None
+ ) -> Generator[Any, Any, Any]:
+ if messages:
+ messages = self.reranker.rerank(prompt.query, self.messages)
else:
messages = []
messages.append(
diff --git a/rag/retriever/document.py b/rag/retriever/document.py
index 132ec4b..df7a057 100644
--- a/rag/retriever/document.py
+++ b/rag/retriever/document.py
@@ -44,7 +44,7 @@ class DocumentDB:
)
self.conn.commit()
- def add(self, blob: Blob) -> bool:
+ def create(self, blob: Blob) -> bool:
with self.conn.cursor() as cur:
hash = self.__hash(blob)
cur.execute(
diff --git a/rag/retriever/encoder.py b/rag/retriever/encoder.py
index b68c3bb..8b02a14 100644
--- a/rag/retriever/encoder.py
+++ b/rag/retriever/encoder.py
@@ -1,7 +1,8 @@
+from dataclasses import dataclass
import hashlib
import os
from pathlib import Path
-from typing import Dict, List
+from typing import Dict, List, Union
import ollama
from langchain_core.documents import Document
@@ -9,29 +10,42 @@ from loguru import logger as log
from qdrant_client.http.models import StrictFloat
from tqdm import tqdm
-from .vector import Point
+from .vector import Documents, Point
+
+@dataclass
+class Query:
+ query: str
+
+
+Input = Query | Documents
class Encoder:
def __init__(self) -> None:
self.model = os.environ["ENCODER_MODEL"]
- self.query_prompt = "Represent this sentence for searching relevant passages: "
-
- def __encode(self, prompt: str) -> List[StrictFloat]:
- return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"])
+ self.preamble = (
+ "Represent this sentence for searching relevant passages: "
+ if "mxbai-embed-large" in model_name
+ else ""
+ )
def __get_source(self, metadata: Dict[str, str]) -> str:
source = metadata["source"]
return Path(source).name
- def encode_document(self, chunks: List[Document]) -> List[Point]:
+ def __encode(self, prompt: str) -> List[StrictFloat]:
+ return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"])
+
+ # TODO: move this to vec db and just return the embeddings
+ # TODO: use late chunking here
+ def __encode_document(self, chunks: List[Document]) -> List[Point]:
log.debug("Encoding document...")
return [
Point(
id=hashlib.sha256(
chunk.page_content.encode(encoding="utf-8")
).hexdigest(),
- vector=self.__encode(chunk.page_content),
+ vector=list(self.__encode(chunk.page_content)),
payload={
"text": chunk.page_content,
"source": self.__get_source(chunk.metadata),
@@ -40,8 +54,14 @@ class Encoder:
for chunk in tqdm(chunks)
]
- def encode_query(self, query: str) -> List[StrictFloat]:
+ def __encode_query(self, query: str) -> List[StrictFloat]:
log.debug(f"Encoding query: {query}")
- if self.model == "mxbai-embed-large":
- query = self.query_prompt + query
+ query = self.preamble + query
return self.__encode(query)
+
+ def encode(self, x: Input) -> Union[List[StrictFloat], List[Point]]:
+ match x:
+ case Query(query):
+ return self.__encode_query(query)
+ case Documents(documents):
+ return self.__encode_document(documents)
diff --git a/rag/retriever/parser/pdf.py b/rag/retriever/parser/pdf.py
index 4c5addc..3253dc1 100644
--- a/rag/retriever/parser/pdf.py
+++ b/rag/retriever/parser/pdf.py
@@ -8,8 +8,10 @@ from langchain_community.document_loaders.parsers.pdf import (
PyPDFParser,
)
from langchain_core.documents import Document
+from rag.retriever.encoder import Chunks
+# TODO: fix the PDFParser, remove langchain
class PDFParser:
def __init__(self) -> None:
self.parser = PyPDFParser(password=None, extract_images=False)
@@ -22,7 +24,7 @@ class PDFParser:
def chunk(
self, document: List[Document], source: Optional[str] = None
- ) -> List[Document]:
+ ) -> Chunks:
splitter = RecursiveCharacterTextSplitter(
chunk_size=int(os.environ["CHUNK_SIZE"]),
chunk_overlap=int(os.environ["CHUNK_OVERLAP"]),
@@ -31,4 +33,4 @@ class PDFParser:
if source is not None:
for c in chunks:
c.metadata["source"] = source
- return chunks
+ return Chunks(chunks)
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py
index 231d50a..e2bef31 100644
--- a/rag/retriever/rerank/local.py
+++ b/rag/retriever/rerank/local.py
@@ -1,46 +1,39 @@
import os
-from typing import List
from loguru import logger as log
from sentence_transformers import CrossEncoder
-from rag.message import Message
+from rag.message import Messages
+from rag.retriever.encoder import Query
from rag.retriever.rerank.abstract import AbstractReranker
-from rag.retriever.vector import Document
+from rag.retriever.vector import Documents
+
+Context = Documents | Messages
class Reranker(metaclass=AbstractReranker):
def __init__(self) -> None:
self.model = CrossEncoder(os.environ["RERANK_MODEL"], device="cpu")
self.top_k = int(os.environ["RERANK_TOP_K"])
- self.relevance_threshold = float(os.environ["RETRIEVER_RELEVANCE_THRESHOLD"])
-
- def rerank_documents(self, query: str, documents: List[Document]) -> List[str]:
- results = self.model.rank(
- query=query,
- documents=[d.text for d in documents],
- return_documents=False,
- top_k=self.top_k,
- )
- ranking = list(
- filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results)
- )
- log.debug(
- f"Reranking gave {len(ranking)} relevant documents of {len(documents)}"
- )
- return [documents[r.get("corpus_id", 0)] for r in ranking]
+ self.relevance_threshold = float(os.environ["RERANK_RELEVANCE_THRESHOLD"])
- def rerank_messages(self, query: str, messages: List[Message]) -> List[Message]:
+ def rerank(self, query: Query, documents: Context) -> Context:
results = self.model.rank(
- query=query,
- documents=[m.content for m in messages],
+ query=query.query,
+ documents=documents.content(),
return_documents=False,
top_k=self.top_k,
)
- ranking = list(
- filter(lambda x: x.get("score", 0.0) > self.relevance_threshold, results)
+ rankings = list(
+ map(
+ lambda x: x.get("corpus_id", 0),
+ filter(
+ lambda x: x.get("score", 0.0) > self.relevance_threshold, results
+ ),
+ )
)
log.debug(
- f"Reranking gave {len(ranking)} relevant chat messages of {len(messages)}"
+ f"Reranking gave {len(rankings)} relevant documents of {len(documents)}"
)
- return [messages[r.get("corpus_id", 0)] for r in ranking]
+ documents.rerank(rankings)
+ return documents
diff --git a/rag/retriever/retriever.py b/rag/retriever/retriever.py
index 351cfb0..7d43941 100644
--- a/rag/retriever/retriever.py
+++ b/rag/retriever/retriever.py
@@ -1,3 +1,4 @@
+from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
from typing import List, Optional
@@ -5,11 +6,25 @@ from typing import List, Optional
from loguru import logger as log
from .document import DocumentDB
-from .encoder import Encoder
+from .encoder import Encoder, Query
from .parser.pdf import PDFParser
from .vector import Document, VectorDB
+@dataclass
+class FilePath:
+ path: Path
+
+
+@dataclass
+class Blob:
+ blob: BytesIO
+ source: Optional[str] = None
+
+
+FileType = FilePath | Blob
+
+
class Retriever:
def __init__(self) -> None:
self.pdf_parser = PDFParser()
@@ -17,35 +32,29 @@ class Retriever:
self.doc_db = DocumentDB()
self.vec_db = VectorDB()
- def __add_pdf_from_path(self, path: Path):
- log.debug(f"Adding pdf from {path}")
+ def __index_pdf_from_path(self, path: Path):
+ log.debug(f"Indexing pdf from {path}")
blob = self.pdf_parser.from_path(path)
- self.__add_pdf_from_blob(blob)
+ self.__index_pdf_from_blob(blob, None)
- def __add_pdf_from_blob(self, blob: BytesIO, source: Optional[str] = None):
- if self.doc_db.add(blob):
- log.debug("Adding pdf to vector database...")
+ def __index_pdf_from_blob(self, blob: BytesIO, source: Optional[str]):
+ if self.doc_db.create(blob):
+ log.debug("Indexing pdf to vector database...")
document = self.pdf_parser.from_data(blob)
chunks = self.pdf_parser.chunk(document, source)
- points = self.encoder.encode_document(chunks)
- self.vec_db.add(points)
+ points = self.encoder.encode(chunks)
+ self.vec_db.index(points)
else:
log.debug("Document already exists!")
- def add_pdf(
- self,
- path: Optional[Path] = None,
- blob: Optional[BytesIO] = None,
- source: Optional[str] = None,
- ):
- if path:
- self.__add_pdf_from_path(path)
- elif blob and source:
- self.__add_pdf_from_blob(blob, source)
- else:
- log.error("Invalid input!")
+ def index(self, filetype: FileType):
+ match filetype:
+ case FilePath(path):
+ self.__index_pdf_from_path(path)
+ case Blob(blob, source):
+ self.__index_pdf_from_blob(blob, source)
- def retrieve(self, query: str) -> List[Document]:
- log.debug(f"Finding documents matching query: {query}")
- query_emb = self.encoder.encode_query(query)
+ def search(self, query: Query) -> List[Document]:
+ log.debug(f"Finding documents matching query: {query.query}")
+ query_emb = self.encoder.encode(query)
return self.vec_db.search(query_emb)
diff --git a/rag/retriever/vector.py b/rag/retriever/vector.py
index 1a484f3..b36aee8 100644
--- a/rag/retriever/vector.py
+++ b/rag/retriever/vector.py
@@ -21,6 +21,20 @@ class Document:
text: str
+@dataclass
+class Documents:
+ documents: List[Document]
+
+ def __len__(self):
+ return len(self.documents)
+
+ def content(self) -> List[str]:
+ return [d.text for d in self.documents]
+
+ def rerank(self, rankings: List[int]):
+ self.documents = [self.documents[r] for r in rankings]
+
+
class VectorDB:
def __init__(self):
self.dim = int(os.environ["EMBEDDING_DIM"])
@@ -47,7 +61,7 @@ class VectorDB:
log.info(f"Deleting collection {self.collection_name}")
self.client.delete_collection(self.collection_name)
- def add(self, points: List[Point]):
+ def index(self, points: List[Point]):
log.debug(f"Inserting {len(points)} vectors into the vector db...")
self.client.upload_points(
collection_name=self.collection_name,
@@ -59,7 +73,7 @@ class VectorDB:
max_retries=3,
)
- def search(self, query: List[float]) -> List[Document]:
+ def search(self, query: List[float]) -> Documents:
log.debug("Searching for vectors...")
hits = self.client.search(
collection_name=self.collection_name,
@@ -68,11 +82,13 @@ class VectorDB:
score_threshold=self.score_threshold,
)
log.debug(f"Got {len(hits)} hits in the vector db with limit={self.top_k}")
- return list(
- map(
- lambda h: Document(
- title=h.payload.get("source", ""), text=h.payload["text"]
- ),
- hits,
+ return Documents(
+ list(
+ map(
+ lambda h: Document(
+ title=h.payload.get("source", ""), text=h.payload["text"]
+ ),
+ hits,
+ )
)
)
diff --git a/rag/static/styles.tcss b/rag/static/styles.tcss
new file mode 100644
index 0000000..902f60b
--- /dev/null
+++ b/rag/static/styles.tcss
@@ -0,0 +1,48 @@
+#rag {
+ background: #151515;
+ color: #E1E1E1;
+ height: auto;
+ width: 100%;
+ margin: 0 0 2 0;
+ grid-size: 2 2;
+}
+
+#history {
+ scrollbar-size: 1 1;
+ height: 80%;
+}
+
+#output {
+ height: 100%;
+}
+
+#chat {
+ border: round #373737;
+ padding: 1;
+ margin-top: 2;
+ margin-right: 2;
+}
+
+#input {
+ height: 20%;
+ dock: bottom;
+}
+
+#references {
+ scrollbar-size: 1 1;
+ dock: right;
+ border-left: solid #373737;
+ height: 100%;
+ width: 20%;
+}
+
+Input {
+ dock: bottom;
+ align_horizontal: center;
+ width: 90%;
+ height: 10%;
+ background: #202020;
+ color: #E1E1E1;
+ border: round #373737;
+}
+
diff --git a/rag/tui.py b/rag/tui.py
new file mode 100644
index 0000000..ead4f71
--- /dev/null
+++ b/rag/tui.py
@@ -0,0 +1,71 @@
+from textual import events
+from textual.app import App, ComposeResult
+from textual.containers import Vertical, Container, VerticalScroll
+from textual.widgets import Input, Label, Markdown, Static, TabbedContent, TabPane
+
+JESSICA = """
+# Lady Jessica
+
+Bene Gesserit and concubine of Leto, and mother of Paul and Alia.
+"""
+
+PAUL = """
+# Paul Atreides
+
+Son of Leto and Jessica.
+"""
+
+TEXT = """\
+Docking a widget removes it from the layout and fixes its position, aligned to either the top, right, bottom, or left edges of a container.
+
+Docked widgets will not scroll out of view, making them ideal for sticky headers, footers, and sidebars.
+
+"""
+
+class TabbedApp(App):
+ """An example of tabbed content."""
+
+ CSS_PATH = "static/styles.tcss"
+
+ BINDINGS = [
+ ("n", "show_tab('rag')", "Rag"),
+ ("e", "show_tab('settings')", "Settings"),
+ ]
+
+ def compose(self) -> ComposeResult:
+ """Compose app with tabbed content."""
+ # Add the TabbedContent widget
+ with Container(id="rag"):
+ with VerticalScroll(id="references"):
+ yield Static("test3", classes="context")
+ with VerticalScroll(id="history"):
+ yield Static(TEXT * 10, classes="output")
+ with Vertical(id="chat"):
+ yield Static("test2", classes="input")
+ # with TabbedContent(initial="rag"):
+ # with TabPane("RAG", id="rag"): # First tab
+ # # yield Input(placeholder=">>", id="chat")
+ # yield Static("test1", classes="chat")
+ # yield Static("test", classes="context")
+ # with TabPane("Settings", id="settings"):
+ # yield Markdown(JESSICA)
+ # with TabbedContent("Paul", "Alia"):
+ # yield TabPane("Paul", Label("First child"))
+ # yield TabPane("Alia", Label("Second child"))
+
+ def action_show_tab(self, tab: str) -> None:
+ """Switch to a new tab."""
+ self.get_child_by_type(TabbedContent).active = tab
+
+ def on_key(self, event: events.Key) -> None:
+ if event.key == "s":
+ self.action_show_tab("settings")
+ if event.key == "r":
+ self.action_show_tab("rag")
+ if event.key == "q":
+ self.exit()
+
+
+if __name__ == "__main__":
+ app = TabbedApp()
+ app.run()
diff --git a/rag/ui.py b/rag/ui.py
deleted file mode 100644
index 2192ad8..0000000
--- a/rag/ui.py
+++ /dev/null
@@ -1,134 +0,0 @@
-from typing import List
-
-import streamlit as st
-from dotenv import load_dotenv
-from langchain_community.document_loaders.blob_loaders import Blob
-from loguru import logger as log
-
-from rag.generator import MODELS
-from rag.generator.prompt import Prompt
-from rag.message import Message
-from rag.model import Rag
-from rag.retriever.vector import Document
-
-
-def set_chat_users():
- log.debug("Setting user and bot value")
- ss = st.session_state
- ss.user = "user"
- ss.bot = "assistant"
-
-@st.cache_resource
-def load_rag():
- log.debug("Loading Rag...")
- st.session_state.rag = Rag()
-
-
-@st.cache_resource
-def set_client(client: str):
- log.debug("Setting client...")
- rag = st.session_state.rag
- rag.set_client(client)
-
-
-@st.cache_data(show_spinner=False)
-def upload(files):
- rag = st.session_state.rag
- with st.spinner("Uploading documents..."):
- for file in files:
- source = file.name
- blob = Blob.from_data(file.read())
- rag.retriever.add_pdf(blob=blob, source=source)
-
-
-def display_context(documents: List[Document]):
- with st.popover("See Context"):
- for i, doc in enumerate(documents):
- st.markdown(f"### Document {i}")
- st.markdown(f"**Title: {doc.title}**")
- st.markdown(doc.text)
- st.markdown("---")
-
-
-def display_chat():
- ss = st.session_state
- for msg in ss.chat:
- if isinstance(msg, list):
- display_context(msg)
- else:
- st.chat_message(msg.role).write(msg.content)
-
-
-def generate_chat(query: str):
- ss = st.session_state
-
- with st.chat_message(ss.user):
- st.write(query)
-
- rag = ss.rag
- documents = rag.retrieve(query)
- prompt = Prompt(query, documents, ss.model)
- with st.chat_message(ss.bot):
- response = st.write_stream(rag.generate(prompt))
-
- rag.add_message(rag.bot, response)
-
- display_context(prompt.documents)
- store_chat(prompt, response)
-
-
-def store_chat(prompt: Prompt, response: str):
- log.debug("Storing chat")
- ss = st.session_state
- query = Message(ss.user, prompt.query, ss.model)
- response = Message(ss.bot, response, ss.model)
- ss.chat.append(query)
- ss.chat.append(response)
- ss.chat.append(prompt.documents)
-
-
-def sidebar():
- with st.sidebar:
- st.header("Grounding")
- st.markdown(
- (
- "These files will be uploaded to the knowledge base and used "
- "as groudning if they are relevant to the question."
- )
- )
-
- files = st.file_uploader(
- "Choose pdfs to add to the knowledge base",
- type="pdf",
- accept_multiple_files=True,
- )
-
- upload(files)
-
- st.header("Model")
- st.markdown(
- "Select the model that will be used for reranking and generating the answer."
- )
- st.selectbox("Model", key="model", options=MODELS)
- set_client(st.session_state.model)
-
-
-def page():
- ss = st.session_state
- if "chat" not in st.session_state:
- ss.chat = []
-
- display_chat()
-
- query = st.chat_input("Enter query here")
- if query:
- generate_chat(query)
-
-
-if __name__ == "__main__":
- load_dotenv()
- st.title("Retrieval Augmented Generation")
- set_chat_users()
- load_rag()
- sidebar()
- page()
diff --git a/rag/workflows.py b/rag/workflows.py
new file mode 100644
index 0000000..58d1d78
--- /dev/null
+++ b/rag/workflows.py
@@ -0,0 +1,7 @@
+
+
+def ollama_workflow():
+ pass
+
+def cohere_workflow():
+ pass