summaryrefslogtreecommitdiff
path: root/rag/llms/encoder.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-05 00:27:27 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-05 00:27:27 +0200
commit633f180eb0ccdc4772d5d705873cef1e33507976 (patch)
tree313d9653c188b4b788b5d929633c794eb8d78ae2 /rag/llms/encoder.py
parent47084bfa9fe2f21bfc079bdd4d057b5768083454 (diff)
Wip llms
Diffstat (limited to 'rag/llms/encoder.py')
-rw-r--r--rag/llms/encoder.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/rag/llms/encoder.py b/rag/llms/encoder.py
new file mode 100644
index 0000000..758b523
--- /dev/null
+++ b/rag/llms/encoder.py
@@ -0,0 +1,24 @@
+from typing import List
+import ollama
+import numpy as np
+
+# FIXME: .env
+ENCODER_MODEL = "mxbai-embed-large"
+
+
+class Encoder:
+ def __init__(self) -> None:
+ self.query_prompt = "Represent this sentence for searching relevant passages: "
+
+ def __encode(self, prompt: str) -> np.ndarray:
+ x = ollama.embeddings(model=ENCODER_MODEL, prompt=prompt)
+ x = np.array([x["embedding"]]).astype("float32")
+ return x
+
+ def encode(self, doc: List[str]) -> List[np.ndarray]:
+ return [self.__encode(chunk) for chunk in doc]
+
+ def query(self, query: str) -> np.ndarray:
+ query = self.query_prompt + query
+ return self.__encode(query)
+