diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-12 23:42:03 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-12 23:42:03 +0100 |
commit | 8fdb6435e15703fa5b76df19728d905650ee1aef (patch) | |
tree | be3bec9e5cab4ef7f9d94528d102e57ce9b16c3f /src/text_recognizer/models | |
parent | dc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 (diff) | |
parent | 6cb08a110620ee09fe9d8a5d008197a801d025df (diff) |
Working cnn transformer.
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r-- | src/text_recognizer/models/__init__.py | 14 | ||||
-rw-r--r-- | src/text_recognizer/models/base.py | 8 | ||||
-rw-r--r-- | src/text_recognizer/models/metrics.py | 21 | ||||
-rw-r--r-- | src/text_recognizer/models/transformer_encoder_model.py | 111 | ||||
-rw-r--r-- | src/text_recognizer/models/transformer_model.py (renamed from src/text_recognizer/models/vision_transformer_model.py) | 13 |
5 files changed, 33 insertions, 134 deletions
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py index 28aa52e..53340f1 100644 --- a/src/text_recognizer/models/__init__.py +++ b/src/text_recognizer/models/__init__.py @@ -2,18 +2,16 @@ from .base import Model from .character_model import CharacterModel from .crnn_model import CRNNModel -from .metrics import accuracy, cer, wer -from .transformer_encoder_model import TransformerEncoderModel -from .vision_transformer_model import VisionTransformerModel +from .metrics import accuracy, accuracy_ignore_pad, cer, wer +from .transformer_model import TransformerModel __all__ = [ - "Model", + "accuracy", + "accuracy_ignore_pad", "cer", "CharacterModel", "CRNNModel", - "CNNTransfromerModel", - "accuracy", - "TransformerEncoderModel", - "VisionTransformerModel", + "Model", + "TransformerModel", "wer", ] diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index cc44c92..a945b41 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -49,7 +49,7 @@ class Model(ABC): network_args (Optional[Dict]): Arguments for the network. Defaults to None. dataset_args (Optional[Dict]): Arguments for the dataset. metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. - criterion (Optional[Callable]): The criterion to evaulate the preformance of the network. + criterion (Optional[Callable]): The criterion to evaluate the performance of the network. Defaults to None. criterion_args (Optional[Dict]): Dict of arguments for criterion. Defaults to None. optimizer (Optional[Callable]): The optimizer for updating the weights. Defaults to None. @@ -221,7 +221,7 @@ class Model(ABC): def _configure_network(self, network_fn: Type[nn.Module]) -> None: """Loads the network.""" - # If no network arguemnts are given, load pretrained weights if they exist. + # If no network arguments are given, load pretrained weights if they exist. if self._network_args is None: self.load_weights(network_fn) else: @@ -245,7 +245,7 @@ class Model(ABC): self._optimizer = None if self._optimizer and self._lr_scheduler is not None: - if "OneCycleLR" in str(self._lr_scheduler): + if "steps_per_epoch" in self.lr_scheduler_args: self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader()) # Assume lr scheduler should update at each epoch if not specified. @@ -412,7 +412,7 @@ class Model(ABC): self._optimizer.load_state_dict(checkpoint["optimizer_state"]) if self._lr_scheduler is not None: - # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs + # Does not work when loading from previous checkpoint and trying to train beyond the last max epochs # with OneCycleLR. if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR": self._lr_scheduler["lr_scheduler"].load_state_dict( diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py index 42c3c6e..af9adb5 100644 --- a/src/text_recognizer/models/metrics.py +++ b/src/text_recognizer/models/metrics.py @@ -6,7 +6,23 @@ from torch import Tensor from text_recognizer.networks import greedy_decoder -def accuracy(outputs: Tensor, labels: Tensor) -> float: +def accuracy_ignore_pad( + output: Tensor, + target: Tensor, + pad_index: int = 79, + eos_index: int = 81, + seq_len: int = 97, +) -> float: + """Sets all predictions after eos to pad.""" + start_indices = torch.nonzero(target == eos_index, as_tuple=False).squeeze(1) + end_indices = torch.arange(seq_len, target.shape[0] + 1, seq_len) + for start, stop in zip(start_indices, end_indices): + output[start + 1 : stop] = pad_index + + return accuracy(output, target) + + +def accuracy(outputs: Tensor, labels: Tensor,) -> float: """Computes the accuracy. Args: @@ -17,10 +33,9 @@ def accuracy(outputs: Tensor, labels: Tensor) -> float: float: The accuracy for the batch. """ - # eos_index = torch.nonzero(labels == eos, as_tuple=False) - # eos_index = eos_index[0].item() if eos_index.nelement() else -1 _, predicted = torch.max(outputs, dim=-1) + acc = (predicted == labels).sum().float() / labels.shape[0] acc = acc.item() return acc diff --git a/src/text_recognizer/models/transformer_encoder_model.py b/src/text_recognizer/models/transformer_encoder_model.py deleted file mode 100644 index e35e298..0000000 --- a/src/text_recognizer/models/transformer_encoder_model.py +++ /dev/null @@ -1,111 +0,0 @@ -"""Defines the CNN-Transformer class.""" -from typing import Callable, Dict, List, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.datasets import EmnistMapper -from text_recognizer.models.base import Model - - -class TransformerEncoderModel(Model): - """A class for only using the encoder part in the sequence modelling.""" - - def __init__( - self, - network_fn: Type[nn.Module], - dataset: Type[Dataset], - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - super().__init__( - network_fn, - dataset, - network_args, - dataset_args, - metrics, - criterion, - criterion_args, - optimizer, - optimizer_args, - lr_scheduler, - lr_scheduler_args, - swa_args, - device, - ) - # self.init_token = dataset_args["args"]["init_token"] - self.pad_token = dataset_args["args"]["pad_token"] - self.eos_token = dataset_args["args"]["eos_token"] - if network_args is not None: - self.max_len = network_args["max_len"] - else: - self.max_len = 128 - - if self._mapper is None: - self._mapper = EmnistMapper( - # init_token=self.init_token, - pad_token=self.pad_token, - eos_token=self.eos_token, - ) - self.tensor_transform = ToTensor() - - self.softmax = nn.Softmax(dim=2) - - @torch.no_grad() - def _generate_sentence(self, image: Tensor) -> Tuple[List, float]: - logits = self.network(image) - # Convert logits to probabilities. - probs = self.softmax(logits).squeeze(0) - - confidence, pred_tokens = probs.max(1) - pred_tokens = pred_tokens - - eos_index = torch.nonzero( - pred_tokens == self._mapper(self.eos_token), as_tuple=False, - ) - - eos_index = eos_index[0].item() if eos_index.nelement() else -1 - - predicted_characters = "".join( - [self.mapper(x) for x in pred_tokens[:eos_index].tolist()] - ) - - confidence = np.min(confidence.tolist()) - - return predicted_characters, confidence - - @torch.no_grad() - def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: - """Predict on a single input.""" - self.eval() - - if image.dtype == np.uint8: - # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. - image = self.tensor_transform(image) - - # Rescale image between 0 and 1. - if image.dtype == torch.uint8: - # If the image is an unscaled tensor. - image = image.type("torch.FloatTensor") / 255 - - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - - (predicted_characters, confidence_of_prediction,) = self._generate_sentence( - image - ) - - return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/models/vision_transformer_model.py b/src/text_recognizer/models/transformer_model.py index 3d36437..968a047 100644 --- a/src/text_recognizer/models/vision_transformer_model.py +++ b/src/text_recognizer/models/transformer_model.py @@ -13,7 +13,7 @@ from text_recognizer.models.base import Model from text_recognizer.networks import greedy_decoder -class VisionTransformerModel(Model): +class TransformerModel(Model): """Model for predicting a sequence of characters from an image of a text line with a cnn-transformer.""" def __init__( @@ -50,10 +50,7 @@ class VisionTransformerModel(Model): self.init_token = dataset_args["args"]["init_token"] self.pad_token = dataset_args["args"]["pad_token"] self.eos_token = dataset_args["args"]["eos_token"] - if network_args is not None: - self.max_len = network_args["max_len"] - else: - self.max_len = 120 + self.max_len = 120 if self._mapper is None: self._mapper = EmnistMapper( @@ -67,7 +64,7 @@ class VisionTransformerModel(Model): @torch.no_grad() def _generate_sentence(self, image: Tensor) -> Tuple[List, float]: - src = self.network.preprocess_input(image) + src = self.network.extract_image_features(image) memory = self.network.encoder(src) confidence_of_predictions = [] @@ -75,7 +72,7 @@ class VisionTransformerModel(Model): for _ in range(self.max_len - 1): trg = torch.tensor(trg_indices, device=self.device)[None, :].long() - trg = self.network.preprocess_target(trg) + trg = self.network.target_embedding(trg) logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None) # Convert logits to probabilities. @@ -101,7 +98,7 @@ class VisionTransformerModel(Model): self.eval() if image.dtype == np.uint8: - # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1]. image = self.tensor_transform(image) # Rescale image between 0 and 1. |