summaryrefslogtreecommitdiff
path: root/rag/ui.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/ui.py')
-rw-r--r--rag/ui.py21
1 files changed, 11 insertions, 10 deletions
diff --git a/rag/ui.py b/rag/ui.py
index 84dbbeb..83a22e2 100644
--- a/rag/ui.py
+++ b/rag/ui.py
@@ -1,16 +1,15 @@
import streamlit as st
from langchain_community.document_loaders.blob_loaders import Blob
+from .rag import RAG
-try:
- from rag.rag import RAG
- from rag.llm.ollama_generator import Prompt
-except ModuleNotFoundError:
- from rag import RAG
- from llm.ollama_generator import Prompt
+from .generator import get_generator
+from .generator.prompt import Prompt
rag = RAG()
+MODELS = ["ollama", "cohere"]
+
def upload_pdfs():
files = st.file_uploader(
@@ -26,13 +25,15 @@ def upload_pdfs():
for file in files:
source = file.name
blob = Blob.from_data(file.read())
- rag.add_pdf_from_blob(blob, source)
+ rag.add_pdf(blob, source)
if __name__ == "__main__":
ss = st.session_state
st.header("RAG-UI")
+ model = st.selectbox("Model", options=MODELS)
+
upload_pdfs()
with st.form(key="query"):
@@ -56,7 +57,7 @@ if __name__ == "__main__":
query = ss.get("query", "")
with st.spinner("Searching for documents..."):
- documents = rag.search(query)
+ documents = rag.retrieve(query)
prompt = Prompt(query, documents)
@@ -69,6 +70,6 @@ if __name__ == "__main__":
st.markdown("---")
with result_column:
+ generator = get_generator(model)
st.markdown("### Answer")
- st.write_stream(rag.retrieve(prompt))
-
+ st.write_stream(rag.generate(generator, prompt))