summaryrefslogtreecommitdiff
path: root/text_recognizer/models/vqvae_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models/vqvae_model.py')
-rw-r--r--text_recognizer/models/vqvae_model.py80
1 files changed, 0 insertions, 80 deletions
diff --git a/text_recognizer/models/vqvae_model.py b/text_recognizer/models/vqvae_model.py
deleted file mode 100644
index 70f6f1f..0000000
--- a/text_recognizer/models/vqvae_model.py
+++ /dev/null
@@ -1,80 +0,0 @@
-"""Defines the VQVAEModel class."""
-from typing import Callable, Dict, Optional, Tuple, Type, Union
-
-import numpy as np
-import torch
-from torch import nn
-from torch.utils.data import Dataset
-from torchvision.transforms import ToTensor
-
-from text_recognizer.datasets import EmnistMapper
-from text_recognizer.models.base import Model
-
-
-class VQVAEModel(Model):
- """Model for reconstructing images from codebook."""
-
- def __init__(
- self,
- network_fn: Type[nn.Module],
- dataset: Type[Dataset],
- network_args: Optional[Dict] = None,
- dataset_args: Optional[Dict] = None,
- metrics: Optional[Dict] = None,
- criterion: Optional[Callable] = None,
- criterion_args: Optional[Dict] = None,
- optimizer: Optional[Callable] = None,
- optimizer_args: Optional[Dict] = None,
- lr_scheduler: Optional[Callable] = None,
- lr_scheduler_args: Optional[Dict] = None,
- swa_args: Optional[Dict] = None,
- device: Optional[str] = None,
- ) -> None:
- """Initializes the CharacterModel."""
-
- super().__init__(
- network_fn,
- dataset,
- network_args,
- dataset_args,
- metrics,
- criterion,
- criterion_args,
- optimizer,
- optimizer_args,
- lr_scheduler,
- lr_scheduler_args,
- swa_args,
- device,
- )
- self.pad_token = dataset_args["args"]["pad_token"]
- if self._mapper is None:
- self._mapper = EmnistMapper(pad_token=self.pad_token,)
- self.tensor_transform = ToTensor()
- self.softmax = nn.Softmax(dim=0)
-
- @torch.no_grad()
- def predict_on_image(self, image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
- """Reconstruction of image.
-
- Args:
- image (Union[np.ndarray, torch.Tensor]): An image containing a character.
-
- Returns:
- Tuple[str, float]: The predicted character and the confidence in the prediction.
-
- """
- self.eval()
-
- if image.dtype == np.uint8:
- # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
- image = self.tensor_transform(image)
- if image.dtype == torch.uint8:
- # If the image is an unscaled tensor.
- image = image.type("torch.FloatTensor") / 255
-
- # Put the image tensor on the device the model weights are on.
- image = image.to(self.device)
- image_reconstructed, _ = self.forward(image)
-
- return image_reconstructed