summaryrefslogtreecommitdiff
path: root/rag
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-06 01:22:13 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-06 01:22:13 +0200
commit3ba6eca92a339e28ffce14adf46d2fb71e6f4958 (patch)
treed03d23c080c51382e4c987f52932253a0f814136 /rag
parent13ac875b2269756045834d7a64e7b35acb9ce0b4 (diff)
Refactor
Diffstat (limited to 'rag')
-rw-r--r--rag/llm/encoder.py12
-rw-r--r--rag/llm/generator.py14
2 files changed, 14 insertions, 12 deletions
diff --git a/rag/llm/encoder.py b/rag/llm/encoder.py
index a686aaf..94d5559 100644
--- a/rag/llm/encoder.py
+++ b/rag/llm/encoder.py
@@ -1,13 +1,13 @@
import os
-from typing import List
+from typing import List, Optional
from uuid import uuid4
-import numpy as np
import ollama
from langchain_core.documents import Document
+from loguru import logger as log
from qdrant_client.http.models import StrictFloat
-from rag.db.embeddings import Point
+from rag.db.vector import Point
class Encoder:
@@ -18,7 +18,8 @@ class Encoder:
def __encode(self, prompt: str) -> List[StrictFloat]:
return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"])
- def encode_document(self, chunks: List[Document]) -> np.ndarray:
+ def encode_document(self, chunks: List[Document]) -> List[Point]:
+ log.debug("Encoding document...")
return [
Point(
id=str(uuid4()),
@@ -28,6 +29,7 @@ class Encoder:
for chunk in chunks
]
- def query(self, query: str) -> np.ndarray:
+ def encode_query(self, query: str) -> List[StrictFloat]:
+ log.debug(f"Encoding query: {query}")
query = self.query_prompt + query
return self.__encode(query)
diff --git a/rag/llm/generator.py b/rag/llm/generator.py
index cbe9474..5f164e7 100644
--- a/rag/llm/generator.py
+++ b/rag/llm/generator.py
@@ -2,16 +2,14 @@ import os
from dataclasses import dataclass
import ollama
+from loguru import logger as log
@dataclass
class Prompt:
- question: str
+ query: str
context: str
- # def context(self) -> str:
- # return "\n".join(point.payload["text"] for point in self.points)
-
class Generator:
def __init__(self) -> None:
@@ -22,14 +20,16 @@ class Generator:
f"You are a {role}.\n"
"Answer the following question using the provided context.\n"
"If you can't find the answer, do not pretend you know it,"
- 'but answer "I don\'t know".'
- f"Question: {prompt.question.strip()}\n\n"
+ 'but answer "I don\'t know".\n\n'
+ f"Question: {prompt.query.strip()}\n\n"
"Context:\n"
f"{prompt.context.strip()}\n\n"
"Answer:\n"
)
return metaprompt
- def generate(self, role: str, prompt: Prompt) -> str:
+ def generate(self, prompt: Prompt, role: str) -> str:
+ log.debug("Generating answer...")
metaprompt = self.__metaprompt(role, prompt)
+ print(f"metaprompt = \n{metaprompt}")
return ollama.generate(model=self.model, prompt=metaprompt)