summaryrefslogtreecommitdiff
path: root/rag/llm/encoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/llm/encoder.py')
-rw-r--r--rag/llm/encoder.py15
1 files changed, 12 insertions, 3 deletions
diff --git a/rag/llm/encoder.py b/rag/llm/encoder.py
index 95f3c6a..a59b1b4 100644
--- a/rag/llm/encoder.py
+++ b/rag/llm/encoder.py
@@ -1,5 +1,6 @@
import os
-from typing import Iterator, List
+from pathlib import Path
+from typing import List, Dict
from uuid import uuid4
import ollama
@@ -13,6 +14,7 @@ try:
except ModuleNotFoundError:
from db.vector import Point
+
class Encoder:
def __init__(self) -> None:
self.model = os.environ["ENCODER_MODEL"]
@@ -21,13 +23,20 @@ class Encoder:
def __encode(self, prompt: str) -> List[StrictFloat]:
return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"])
- def encode_document(self, chunks: Iterator[Document]) -> List[Point]:
+ def __get_source(self, metadata: Dict[str, str]) -> str:
+ source = metadata["source"]
+ return Path(source).name
+
+ def encode_document(self, chunks: List[Document]) -> List[Point]:
log.debug("Encoding document...")
return [
Point(
id=uuid4().hex,
vector=self.__encode(chunk.page_content),
- payload={"text": chunk.page_content},
+ payload={
+ "text": chunk.page_content,
+ "source": self.__get_source(chunk.metadata),
+ },
)
for chunk in chunks
]