From 36bb9e4cd1f42ae6e60ed4e296beab0eb462f376 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 5 Apr 2024 18:31:27 +0200 Subject: Refactor encoder --- rag/llm/encoder.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) (limited to 'rag') diff --git a/rag/llm/encoder.py b/rag/llm/encoder.py index 095c6af..a686aaf 100644 --- a/rag/llm/encoder.py +++ b/rag/llm/encoder.py @@ -1,9 +1,13 @@ import os from typing import List +from uuid import uuid4 import numpy as np import ollama from langchain_core.documents import Document +from qdrant_client.http.models import StrictFloat + +from rag.db.embeddings import Point class Encoder: @@ -11,13 +15,18 @@ class Encoder: self.model = os.environ["ENCODER_MODEL"] self.query_prompt = "Represent this sentence for searching relevant passages: " - def __encode(self, prompt: str) -> np.ndarray: - x = ollama.embeddings(model=self.model, prompt=prompt) - x = np.array([x["embedding"]]).astype("float32") - return x + def __encode(self, prompt: str) -> List[StrictFloat]: + return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"]) def encode_document(self, chunks: List[Document]) -> np.ndarray: - return np.concatenate([self.__encode(chunk.page_content) for chunk in chunks]) + return [ + Point( + id=str(uuid4()), + vector=self.__encode(chunk.page_content), + payload={"text": chunk.page_content}, + ) + for chunk in chunks + ] def query(self, query: str) -> np.ndarray: query = self.query_prompt + query -- cgit v1.2.3-70-g09d2