summaryrefslogtreecommitdiff
path: root/rag/llm
diff options
context:
space:
mode:
Diffstat (limited to 'rag/llm')
-rw-r--r--rag/llm/encoder.py19
1 files changed, 14 insertions, 5 deletions
diff --git a/rag/llm/encoder.py b/rag/llm/encoder.py
index 095c6af..a686aaf 100644
--- a/rag/llm/encoder.py
+++ b/rag/llm/encoder.py
@@ -1,9 +1,13 @@
import os
from typing import List
+from uuid import uuid4
import numpy as np
import ollama
from langchain_core.documents import Document
+from qdrant_client.http.models import StrictFloat
+
+from rag.db.embeddings import Point
class Encoder:
@@ -11,13 +15,18 @@ class Encoder:
self.model = os.environ["ENCODER_MODEL"]
self.query_prompt = "Represent this sentence for searching relevant passages: "
- def __encode(self, prompt: str) -> np.ndarray:
- x = ollama.embeddings(model=self.model, prompt=prompt)
- x = np.array([x["embedding"]]).astype("float32")
- return x
+ def __encode(self, prompt: str) -> List[StrictFloat]:
+ return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"])
def encode_document(self, chunks: List[Document]) -> np.ndarray:
- return np.concatenate([self.__encode(chunk.page_content) for chunk in chunks])
+ return [
+ Point(
+ id=str(uuid4()),
+ vector=self.__encode(chunk.page_content),
+ payload={"text": chunk.page_content},
+ )
+ for chunk in chunks
+ ]
def query(self, query: str) -> np.ndarray:
query = self.query_prompt + query