From 7dfd91176fecfc442783260183ad8f34807cc284 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 5 Apr 2024 02:04:20 +0200 Subject: Update encoder --- rag/llm/encoder.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'rag/llm') diff --git a/rag/llm/encoder.py b/rag/llm/encoder.py index d5e0566..095c6af 100644 --- a/rag/llm/encoder.py +++ b/rag/llm/encoder.py @@ -3,6 +3,7 @@ from typing import List import numpy as np import ollama +from langchain_core.documents import Document class Encoder: @@ -11,12 +12,12 @@ class Encoder: self.query_prompt = "Represent this sentence for searching relevant passages: " def __encode(self, prompt: str) -> np.ndarray: - x = ollama.embeddings(model=ENCODER_MODEL, prompt=prompt) + x = ollama.embeddings(model=self.model, prompt=prompt) x = np.array([x["embedding"]]).astype("float32") return x - def encode(self, doc: List[str]) -> List[np.ndarray]: - return [self.__encode(chunk) for chunk in doc] + def encode_document(self, chunks: List[Document]) -> np.ndarray: + return np.concatenate([self.__encode(chunk.page_content) for chunk in chunks]) def query(self, query: str) -> np.ndarray: query = self.query_prompt + query -- cgit v1.2.3-70-g09d2