From beeaef529e7c893a3475fe27edc880e283373725 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Sun, 8 Nov 2020 12:41:04 +0100 Subject: Trying to get the CNNTransformer to work, but it is hard. --- src/text_recognizer/models/__init__.py | 7 +- src/text_recognizer/models/base.py | 9 +- src/text_recognizer/models/character_model.py | 3 +- src/text_recognizer/models/crnn_model.py | 119 +++++++++++++++++++++ src/text_recognizer/models/line_ctc_model.py | 117 -------------------- src/text_recognizer/models/metrics.py | 5 +- .../models/transformer_encoder_model.py | 111 +++++++++++++++++++ .../models/vision_transformer_model.py | 12 ++- 8 files changed, 253 insertions(+), 130 deletions(-) create mode 100644 src/text_recognizer/models/crnn_model.py delete mode 100644 src/text_recognizer/models/line_ctc_model.py create mode 100644 src/text_recognizer/models/transformer_encoder_model.py (limited to 'src/text_recognizer/models') diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py index 0855079..28aa52e 100644 --- a/src/text_recognizer/models/__init__.py +++ b/src/text_recognizer/models/__init__.py @@ -1,16 +1,19 @@ """Model modules.""" from .base import Model from .character_model import CharacterModel -from .line_ctc_model import LineCTCModel +from .crnn_model import CRNNModel from .metrics import accuracy, cer, wer +from .transformer_encoder_model import TransformerEncoderModel from .vision_transformer_model import VisionTransformerModel __all__ = [ "Model", "cer", "CharacterModel", + "CRNNModel", "CNNTransfromerModel", - "LineCTCModel", "accuracy", + "TransformerEncoderModel", + "VisionTransformerModel", "wer", ] diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index cbef787..cc44c92 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -141,11 +141,12 @@ class Model(ABC): "transform" in self.dataset_args["args"] and self.dataset_args["args"]["transform"] is not None ): - transform_ = [ - getattr(transforms_module, t["type"])() - for t in self.dataset_args["args"]["transform"] - ] + transform_ = [] + for t in self.dataset_args["args"]["transform"]: + args = t["args"] or {} + transform_.append(getattr(transforms_module, t["type"])(**args)) self.dataset_args["args"]["transform"] = Compose(transform_) + if ( "target_transform" in self.dataset_args["args"] and self.dataset_args["args"]["target_transform"] is not None diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index 3cf6695..f9944f3 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -47,8 +47,9 @@ class CharacterModel(Model): swa_args, device, ) + self.pad_token = dataset_args["args"]["pad_token"] if self._mapper is None: - self._mapper = EmnistMapper() + self._mapper = EmnistMapper(pad_token=self.pad_token,) self.tensor_transform = ToTensor() self.softmax = nn.Softmax(dim=0) diff --git a/src/text_recognizer/models/crnn_model.py b/src/text_recognizer/models/crnn_model.py new file mode 100644 index 0000000..1e01a83 --- /dev/null +++ b/src/text_recognizer/models/crnn_model.py @@ -0,0 +1,119 @@ +"""Defines the CRNNModel class.""" +from typing import Callable, Dict, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets import EmnistMapper +from text_recognizer.models.base import Model +from text_recognizer.networks import greedy_decoder + + +class CRNNModel(Model): + """Model for predicting a sequence of characters from an image of a text line.""" + + def __init__( + self, + network_fn: Type[nn.Module], + dataset: Type[Dataset], + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + + self.pad_token = dataset_args["args"]["pad_token"] + if self._mapper is None: + self._mapper = EmnistMapper(pad_token=self.pad_token,) + self.tensor_transform = ToTensor() + + def criterion(self, output: Tensor, targets: Tensor) -> Tensor: + """Computes the CTC loss. + + Args: + output (Tensor): Model predictions. + targets (Tensor): Correct output sequence. + + Returns: + Tensor: The CTC loss. + + """ + + # Input lengths on the form [T, B] + input_lengths = torch.full( + size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long, + ) + + # Configure target tensors for ctc loss. + targets_ = Tensor([]).to(self.device) + target_lengths = [] + for t in targets: + # Remove padding symbol as it acts as the blank symbol. + t = t[t < 79] + targets_ = torch.cat([targets_, t]) + target_lengths.append(len(t)) + + targets = targets_.type(dtype=torch.long) + target_lengths = ( + torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device) + ) + + return self._criterion(output, targets, input_lengths, target_lengths) + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: + """Predict on a single input.""" + self.eval() + + if image.dtype == np.uint8: + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + + # Rescale image between 0 and 1. + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + log_probs = self.forward(image) + + raw_pred, _ = greedy_decoder( + predictions=log_probs, + character_mapper=self.mapper, + blank_label=79, + collapse_repeated=True, + ) + + log_probs, _ = log_probs.max(dim=2) + + predicted_characters = "".join(raw_pred[0]) + confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item() + + return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/line_ctc_model.py deleted file mode 100644 index cdc2d8b..0000000 --- a/src/text_recognizer/models/line_ctc_model.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Defines the LineCTCModel class.""" -from typing import Callable, Dict, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.datasets import EmnistMapper -from text_recognizer.models.base import Model -from text_recognizer.networks import greedy_decoder - - -class LineCTCModel(Model): - """Model for predicting a sequence of characters from an image of a text line.""" - - def __init__( - self, - network_fn: Type[nn.Module], - dataset: Type[Dataset], - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - super().__init__( - network_fn, - dataset, - network_args, - dataset_args, - metrics, - criterion, - criterion_args, - optimizer, - optimizer_args, - lr_scheduler, - lr_scheduler_args, - swa_args, - device, - ) - if self._mapper is None: - self._mapper = EmnistMapper() - self.tensor_transform = ToTensor() - - def criterion(self, output: Tensor, targets: Tensor) -> Tensor: - """Computes the CTC loss. - - Args: - output (Tensor): Model predictions. - targets (Tensor): Correct output sequence. - - Returns: - Tensor: The CTC loss. - - """ - - # Input lengths on the form [T, B] - input_lengths = torch.full( - size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long, - ) - - # Configure target tensors for ctc loss. - targets_ = Tensor([]).to(self.device) - target_lengths = [] - for t in targets: - # Remove padding symbol as it acts as the blank symbol. - t = t[t < 79] - targets_ = torch.cat([targets_, t]) - target_lengths.append(len(t)) - - targets = targets_.type(dtype=torch.long) - target_lengths = ( - torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device) - ) - - return self._criterion(output, targets, input_lengths, target_lengths) - - @torch.no_grad() - def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: - """Predict on a single input.""" - self.eval() - - if image.dtype == np.uint8: - # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. - image = self.tensor_transform(image) - - # Rescale image between 0 and 1. - if image.dtype == torch.uint8: - # If the image is an unscaled tensor. - image = image.type("torch.FloatTensor") / 255 - - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - log_probs = self.forward(image) - - raw_pred, _ = greedy_decoder( - predictions=log_probs, - character_mapper=self.mapper, - blank_label=79, - collapse_repeated=True, - ) - - log_probs, _ = log_probs.max(dim=2) - - predicted_characters = "".join(raw_pred[0]) - confidence_of_prediction = torch.exp(-log_probs.sum()).item() - - return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py index 6a26216..42c3c6e 100644 --- a/src/text_recognizer/models/metrics.py +++ b/src/text_recognizer/models/metrics.py @@ -17,7 +17,10 @@ def accuracy(outputs: Tensor, labels: Tensor) -> float: float: The accuracy for the batch. """ - _, predicted = torch.max(outputs.data, dim=1) + # eos_index = torch.nonzero(labels == eos, as_tuple=False) + # eos_index = eos_index[0].item() if eos_index.nelement() else -1 + + _, predicted = torch.max(outputs, dim=-1) acc = (predicted == labels).sum().float() / labels.shape[0] acc = acc.item() return acc diff --git a/src/text_recognizer/models/transformer_encoder_model.py b/src/text_recognizer/models/transformer_encoder_model.py new file mode 100644 index 0000000..e35e298 --- /dev/null +++ b/src/text_recognizer/models/transformer_encoder_model.py @@ -0,0 +1,111 @@ +"""Defines the CNN-Transformer class.""" +from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets import EmnistMapper +from text_recognizer.models.base import Model + + +class TransformerEncoderModel(Model): + """A class for only using the encoder part in the sequence modelling.""" + + def __init__( + self, + network_fn: Type[nn.Module], + dataset: Type[Dataset], + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + # self.init_token = dataset_args["args"]["init_token"] + self.pad_token = dataset_args["args"]["pad_token"] + self.eos_token = dataset_args["args"]["eos_token"] + if network_args is not None: + self.max_len = network_args["max_len"] + else: + self.max_len = 128 + + if self._mapper is None: + self._mapper = EmnistMapper( + # init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + ) + self.tensor_transform = ToTensor() + + self.softmax = nn.Softmax(dim=2) + + @torch.no_grad() + def _generate_sentence(self, image: Tensor) -> Tuple[List, float]: + logits = self.network(image) + # Convert logits to probabilities. + probs = self.softmax(logits).squeeze(0) + + confidence, pred_tokens = probs.max(1) + pred_tokens = pred_tokens + + eos_index = torch.nonzero( + pred_tokens == self._mapper(self.eos_token), as_tuple=False, + ) + + eos_index = eos_index[0].item() if eos_index.nelement() else -1 + + predicted_characters = "".join( + [self.mapper(x) for x in pred_tokens[:eos_index].tolist()] + ) + + confidence = np.min(confidence.tolist()) + + return predicted_characters, confidence + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: + """Predict on a single input.""" + self.eval() + + if image.dtype == np.uint8: + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + + # Rescale image between 0 and 1. + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + + (predicted_characters, confidence_of_prediction,) = self._generate_sentence( + image + ) + + return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/models/vision_transformer_model.py b/src/text_recognizer/models/vision_transformer_model.py index 20bd4ca..3d36437 100644 --- a/src/text_recognizer/models/vision_transformer_model.py +++ b/src/text_recognizer/models/vision_transformer_model.py @@ -53,7 +53,7 @@ class VisionTransformerModel(Model): if network_args is not None: self.max_len = network_args["max_len"] else: - self.max_len = 128 + self.max_len = 120 if self._mapper is None: self._mapper = EmnistMapper( @@ -73,10 +73,10 @@ class VisionTransformerModel(Model): confidence_of_predictions = [] trg_indices = [self.mapper(self.init_token)] - for _ in range(self.max_len): + for _ in range(self.max_len - 1): trg = torch.tensor(trg_indices, device=self.device)[None, :].long() - trg, trg_mask = self.network.preprocess_target(trg) - logits = self.network.decoder(trg=trg, memory=memory, trg_mask=trg_mask) + trg = self.network.preprocess_target(trg) + logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None) # Convert logits to probabilities. probs = self.softmax(logits) @@ -112,6 +112,8 @@ class VisionTransformerModel(Model): # Put the image tensor on the device the model weights are on. image = image.to(self.device) - predicted_characters, confidence_of_prediction = self._generate_sentence(image) + (predicted_characters, confidence_of_prediction,) = self._generate_sentence( + image + ) return predicted_characters, confidence_of_prediction -- cgit v1.2.3-70-g09d2