From 04686f497f120096435da72c6546306eb292846a Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 23 Apr 2024 00:50:25 +0200 Subject: Add delete all cli --- rag/drop.py | 24 ++++++++++++++++++++++++ rag/retriever/document.py | 10 ++++++++++ rag/retriever/vector.py | 6 +++++- 3 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 rag/drop.py diff --git a/rag/drop.py b/rag/drop.py new file mode 100644 index 0000000..89ae755 --- /dev/null +++ b/rag/drop.py @@ -0,0 +1,24 @@ +import click +from dotenv import load_dotenv +from loguru import logger as log + +from rag.retriever.retriever import Retriever + + +def drop(): + log.debug("Dropping documents") + retriever = Retriever() + doc_db = retriever.doc_db + doc_db.delete_all() + vec_db = retriever.vec_db + vec_db.delete_collection() + + +@click.confirmation_option(prompt="Are you sure you want to drop the db?") +def main(): + drop() + + +if __name__ == "__main__": + load_dotenv() + main() diff --git a/rag/retriever/document.py b/rag/retriever/document.py index 8a50f01..132ec4b 100644 --- a/rag/retriever/document.py +++ b/rag/retriever/document.py @@ -34,6 +34,16 @@ class DocumentDB: log.debug("Hashing document...") return hashlib.sha256(blob.as_bytes()).hexdigest() + def delete_all(self): + with self.conn.cursor() as cur: + cur.execute( + """ + TRUNCATE TABLE + document + """ + ) + self.conn.commit() + def add(self, blob: Blob) -> bool: with self.conn.cursor() as cur: hash = self.__hash(blob) diff --git a/rag/retriever/vector.py b/rag/retriever/vector.py index e8d22c2..b72a3c1 100644 --- a/rag/retriever/vector.py +++ b/rag/retriever/vector.py @@ -22,7 +22,7 @@ class Document: class VectorDB: - def __init__(self, score_threshold: float = 0.6): + def __init__(self, score_threshold: float = 0.5): self.dim = int(os.environ["EMBEDDING_DIM"]) self.collection_name = os.environ["QDRANT_COLLECTION_NAME"] self.client = QdrantClient(url=os.environ["QDRANT_URL"]) @@ -42,6 +42,10 @@ class VectorDB: else: log.debug(f"Collection {self.collection_name} already exists!") + def delete_collection(self): + log.info(f"Deleting collection {self.collection_name}") + self.client.delete_collection(self.collection_name) + def add(self, points: List[Point]): log.debug(f"Inserting {len(points)} vectors into the vector db...") self.client.upload_points( -- cgit v1.2.3-70-g09d2