summaryrefslogtreecommitdiff
path: root/rag/llm/encoder.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-05 18:31:27 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-05 18:31:27 +0200
commit36bb9e4cd1f42ae6e60ed4e296beab0eb462f376 (patch)
tree6c08d32bd5e384ae6704db8863f4bc3ea41e60be /rag/llm/encoder.py
parent1dfaf80c75afa84b6d03a0013eb1fd94d0257226 (diff)
Refactor encoder
Diffstat (limited to 'rag/llm/encoder.py')
-rw-r--r--rag/llm/encoder.py19
1 files changed, 14 insertions, 5 deletions
diff --git a/rag/llm/encoder.py b/rag/llm/encoder.py
index 095c6af..a686aaf 100644
--- a/rag/llm/encoder.py
+++ b/rag/llm/encoder.py
@@ -1,9 +1,13 @@
import os
from typing import List
+from uuid import uuid4
import numpy as np
import ollama
from langchain_core.documents import Document
+from qdrant_client.http.models import StrictFloat
+
+from rag.db.embeddings import Point
class Encoder:
@@ -11,13 +15,18 @@ class Encoder:
self.model = os.environ["ENCODER_MODEL"]
self.query_prompt = "Represent this sentence for searching relevant passages: "
- def __encode(self, prompt: str) -> np.ndarray:
- x = ollama.embeddings(model=self.model, prompt=prompt)
- x = np.array([x["embedding"]]).astype("float32")
- return x
+ def __encode(self, prompt: str) -> List[StrictFloat]:
+ return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"])
def encode_document(self, chunks: List[Document]) -> np.ndarray:
- return np.concatenate([self.__encode(chunk.page_content) for chunk in chunks])
+ return [
+ Point(
+ id=str(uuid4()),
+ vector=self.__encode(chunk.page_content),
+ payload={"text": chunk.page_content},
+ )
+ for chunk in chunks
+ ]
def query(self, query: str) -> np.ndarray:
query = self.query_prompt + query