summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-05-29 00:53:39 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-05-29 00:53:39 +0200
commit716e3fe58adee5b8a6bfa91de4b3ba6cf204d172 (patch)
tree778da9011d21051006fc206ce0978f0fc114b77b
parent2d91c118d71a8dd7fbd7f9cf21f86e92da33827e (diff)
Wip memory
-rw-r--r--README.md6
-rw-r--r--docs/rag.html10
-rw-r--r--notebooks/testing.ipynb203
-rw-r--r--rag/generator/cohere.py7
-rw-r--r--rag/generator/ollama.py4
-rw-r--r--rag/retriever/memory.py51
-rw-r--r--rag/retriever/rerank/local.py21
-rw-r--r--rag/ui.py14
8 files changed, 286 insertions, 30 deletions
diff --git a/README.md b/README.md
index 53d90b6..3010524 100644
--- a/README.md
+++ b/README.md
@@ -98,6 +98,12 @@ streamlit run rag/ui.py
Yes, it is inefficient/dumb to use ollama when you can just load the models with python
in the same process.
+### TODO
+
+-[ ] Rerank history if it is relevant.
+-[ ] message ollama/cohere
+-[ ] create db script
+
### Inspiration
I took some inspiration from these tutorials:
diff --git a/docs/rag.html b/docs/rag.html
index 8294f1f..1005192 100644
--- a/docs/rag.html
+++ b/docs/rag.html
@@ -1798,9 +1798,11 @@ relevance:0},{begin:/->/}]})})();hljs.registerLanguage("ocaml",e)})();</script>
<img src="https://qdrant.tech/articles_data/what-is-rag-in-ai/how-generation-works.png" alt="image" width="1024" height="auto">
</center>
<div up-at-unpause pause></div>
+<div id="rag">
<center>
<img src="https://qdrant.tech/articles_data/what-is-rag-in-ai/rag-system.jpg" alt="image" width="1024" height="auto">
</center>
+</div>
<div up-at-unpause pause></div>
<center>
<img src="https://media2.giphy.com/media/v1.Y2lkPTc5MGI3NjExZXM0ZGtsZzdldjh5cW54bnN1MTA1dDl3cjV2c2p2NmRiMHpkYmZyYyZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/QYLQRR7IF48njkq5an/giphy.gif" alt="image" width="1024" height="auto">
@@ -1811,16 +1813,18 @@ relevance:0},{begin:/->/}]})})();hljs.registerLanguage("ocaml",e)})();</script>
<p focus-at-unpause=solution pause><span>Add another </span><strong><span>LLM</span></strong><span> of course!</span></p>
</div>
</center>
-<div up-at-unpause pause attributes the></div>
+<div up-at-unpause pause></div>
<h1 id="reranker"><a class="anchor" aria-hidden="true" href="#reranker"></a><span>Reranker</span></h1>
-<div left pause></div>
<center>
<img src="https://cdn.sanity.io/images/vr8gru94/production/906c3c0f8fe637840f134dbf966839ef89ac7242-3443x1641.png" alt="image" width="1024" height="auto">
</center>
<div up-at-unpause pause></div>
+<p><span>A common reranking model is the cross encoder:</span></p>
<center>
-<img src="https://cdn.sanity.io/images/vr8gru94/production/4509817116ab72e27bae809c38cb48fbf1578b5d-2760x1420.png" alt="image" width="1024" height="auto">
+<img src="https://cdn.sanity.io/images/vr8gru94/production/9f0d2f75571bb58eecf2520a23d300a5fc5b1e2c-2440x1100.png" alt="image" width="1024" height="auto">
</center>
+<p><span>We plug this reranking model into the rag pipeline...</span>
+<span up-at-unpause=rag pause></span></p>
</slip-body>
</slip-slip>
diff --git a/notebooks/testing.ipynb b/notebooks/testing.ipynb
index d255438..26ff2e4 100644
--- a/notebooks/testing.ipynb
+++ b/notebooks/testing.ipynb
@@ -2,12 +2,13 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": 1,
"id": "c1f56ae3-a056-4b31-bcab-27c2c97c00f1",
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
+ "from dotenv import load_dotenv\n",
"\n",
"from importlib.util import find_spec\n",
"if find_spec(\"rag\") is None:\n",
@@ -17,7 +18,7 @@
},
{
"cell_type": "code",
- "execution_count": 28,
+ "execution_count": 2,
"id": "240df289-d9d6-424b-a5b9-e09dcfefd57e",
"metadata": {},
"outputs": [],
@@ -27,7 +28,7 @@
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": 7,
"id": "01305de3-0d6d-46e7-a0f0-2e2a2f72d563",
"metadata": {},
"outputs": [
@@ -35,18 +36,41 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "\u001b[32m2024-04-13 22:30:45.267\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.retriever.document\u001b[0m:\u001b[36m__configure\u001b[0m:\u001b[36m28\u001b[0m - \u001b[34m\u001b[1mCreating documents table if it does not exist...\u001b[0m\n",
- "\u001b[32m2024-04-13 22:30:45.278\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.retriever.vector\u001b[0m:\u001b[36m__configure\u001b[0m:\u001b[36m43\u001b[0m - \u001b[34m\u001b[1mCollection knowledge-base already exists...\u001b[0m\n"
+ "\u001b[32m2024-04-20 20:49:04.904\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.retriever.document\u001b[0m:\u001b[36m__configure\u001b[0m:\u001b[36m28\u001b[0m - \u001b[34m\u001b[1mCreating documents table if it does not exist...\u001b[0m\n",
+ "\u001b[32m2024-04-20 20:49:04.922\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.retriever.vector\u001b[0m:\u001b[36m__configure\u001b[0m:\u001b[36m43\u001b[0m - \u001b[34m\u001b[1mCollection knowledge-base already exists!\u001b[0m\n"
]
}
],
"source": [
- "client = Retriever().vec_db"
+ "load_dotenv()\n",
+ "ret = Retriever()\n",
+ "vecdb = ret.vec_db"
]
},
{
"cell_type": "code",
- "execution_count": 32,
+ "execution_count": 5,
+ "id": "c83210e2-a5a3-46eb-bd77-aedc0d223190",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "CountResult(count=17918)"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "client.client.count(\"knowledge-base\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
"id": "4fcb1862-19a3-482b-b397-ce7be074c187",
"metadata": {},
"outputs": [
@@ -56,7 +80,7 @@
"True"
]
},
- "execution_count": 32,
+ "execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@@ -67,9 +91,170 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 6,
"id": "c86d538c-d613-4e4d-8dd2-3abcea2cfeef",
"metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "CollectionInfo(status=<CollectionStatus.GREEN: 'green'>, optimizer_status=<OptimizersStatusOneOf.OK: 'ok'>, vectors_count=17918, indexed_vectors_count=0, points_count=17918, segments_count=8, config=CollectionConfig(params=CollectionParams(vectors=VectorParams(size=1024, distance=<Distance.COSINE: 'Cosine'>, hnsw_config=None, quantization_config=None, on_disk=None), shard_number=1, sharding_method=None, replication_factor=1, write_consistency_factor=1, read_fan_out_factor=None, on_disk_payload=True, sparse_vectors=None), hnsw_config=HnswConfig(m=16, ef_construct=100, full_scan_threshold=10000, max_indexing_threads=0, on_disk=False, payload_m=None), optimizer_config=OptimizersConfig(deleted_threshold=0.2, vacuum_min_vector_number=1000, default_segment_number=0, max_segment_size=None, memmap_threshold=None, indexing_threshold=20000, flush_interval_sec=5, max_optimization_threads=None), wal_config=WalConfig(wal_capacity_mb=32, wal_segments_ahead=0), quantization_config=None), payload_schema={})"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "client.client.get_collection(\"knowledge-base\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "c346ef16-4884-4c3a-b03c-9f789ca00212",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\u001b[32m2024-04-20 20:49:52.237\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.retriever.retriever\u001b[0m:\u001b[36mretrieve\u001b[0m:\u001b[36m49\u001b[0m - \u001b[34m\u001b[1mFinding documents matching query: what is a convex function?\u001b[0m\n",
+ "\u001b[32m2024-04-20 20:49:52.239\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.retriever.encoder\u001b[0m:\u001b[36mencode_query\u001b[0m:\u001b[36m41\u001b[0m - \u001b[34m\u001b[1mEncoding query: what is a convex function?\u001b[0m\n",
+ "\u001b[32m2024-04-20 20:49:53.191\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.retriever.vector\u001b[0m:\u001b[36msearch\u001b[0m:\u001b[36m58\u001b[0m - \u001b[34m\u001b[1mSearching for vectors...\u001b[0m\n",
+ "\u001b[32m2024-04-20 20:49:53.198\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.retriever.vector\u001b[0m:\u001b[36msearch\u001b[0m:\u001b[36m65\u001b[0m - \u001b[34m\u001b[1mGot 5 hits in the vector db with limit=5\u001b[0m\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "[Document(title='Haskell Design Patterns.pdf', text='N\\nnon-tail\\trecursion\\t/\\t\\nNon-tail\\trecursion'),\n",
+ " Document(title='LinuxNotesForProfessionals.pdf', text='GoalKicker.com – Linux ® Notes for Professionals 44sudo systemctl disable sshd.service\\nDebian\\nsudo /etc/init.d/ssh stop\\nsudo systemctl disable sshd.service\\nArch Linux\\nsudo killall sshd\\nsudo systemctl disable sshd.service'),\n",
+ " Document(title='Kubernetes_Deployment_Antipatterns_v1.1.pdf', text='Anti-pattern 12\\nNot using the Helm package manager\\n37'),\n",
+ " Document(title='Docker_anti_patterns_vertical_3.pdf', text='Docker Anti-PatternsCONTINUOUS DEPLOYMENT / DELIVERY'),\n",
+ " Document(title='Haskell Design Patterns.pdf', text='B\\nbind\\tchain\\nand\\tmonad\\t/\\t\\nMonads\\tand\\tthe\\tbind\\tchain')]"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ret.retrieve(\"what is a convex function?\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "36d174df-aa21-4708-ad52-43bb8cfa80e5",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\u001b[32m2024-04-20 20:50:18.285\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.retriever.retriever\u001b[0m:\u001b[36mretrieve\u001b[0m:\u001b[36m49\u001b[0m - \u001b[34m\u001b[1mFinding documents matching query: what is a hidden markov model?\u001b[0m\n",
+ "\u001b[32m2024-04-20 20:50:18.286\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.retriever.encoder\u001b[0m:\u001b[36mencode_query\u001b[0m:\u001b[36m41\u001b[0m - \u001b[34m\u001b[1mEncoding query: what is a hidden markov model?\u001b[0m\n",
+ "\u001b[32m2024-04-20 20:50:18.357\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.retriever.vector\u001b[0m:\u001b[36msearch\u001b[0m:\u001b[36m58\u001b[0m - \u001b[34m\u001b[1mSearching for vectors...\u001b[0m\n",
+ "\u001b[32m2024-04-20 20:50:18.364\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.retriever.vector\u001b[0m:\u001b[36msearch\u001b[0m:\u001b[36m65\u001b[0m - \u001b[34m\u001b[1mGot 5 hits in the vector db with limit=5\u001b[0m\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "[Document(title='thinkocaml.pdf', text='44 Chapter 5. Recursive Functions'),\n",
+ " Document(title='Haskell Design Patterns.pdf', text='N\\nnon-tail\\trecursion\\t/\\t\\nNon-tail\\trecursion'),\n",
+ " Document(title='thinkdsp.pdf', text='Think DSP\\nDigital Signal Processing in Python\\nVersion 1.1.4\\nAllen B. Downey\\nGreen Tea Press\\nNeedham, Massachusetts'),\n",
+ " Document(title='PFP-0.1.pdf', text='ii'),\n",
+ " Document(title='Learning Bayesian Networks(Neapolitan, Richard).pdf', text='ii')]"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ret.retrieve(\"what is a hidden markov model?\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "091c5da0-3a9f-4100-8b2d-ee93a8cf3234",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\u001b[32m2024-04-20 20:50:42.102\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.retriever.retriever\u001b[0m:\u001b[36mretrieve\u001b[0m:\u001b[36m49\u001b[0m - \u001b[34m\u001b[1mFinding documents matching query: what is the weather today?\u001b[0m\n",
+ "\u001b[32m2024-04-20 20:50:42.104\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.retriever.encoder\u001b[0m:\u001b[36mencode_query\u001b[0m:\u001b[36m41\u001b[0m - \u001b[34m\u001b[1mEncoding query: what is the weather today?\u001b[0m\n",
+ "\u001b[32m2024-04-20 20:50:42.175\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.retriever.vector\u001b[0m:\u001b[36msearch\u001b[0m:\u001b[36m58\u001b[0m - \u001b[34m\u001b[1mSearching for vectors...\u001b[0m\n",
+ "\u001b[32m2024-04-20 20:50:42.181\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.retriever.vector\u001b[0m:\u001b[36msearch\u001b[0m:\u001b[36m65\u001b[0m - \u001b[34m\u001b[1mGot 5 hits in the vector db with limit=5\u001b[0m\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "[Document(title='Haskell Design Patterns.pdf', text='N\\nnon-tail\\trecursion\\t/\\t\\nNon-tail\\trecursion'),\n",
+ " Document(title='category-theory-for-programmers.pdf', text='Category Theory\\nfor Programmers\\nByBartosz Milewski\\ncompiled and edited by\\nIgal Tabachnik'),\n",
+ " Document(title='CNotesForProfessionals.pdf', text='Section 22.10: Pointer to Pointer 141 .............................................................................................................................. \\nSection 22.11: void* pointers as arguments and return values to standard functions 141 ....................................... \\nSection 22.12: Same Asterisk, Di\\ue023erent Meanings 142 ................................................................................................. \\nChapter 23: Sequence points 144 .............................................................................................................................. \\nSection 23.1: Unsequenced expressions 144 .................................................................................................................. \\nSection 23.2: Sequenced expressions 144 ..................................................................................................................... \\nSection 23.3: Indeterminately sequenced expressions 145 ......................................................................................... \\nChapter 24: Function Pointers 146 ........................................................................................................................... \\nSection 24.1: Introduction 146 .......................................................................................................................................... \\nSection 24.2: Returning Function Pointers from a Function 146 ................................................................................. \\nSection 24.3: Best Practices 147 ..................................................................................................................................... \\nSection 24.4: Assigning a Function Pointer 149 ............................................................................................................. \\nSection 24.5: Mnemonic for writing function pointers 149 ........................................................................................... \\nSection 24.6: Basics 150 ................................................................................................................................................... \\nChapter 25: Function Parameters 152 .................................................................................................................... \\nSection 25.1: Parameters are passed by value 152 ...................................................................................................... \\nSection 25.2: Passing in Arrays to Functions 152 .......................................................................................................... \\nSection 25.3: Order of function parameter execution 153 ........................................................................................... \\nSection 25.4: Using pointer parameters to return multiple values 153 ...................................................................... \\nSection 25.5: Example of function returning struct containing values with error codes 154 ................................... \\nChapter 26: Pass 2D-arrays to functions 156 ..................................................................................................... \\nSection 26.1: Pass a 2D-array to a function 156 ........................................................................................................... \\nSection 26.2: Using flat arrays as 2D arrays 162 .......................................................................................................... \\nChapter 27: Error handling 163 .................................................................................................................................. \\nSection 27.1: errno 163 .....................................................................................................................................................'),\n",
+ " Document(title='Beginning Haskell_ A Project-Based Approach.pdf', text='do Notation �������������������������������������������������������������������������������������������������������������������������������������������������������� 150\\nMonad Laws ������������������������������������������������������������������������������������������������������������������������������������������������������ 152'),\n",
+ " Document(title='Learning Bayesian Networks(Neapolitan, Richard).pdf', text='viii CONTENTS')]"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ret.retrieve(\"what is the weather today?\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "8830235e-1cdc-46b1-9a0f-96d7df6fc183",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\u001b[32m2024-04-20 20:51:23.336\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.retriever.retriever\u001b[0m:\u001b[36mretrieve\u001b[0m:\u001b[36m49\u001b[0m - \u001b[34m\u001b[1mFinding documents matching query: what is ocaml?\u001b[0m\n",
+ "\u001b[32m2024-04-20 20:51:23.337\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.retriever.encoder\u001b[0m:\u001b[36mencode_query\u001b[0m:\u001b[36m41\u001b[0m - \u001b[34m\u001b[1mEncoding query: what is ocaml?\u001b[0m\n",
+ "\u001b[32m2024-04-20 20:51:23.409\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.retriever.vector\u001b[0m:\u001b[36msearch\u001b[0m:\u001b[36m58\u001b[0m - \u001b[34m\u001b[1mSearching for vectors...\u001b[0m\n",
+ "\u001b[32m2024-04-20 20:51:23.416\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mrag.retriever.vector\u001b[0m:\u001b[36msearch\u001b[0m:\u001b[36m65\u001b[0m - \u001b[34m\u001b[1mGot 5 hits in the vector db with limit=5\u001b[0m\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "[Document(title='Haskell Design Patterns.pdf', text='N\\nnon-tail\\trecursion\\t/\\t\\nNon-tail\\trecursion'),\n",
+ " Document(title='category-theory-for-programmers.pdf', text='Category Theory\\nfor Programmers\\nByBartosz Milewski\\ncompiled and edited by\\nIgal Tabachnik'),\n",
+ " Document(title='Docker_anti_patterns_vertical_3.pdf', text='Docker Anti-PatternsCONTINUOUS DEPLOYMENT / DELIVERY'),\n",
+ " Document(title='thinkocaml.pdf', text='90 Chapter 12. Hashtables'),\n",
+ " Document(title='Beginning Haskell_ A Project-Based Approach.pdf', text='do Notation �������������������������������������������������������������������������������������������������������������������������������������������������������� 150\\nMonad Laws ������������������������������������������������������������������������������������������������������������������������������������������������������ 152')]"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ret.retrieve(\"what is ocaml?\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d6d0c2d0-25a3-446b-a103-b8b56c82296c",
+ "metadata": {},
"outputs": [],
"source": []
}
diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py
index 28a87e7..fb0cc5b 100644
--- a/rag/generator/cohere.py
+++ b/rag/generator/cohere.py
@@ -1,6 +1,6 @@
import os
from dataclasses import asdict
-from typing import Any, Generator
+from typing import Any, Dict, Generator, List, Optional
import cohere
from loguru import logger as log
@@ -13,12 +13,15 @@ class Cohere(metaclass=AbstractGenerator):
def __init__(self) -> None:
self.client = cohere.Client(os.environ["COHERE_API_KEY"])
- def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
+ def generate(
+ self, prompt: Prompt, history: Optional[List[Dict[str, str]]]
+ ) -> Generator[Any, Any, Any]:
log.debug("Generating answer from cohere...")
query = f"{prompt.query}\n\n{ANSWER_INSTRUCTION}"
for event in self.client.chat_stream(
message=query,
documents=[asdict(d) for d in prompt.documents],
+ chat_history=history,
prompt_truncation="AUTO",
):
if event.event_type == "text-generation":
diff --git a/rag/generator/ollama.py b/rag/generator/ollama.py
index 52521ca..9bf551a 100644
--- a/rag/generator/ollama.py
+++ b/rag/generator/ollama.py
@@ -36,8 +36,8 @@ class Ollama(metaclass=AbstractGenerator):
)
return metaprompt
- def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]:
+ def generate(self, prompt: Prompt, memory: Memory) -> Generator[Any, Any, Any]:
log.debug("Generating answer with ollama...")
metaprompt = self.__metaprompt(prompt)
- for chunk in ollama.generate(model=self.model, prompt=metaprompt, stream=True):
+ for chunk in ollama.chat(model=self.model, messages=memory.append(metaprompt), stream=True):
yield chunk["response"]
diff --git a/rag/retriever/memory.py b/rag/retriever/memory.py
new file mode 100644
index 0000000..c4455ed
--- /dev/null
+++ b/rag/retriever/memory.py
@@ -0,0 +1,51 @@
+from dataclasses import dataclass
+from typing import Dict, List
+
+
+@dataclass
+class Log:
+ user: Message
+ bot: Message
+
+ def get():
+ return (user, bot)
+
+
+@dataclass
+class Message:
+ role: str
+ message: str
+
+ def as_dict(self, model: str) -> Dict[str, str]:
+ if model == "cohere":
+ match self.role:
+ case "user":
+ role = "USER"
+ case _:
+ role = "CHATBOT"
+
+ return {"role": role, "message": self.message}
+ else:
+ return {"role": self.role, "content": self.message}
+
+
+class Memory:
+ def __init__(self, reranker) -> None:
+ self.history = []
+ self.reranker = reranker
+ self.user = "user"
+ self.bot = "assistant"
+
+ def add(self, prompt: str, response: str):
+ self.history.append(
+ Log(
+ user=Message(role=self.user, message=prompt),
+ bot=Message(role=self.bot, message=response),
+ )
+ )
+
+ def get(self) -> List[Log]:
+ return [m.as_dict() for log in self.history for m in log.get()]
+
+ def reset(self):
+ self.history = []
diff --git a/rag/retriever/rerank/local.py b/rag/retriever/rerank/local.py
index 75fedd8..8e94882 100644
--- a/rag/retriever/rerank/local.py
+++ b/rag/retriever/rerank/local.py
@@ -1,9 +1,11 @@
import os
+from typing import List
from loguru import logger as log
from sentence_transformers import CrossEncoder
from rag.generator.prompt import Prompt
+from rag.retriever.memory import Log
from rag.retriever.rerank.abstract import AbstractReranker
@@ -33,3 +35,22 @@ class Reranker(metaclass=AbstractReranker):
prompt.documents[r.get("corpus_id", 0)] for r in ranking
]
return prompt
+
+ def rank_memory(self, prompt: Prompt, history: List[Log]) -> List[Log]:
+ if history:
+ results = self.model.rank(
+ query=prompt.query,
+ documents=[m.bot.message for m in history],
+ return_documents=False,
+ top_k=self.top_k,
+ )
+ ranking = list(
+ filter(
+ lambda x: x.get("score", 0.0) > self.relevance_threshold, results
+ )
+ )
+ log.debug(
+ f"Reranking gave {len(ranking)} relevant messages of {len(history)}"
+ )
+ history = [history[r.get("corpus_id", 0)] for r in ranking]
+ return history
diff --git a/rag/ui.py b/rag/ui.py
index 36e8c4c..ddb3d78 100644
--- a/rag/ui.py
+++ b/rag/ui.py
@@ -13,18 +13,6 @@ from rag.retriever.retriever import Retriever
from rag.retriever.vector import Document
-@dataclass
-class Message:
- role: str
- message: str
-
- def as_dict(self, model: str) -> Dict[str, str]:
- if model == "cohere":
- return {"role": self.role, "message": self.message}
- else:
- return {"role": self.role, "content": self.message}
-
-
def set_chat_users():
log.debug("Setting user and bot value")
ss = st.session_state
@@ -38,13 +26,11 @@ def load_retriever():
st.session_state.retriever = Retriever()
-# @st.cache_resource
def load_generator(model: str):
log.debug("Loading generator model")
st.session_state.generator = get_generator(model)
-# @st.cache_resource
def load_reranker(model: str):
log.debug("Loading reranker model")
st.session_state.reranker = get_reranker(model)