summaryrefslogtreecommitdiff
path: root/rag
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-23 00:50:25 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-23 00:50:25 +0200
commit04686f497f120096435da72c6546306eb292846a (patch)
tree8203df9747e2cb729a8317045e5bdc4512241ac4 /rag
parentf5adcbb62b10110dc7417c5a07ef6461782f5a10 (diff)
Add delete all cli
Diffstat (limited to 'rag')
-rw-r--r--rag/drop.py24
-rw-r--r--rag/retriever/document.py10
-rw-r--r--rag/retriever/vector.py6
3 files changed, 39 insertions, 1 deletions
diff --git a/rag/drop.py b/rag/drop.py
new file mode 100644
index 0000000..89ae755
--- /dev/null
+++ b/rag/drop.py
@@ -0,0 +1,24 @@
+import click
+from dotenv import load_dotenv
+from loguru import logger as log
+
+from rag.retriever.retriever import Retriever
+
+
+def drop():
+ log.debug("Dropping documents")
+ retriever = Retriever()
+ doc_db = retriever.doc_db
+ doc_db.delete_all()
+ vec_db = retriever.vec_db
+ vec_db.delete_collection()
+
+
+@click.confirmation_option(prompt="Are you sure you want to drop the db?")
+def main():
+ drop()
+
+
+if __name__ == "__main__":
+ load_dotenv()
+ main()
diff --git a/rag/retriever/document.py b/rag/retriever/document.py
index 8a50f01..132ec4b 100644
--- a/rag/retriever/document.py
+++ b/rag/retriever/document.py
@@ -34,6 +34,16 @@ class DocumentDB:
log.debug("Hashing document...")
return hashlib.sha256(blob.as_bytes()).hexdigest()
+ def delete_all(self):
+ with self.conn.cursor() as cur:
+ cur.execute(
+ """
+ TRUNCATE TABLE
+ document
+ """
+ )
+ self.conn.commit()
+
def add(self, blob: Blob) -> bool:
with self.conn.cursor() as cur:
hash = self.__hash(blob)
diff --git a/rag/retriever/vector.py b/rag/retriever/vector.py
index e8d22c2..b72a3c1 100644
--- a/rag/retriever/vector.py
+++ b/rag/retriever/vector.py
@@ -22,7 +22,7 @@ class Document:
class VectorDB:
- def __init__(self, score_threshold: float = 0.6):
+ def __init__(self, score_threshold: float = 0.5):
self.dim = int(os.environ["EMBEDDING_DIM"])
self.collection_name = os.environ["QDRANT_COLLECTION_NAME"]
self.client = QdrantClient(url=os.environ["QDRANT_URL"])
@@ -42,6 +42,10 @@ class VectorDB:
else:
log.debug(f"Collection {self.collection_name} already exists!")
+ def delete_collection(self):
+ log.info(f"Deleting collection {self.collection_name}")
+ self.client.delete_collection(self.collection_name)
+
def add(self, points: List[Point]):
log.debug(f"Inserting {len(points)} vectors into the vector db...")
self.client.upload_points(