diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-09 00:14:00 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-09 00:14:00 +0200 |
commit | 91ddb3672e514fa9824609ff047d7cab0c65631a (patch) | |
tree | 009fd82618588d2960b5207128e86875f73cccdc /rag/llm | |
parent | d487ef8b04cc7f5ac1491f0638f902fe2abe5ac5 (diff) |
Refactor
Diffstat (limited to 'rag/llm')
-rw-r--r-- | rag/llm/__init__.py | 0 | ||||
-rw-r--r-- | rag/llm/cohere_generator.py | 29 | ||||
-rw-r--r-- | rag/llm/encoder.py | 47 | ||||
-rw-r--r-- | rag/llm/ollama_generator.py | 76 |
4 files changed, 0 insertions, 152 deletions
diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/rag/llm/__init__.py +++ /dev/null diff --git a/rag/llm/cohere_generator.py b/rag/llm/cohere_generator.py deleted file mode 100644 index a6feacd..0000000 --- a/rag/llm/cohere_generator.py +++ /dev/null @@ -1,29 +0,0 @@ -import os -from typing import Any, Generator -import cohere - -from dataclasses import asdict -try: - from rag.llm.ollama_generator import Prompt -except ModuleNotFoundError: - from llm.ollama_generator import Prompt -from loguru import logger as log - - -class CohereGenerator: - def __init__(self) -> None: - self.client = cohere.Client(os.environ["COHERE_API_KEY"]) - - def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: - log.debug("Generating answer from cohere") - for event in self.client.chat_stream( - message=prompt.query, - documents=[asdict(d) for d in prompt.documents], - prompt_truncation="AUTO", - ): - if event.event_type == "text-generation": - yield event.text - elif event.event_type == "citation-generation": - yield event.citations - elif event.event_type == "stream-end": - yield event.finish_reason diff --git a/rag/llm/encoder.py b/rag/llm/encoder.py deleted file mode 100644 index a59b1b4..0000000 --- a/rag/llm/encoder.py +++ /dev/null @@ -1,47 +0,0 @@ -import os -from pathlib import Path -from typing import List, Dict -from uuid import uuid4 - -import ollama -from langchain_core.documents import Document -from loguru import logger as log -from qdrant_client.http.models import StrictFloat - - -try: - from rag.db.vector import Point -except ModuleNotFoundError: - from db.vector import Point - - -class Encoder: - def __init__(self) -> None: - self.model = os.environ["ENCODER_MODEL"] - self.query_prompt = "Represent this sentence for searching relevant passages: " - - def __encode(self, prompt: str) -> List[StrictFloat]: - return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"]) - - def __get_source(self, metadata: Dict[str, str]) -> str: - source = metadata["source"] - return Path(source).name - - def encode_document(self, chunks: List[Document]) -> List[Point]: - log.debug("Encoding document...") - return [ - Point( - id=uuid4().hex, - vector=self.__encode(chunk.page_content), - payload={ - "text": chunk.page_content, - "source": self.__get_source(chunk.metadata), - }, - ) - for chunk in chunks - ] - - def encode_query(self, query: str) -> List[StrictFloat]: - log.debug(f"Encoding query: {query}") - query = self.query_prompt + query - return self.__encode(query) diff --git a/rag/llm/ollama_generator.py b/rag/llm/ollama_generator.py deleted file mode 100644 index dd17f8d..0000000 --- a/rag/llm/ollama_generator.py +++ /dev/null @@ -1,76 +0,0 @@ -import os -from dataclasses import dataclass -from typing import Any, Generator, List - -import ollama -from loguru import logger as log - -try: - from rag.db.vector import Document -except ModuleNotFoundError: - from db.vector import Document - - -@dataclass -class Prompt: - query: str - documents: List[Document] - - -SYSTEM_PROMPT = ( - "# System Preamble" - "## Basic Rules" - "When you answer the user's requests, you cite your sources in your answers, according to those instructions." - "Answer the following question using the provided context.\n" - "## Style Guide" - "Unless the user asks for a different style of answer, you should answer " - "in full sentences, using proper grammar and spelling." -) - - -class OllamaGenerator: - def __init__(self) -> None: - self.model = os.environ["GENERATOR_MODEL"] - - def __context(self, documents: List[Document]) -> str: - results = [ - f"Document: {i}\ntitle: {doc.title}\n{doc.text}" - for i, doc in enumerate(documents) - ] - return "\n".join(results) - - def __metaprompt(self, prompt: Prompt) -> str: - # Include sources - metaprompt = ( - f'Question: "{prompt.query.strip()}"\n\n' - "Context:\n" - "<result>\n" - f"{self.__context(prompt.documents)}\n\n" - "</result>\n" - "Carefully perform the following instructions, in order, starting each " - "with a new line.\n" - "Firstly, Decide which of the retrieved documents are relevant to the " - "user's last input by writing 'Relevant Documents:' followed by " - "comma-separated list of document numbers.\n If none are relevant, you " - "should instead write 'None'.\n" - "Secondly, Decide which of the retrieved documents contain facts that " - "should be cited in a good answer to the user's last input by writing " - "'Cited Documents:' followed a comma-separated list of document numbers. " - "If you dont want to cite any of them, you should instead write 'None'.\n" - "Thirdly, Write 'Answer:' followed by a response to the user's last input " - "in high quality natural english. Use the retrieved documents to help you. " - "Do not insert any citations or grounding markup.\n" - "Finally, Write 'Grounded answer:' followed by a response to the user's " - "last input in high quality natural english. Use the symbols <co: doc> and " - "</co: doc> to indicate when a fact comes from a document in the search " - "result, e.g <co: 0>my fact</co: 0> for a fact from document 0." - ) - return metaprompt - - def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: - log.debug("Generating answer...") - metaprompt = self.__metaprompt(prompt) - for chunk in ollama.generate( - model=self.model, prompt=metaprompt, system=SYSTEM_PROMPT, stream=True - ): - yield chunk |