diff options
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r-- | src/text_recognizer/models/__init__.py | 16 | ||||
-rw-r--r-- | src/text_recognizer/models/base.py | 66 | ||||
-rw-r--r-- | src/text_recognizer/models/character_model.py | 4 | ||||
-rw-r--r-- | src/text_recognizer/models/crnn_model.py (renamed from src/text_recognizer/models/line_ctc_model.py) | 16 | ||||
-rw-r--r-- | src/text_recognizer/models/metrics.py | 5 | ||||
-rw-r--r-- | src/text_recognizer/models/transformer_encoder_model.py | 111 | ||||
-rw-r--r-- | src/text_recognizer/models/vision_transformer_model.py | 119 |
7 files changed, 311 insertions, 26 deletions
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py index a3cfc15..28aa52e 100644 --- a/src/text_recognizer/models/__init__.py +++ b/src/text_recognizer/models/__init__.py @@ -1,7 +1,19 @@ """Model modules.""" from .base import Model from .character_model import CharacterModel -from .line_ctc_model import LineCTCModel +from .crnn_model import CRNNModel from .metrics import accuracy, cer, wer +from .transformer_encoder_model import TransformerEncoderModel +from .vision_transformer_model import VisionTransformerModel -__all__ = ["Model", "cer", "CharacterModel", "LineCTCModel", "accuracy", "wer"] +__all__ = [ + "Model", + "cer", + "CharacterModel", + "CRNNModel", + "CNNTransfromerModel", + "accuracy", + "TransformerEncoderModel", + "VisionTransformerModel", + "wer", +] diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index caf8065..cc44c92 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -6,7 +6,7 @@ import importlib from pathlib import Path import re import shutil -from typing import Callable, Dict, Optional, Tuple, Type +from typing import Callable, Dict, List, Optional, Tuple, Type, Union from loguru import logger import torch @@ -15,6 +15,7 @@ from torch import Tensor from torch.optim.swa_utils import AveragedModel, SWALR from torch.utils.data import DataLoader, Dataset, random_split from torchsummary import summary +from torchvision.transforms import Compose from text_recognizer.datasets import EmnistMapper @@ -128,16 +129,42 @@ class Model(ABC): self._configure_criterion() self._configure_optimizers() - # Prints a summary of the network in terminal. - self.summary() - # Set this flag to true to prevent the model from configuring again. self.is_configured = True + def _configure_transforms(self) -> None: + # Load transforms. + transforms_module = importlib.import_module( + "text_recognizer.datasets.transforms" + ) + if ( + "transform" in self.dataset_args["args"] + and self.dataset_args["args"]["transform"] is not None + ): + transform_ = [] + for t in self.dataset_args["args"]["transform"]: + args = t["args"] or {} + transform_.append(getattr(transforms_module, t["type"])(**args)) + self.dataset_args["args"]["transform"] = Compose(transform_) + + if ( + "target_transform" in self.dataset_args["args"] + and self.dataset_args["args"]["target_transform"] is not None + ): + target_transform_ = [ + torch.tensor, + ] + for t in self.dataset_args["args"]["target_transform"]: + args = t["args"] or {} + target_transform_.append(getattr(transforms_module, t["type"])(**args)) + self.dataset_args["args"]["target_transform"] = Compose(target_transform_) + def prepare_data(self) -> None: """Prepare data for training.""" # TODO add downloading. if not self.data_prepared: + self._configure_transforms() + # Load train dataset. train_dataset = self.dataset(train=True, **self.dataset_args["args"]) train_dataset.load_or_generate_data() @@ -327,20 +354,20 @@ class Model(ABC): else: return self.network(x) - def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor: - """Compute the loss.""" - return self.criterion(output, targets) - def summary( - self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 3 + self, + input_shape: Optional[Union[List, Tuple]] = None, + depth: int = 4, + device: Optional[str] = None, ) -> None: """Prints a summary of the network architecture.""" + device = self.device if device is None else device if input_shape is not None: - summary(self.network, input_shape, depth=depth, device=self.device) + summary(self.network, input_shape, depth=depth, device=device) elif self._input_shape is not None: input_shape = (1,) + tuple(self._input_shape) - summary(self.network, input_shape, depth=depth, device=self.device) + summary(self.network, input_shape, depth=depth, device=device) else: logger.warning("Could not print summary as input shape is not set.") @@ -356,25 +383,29 @@ class Model(ABC): state["optimizer_state"] = self._optimizer.state_dict() if self._lr_scheduler is not None: - state["scheduler_state"] = self._lr_scheduler.state_dict() + state["scheduler_state"] = self._lr_scheduler["lr_scheduler"].state_dict() + state["scheduler_interval"] = self._lr_scheduler["interval"] if self._swa_network is not None: state["swa_network"] = self._swa_network.state_dict() return state - def load_from_checkpoint(self, checkpoint_path: Path) -> None: + def load_from_checkpoint(self, checkpoint_path: Union[str, Path]) -> None: """Load a previously saved checkpoint. Args: checkpoint_path (Path): Path to the experiment with the checkpoint. """ + checkpoint_path = Path(checkpoint_path) + self.prepare_data() + self.configure_model() logger.debug("Loading checkpoint...") if not checkpoint_path.exists(): logger.debug("File does not exist {str(checkpoint_path)}") - checkpoint = torch.load(str(checkpoint_path)) + checkpoint = torch.load(str(checkpoint_path), map_location=self.device) self._network.load_state_dict(checkpoint["model_state"]) if self._optimizer is not None: @@ -383,8 +414,11 @@ class Model(ABC): if self._lr_scheduler is not None: # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs # with OneCycleLR. - if self._lr_scheduler.__class__.__name__ != "OneCycleLR": - self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"]) + if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR": + self._lr_scheduler["lr_scheduler"].load_state_dict( + checkpoint["scheduler_state"] + ) + self._lr_scheduler["interval"] = checkpoint["scheduler_interval"] if self._swa_network is not None: self._swa_network.load_state_dict(checkpoint["swa_network"]) diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index 50e94a2..f9944f3 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -47,8 +47,9 @@ class CharacterModel(Model): swa_args, device, ) + self.pad_token = dataset_args["args"]["pad_token"] if self._mapper is None: - self._mapper = EmnistMapper() + self._mapper = EmnistMapper(pad_token=self.pad_token,) self.tensor_transform = ToTensor() self.softmax = nn.Softmax(dim=0) @@ -65,6 +66,7 @@ class CharacterModel(Model): Tuple[str, float]: The predicted character and the confidence in the prediction. """ + self.eval() if image.dtype == np.uint8: # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/crnn_model.py index 16eaed3..1e01a83 100644 --- a/src/text_recognizer/models/line_ctc_model.py +++ b/src/text_recognizer/models/crnn_model.py @@ -1,4 +1,4 @@ -"""Defines the LineCTCModel class.""" +"""Defines the CRNNModel class.""" from typing import Callable, Dict, Optional, Tuple, Type, Union import numpy as np @@ -13,7 +13,7 @@ from text_recognizer.models.base import Model from text_recognizer.networks import greedy_decoder -class LineCTCModel(Model): +class CRNNModel(Model): """Model for predicting a sequence of characters from an image of a text line.""" def __init__( @@ -47,11 +47,13 @@ class LineCTCModel(Model): swa_args, device, ) + + self.pad_token = dataset_args["args"]["pad_token"] if self._mapper is None: - self._mapper = EmnistMapper() + self._mapper = EmnistMapper(pad_token=self.pad_token,) self.tensor_transform = ToTensor() - def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor: + def criterion(self, output: Tensor, targets: Tensor) -> Tensor: """Computes the CTC loss. Args: @@ -82,11 +84,13 @@ class LineCTCModel(Model): torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device) ) - return self.criterion(output, targets, input_lengths, target_lengths) + return self._criterion(output, targets, input_lengths, target_lengths) @torch.no_grad() def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: """Predict on a single input.""" + self.eval() + if image.dtype == np.uint8: # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. image = self.tensor_transform(image) @@ -110,6 +114,6 @@ class LineCTCModel(Model): log_probs, _ = log_probs.max(dim=2) predicted_characters = "".join(raw_pred[0]) - confidence_of_prediction = torch.exp(log_probs.sum()).item() + confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item() return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py index 6a26216..42c3c6e 100644 --- a/src/text_recognizer/models/metrics.py +++ b/src/text_recognizer/models/metrics.py @@ -17,7 +17,10 @@ def accuracy(outputs: Tensor, labels: Tensor) -> float: float: The accuracy for the batch. """ - _, predicted = torch.max(outputs.data, dim=1) + # eos_index = torch.nonzero(labels == eos, as_tuple=False) + # eos_index = eos_index[0].item() if eos_index.nelement() else -1 + + _, predicted = torch.max(outputs, dim=-1) acc = (predicted == labels).sum().float() / labels.shape[0] acc = acc.item() return acc diff --git a/src/text_recognizer/models/transformer_encoder_model.py b/src/text_recognizer/models/transformer_encoder_model.py new file mode 100644 index 0000000..e35e298 --- /dev/null +++ b/src/text_recognizer/models/transformer_encoder_model.py @@ -0,0 +1,111 @@ +"""Defines the CNN-Transformer class.""" +from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets import EmnistMapper +from text_recognizer.models.base import Model + + +class TransformerEncoderModel(Model): + """A class for only using the encoder part in the sequence modelling.""" + + def __init__( + self, + network_fn: Type[nn.Module], + dataset: Type[Dataset], + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + # self.init_token = dataset_args["args"]["init_token"] + self.pad_token = dataset_args["args"]["pad_token"] + self.eos_token = dataset_args["args"]["eos_token"] + if network_args is not None: + self.max_len = network_args["max_len"] + else: + self.max_len = 128 + + if self._mapper is None: + self._mapper = EmnistMapper( + # init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + ) + self.tensor_transform = ToTensor() + + self.softmax = nn.Softmax(dim=2) + + @torch.no_grad() + def _generate_sentence(self, image: Tensor) -> Tuple[List, float]: + logits = self.network(image) + # Convert logits to probabilities. + probs = self.softmax(logits).squeeze(0) + + confidence, pred_tokens = probs.max(1) + pred_tokens = pred_tokens + + eos_index = torch.nonzero( + pred_tokens == self._mapper(self.eos_token), as_tuple=False, + ) + + eos_index = eos_index[0].item() if eos_index.nelement() else -1 + + predicted_characters = "".join( + [self.mapper(x) for x in pred_tokens[:eos_index].tolist()] + ) + + confidence = np.min(confidence.tolist()) + + return predicted_characters, confidence + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: + """Predict on a single input.""" + self.eval() + + if image.dtype == np.uint8: + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + + # Rescale image between 0 and 1. + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + + (predicted_characters, confidence_of_prediction,) = self._generate_sentence( + image + ) + + return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/models/vision_transformer_model.py b/src/text_recognizer/models/vision_transformer_model.py new file mode 100644 index 0000000..3d36437 --- /dev/null +++ b/src/text_recognizer/models/vision_transformer_model.py @@ -0,0 +1,119 @@ +"""Defines the CNN-Transformer class.""" +from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets import EmnistMapper +from text_recognizer.models.base import Model +from text_recognizer.networks import greedy_decoder + + +class VisionTransformerModel(Model): + """Model for predicting a sequence of characters from an image of a text line with a cnn-transformer.""" + + def __init__( + self, + network_fn: Type[nn.Module], + dataset: Type[Dataset], + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + self.init_token = dataset_args["args"]["init_token"] + self.pad_token = dataset_args["args"]["pad_token"] + self.eos_token = dataset_args["args"]["eos_token"] + if network_args is not None: + self.max_len = network_args["max_len"] + else: + self.max_len = 120 + + if self._mapper is None: + self._mapper = EmnistMapper( + init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + ) + self.tensor_transform = ToTensor() + + self.softmax = nn.Softmax(dim=2) + + @torch.no_grad() + def _generate_sentence(self, image: Tensor) -> Tuple[List, float]: + src = self.network.preprocess_input(image) + memory = self.network.encoder(src) + + confidence_of_predictions = [] + trg_indices = [self.mapper(self.init_token)] + + for _ in range(self.max_len - 1): + trg = torch.tensor(trg_indices, device=self.device)[None, :].long() + trg = self.network.preprocess_target(trg) + logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None) + + # Convert logits to probabilities. + probs = self.softmax(logits) + + pred_token = probs.argmax(2)[:, -1].item() + confidence = probs.max(2).values[:, -1].item() + + trg_indices.append(pred_token) + confidence_of_predictions.append(confidence) + + if pred_token == self.mapper(self.eos_token): + break + + confidence = np.min(confidence_of_predictions) + predicted_characters = "".join([self.mapper(x) for x in trg_indices[1:]]) + + return predicted_characters, confidence + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: + """Predict on a single input.""" + self.eval() + + if image.dtype == np.uint8: + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + + # Rescale image between 0 and 1. + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + + (predicted_characters, confidence_of_prediction,) = self._generate_sentence( + image + ) + + return predicted_characters, confidence_of_prediction |