summaryrefslogtreecommitdiff
path: root/rag/llm/encoder.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-05 00:42:02 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-05 00:42:02 +0200
commit1cf0a401054c3e3ebde60bfd73ad15e39bc531e6 (patch)
tree633314ef342c35213cfa01607dd3c98e77b7cdd2 /rag/llm/encoder.py
parent633f180eb0ccdc4772d5d705873cef1e33507976 (diff)
Rename llm
Diffstat (limited to 'rag/llm/encoder.py')
-rw-r--r--rag/llm/encoder.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/rag/llm/encoder.py b/rag/llm/encoder.py
new file mode 100644
index 0000000..d5e0566
--- /dev/null
+++ b/rag/llm/encoder.py
@@ -0,0 +1,23 @@
+import os
+from typing import List
+
+import numpy as np
+import ollama
+
+
+class Encoder:
+ def __init__(self) -> None:
+ self.model = os.environ["ENCODER_MODEL"]
+ self.query_prompt = "Represent this sentence for searching relevant passages: "
+
+ def __encode(self, prompt: str) -> np.ndarray:
+ x = ollama.embeddings(model=ENCODER_MODEL, prompt=prompt)
+ x = np.array([x["embedding"]]).astype("float32")
+ return x
+
+ def encode(self, doc: List[str]) -> List[np.ndarray]:
+ return [self.__encode(chunk) for chunk in doc]
+
+ def query(self, query: str) -> np.ndarray:
+ query = self.query_prompt + query
+ return self.__encode(query)