diff options
Diffstat (limited to 'src/text_recognizer/models')
| -rw-r--r-- | src/text_recognizer/models/__init__.py | 7 | ||||
| -rw-r--r-- | src/text_recognizer/models/base.py | 9 | ||||
| -rw-r--r-- | src/text_recognizer/models/character_model.py | 3 | ||||
| -rw-r--r-- | src/text_recognizer/models/crnn_model.py (renamed from src/text_recognizer/models/line_ctc_model.py) | 10 | ||||
| -rw-r--r-- | src/text_recognizer/models/metrics.py | 5 | ||||
| -rw-r--r-- | src/text_recognizer/models/transformer_encoder_model.py | 111 | ||||
| -rw-r--r-- | src/text_recognizer/models/vision_transformer_model.py | 12 | 
7 files changed, 140 insertions, 17 deletions
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py index 0855079..28aa52e 100644 --- a/src/text_recognizer/models/__init__.py +++ b/src/text_recognizer/models/__init__.py @@ -1,16 +1,19 @@  """Model modules."""  from .base import Model  from .character_model import CharacterModel -from .line_ctc_model import LineCTCModel +from .crnn_model import CRNNModel  from .metrics import accuracy, cer, wer +from .transformer_encoder_model import TransformerEncoderModel  from .vision_transformer_model import VisionTransformerModel  __all__ = [      "Model",      "cer",      "CharacterModel", +    "CRNNModel",      "CNNTransfromerModel", -    "LineCTCModel",      "accuracy", +    "TransformerEncoderModel", +    "VisionTransformerModel",      "wer",  ] diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index cbef787..cc44c92 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -141,11 +141,12 @@ class Model(ABC):              "transform" in self.dataset_args["args"]              and self.dataset_args["args"]["transform"] is not None          ): -            transform_ = [ -                getattr(transforms_module, t["type"])() -                for t in self.dataset_args["args"]["transform"] -            ] +            transform_ = [] +            for t in self.dataset_args["args"]["transform"]: +                args = t["args"] or {} +                transform_.append(getattr(transforms_module, t["type"])(**args))              self.dataset_args["args"]["transform"] = Compose(transform_) +          if (              "target_transform" in self.dataset_args["args"]              and self.dataset_args["args"]["target_transform"] is not None diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index 3cf6695..f9944f3 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -47,8 +47,9 @@ class CharacterModel(Model):              swa_args,              device,          ) +        self.pad_token = dataset_args["args"]["pad_token"]          if self._mapper is None: -            self._mapper = EmnistMapper() +            self._mapper = EmnistMapper(pad_token=self.pad_token,)          self.tensor_transform = ToTensor()          self.softmax = nn.Softmax(dim=0) diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/crnn_model.py index cdc2d8b..1e01a83 100644 --- a/src/text_recognizer/models/line_ctc_model.py +++ b/src/text_recognizer/models/crnn_model.py @@ -1,4 +1,4 @@ -"""Defines the LineCTCModel class.""" +"""Defines the CRNNModel class."""  from typing import Callable, Dict, Optional, Tuple, Type, Union  import numpy as np @@ -13,7 +13,7 @@ from text_recognizer.models.base import Model  from text_recognizer.networks import greedy_decoder -class LineCTCModel(Model): +class CRNNModel(Model):      """Model for predicting a sequence of characters from an image of a text line."""      def __init__( @@ -47,8 +47,10 @@ class LineCTCModel(Model):              swa_args,              device,          ) + +        self.pad_token = dataset_args["args"]["pad_token"]          if self._mapper is None: -            self._mapper = EmnistMapper() +            self._mapper = EmnistMapper(pad_token=self.pad_token,)          self.tensor_transform = ToTensor()      def criterion(self, output: Tensor, targets: Tensor) -> Tensor: @@ -112,6 +114,6 @@ class LineCTCModel(Model):          log_probs, _ = log_probs.max(dim=2)          predicted_characters = "".join(raw_pred[0]) -        confidence_of_prediction = torch.exp(-log_probs.sum()).item() +        confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item()          return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py index 6a26216..42c3c6e 100644 --- a/src/text_recognizer/models/metrics.py +++ b/src/text_recognizer/models/metrics.py @@ -17,7 +17,10 @@ def accuracy(outputs: Tensor, labels: Tensor) -> float:          float: The accuracy for the batch.      """ -    _, predicted = torch.max(outputs.data, dim=1) +    # eos_index = torch.nonzero(labels == eos, as_tuple=False) +    # eos_index = eos_index[0].item() if eos_index.nelement() else -1 + +    _, predicted = torch.max(outputs, dim=-1)      acc = (predicted == labels).sum().float() / labels.shape[0]      acc = acc.item()      return acc diff --git a/src/text_recognizer/models/transformer_encoder_model.py b/src/text_recognizer/models/transformer_encoder_model.py new file mode 100644 index 0000000..e35e298 --- /dev/null +++ b/src/text_recognizer/models/transformer_encoder_model.py @@ -0,0 +1,111 @@ +"""Defines the CNN-Transformer class.""" +from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets import EmnistMapper +from text_recognizer.models.base import Model + + +class TransformerEncoderModel(Model): +    """A class for only using the encoder part in the sequence modelling.""" + +    def __init__( +        self, +        network_fn: Type[nn.Module], +        dataset: Type[Dataset], +        network_args: Optional[Dict] = None, +        dataset_args: Optional[Dict] = None, +        metrics: Optional[Dict] = None, +        criterion: Optional[Callable] = None, +        criterion_args: Optional[Dict] = None, +        optimizer: Optional[Callable] = None, +        optimizer_args: Optional[Dict] = None, +        lr_scheduler: Optional[Callable] = None, +        lr_scheduler_args: Optional[Dict] = None, +        swa_args: Optional[Dict] = None, +        device: Optional[str] = None, +    ) -> None: +        super().__init__( +            network_fn, +            dataset, +            network_args, +            dataset_args, +            metrics, +            criterion, +            criterion_args, +            optimizer, +            optimizer_args, +            lr_scheduler, +            lr_scheduler_args, +            swa_args, +            device, +        ) +        # self.init_token = dataset_args["args"]["init_token"] +        self.pad_token = dataset_args["args"]["pad_token"] +        self.eos_token = dataset_args["args"]["eos_token"] +        if network_args is not None: +            self.max_len = network_args["max_len"] +        else: +            self.max_len = 128 + +        if self._mapper is None: +            self._mapper = EmnistMapper( +                # init_token=self.init_token, +                pad_token=self.pad_token, +                eos_token=self.eos_token, +            ) +        self.tensor_transform = ToTensor() + +        self.softmax = nn.Softmax(dim=2) + +    @torch.no_grad() +    def _generate_sentence(self, image: Tensor) -> Tuple[List, float]: +        logits = self.network(image) +        # Convert logits to probabilities. +        probs = self.softmax(logits).squeeze(0) + +        confidence, pred_tokens = probs.max(1) +        pred_tokens = pred_tokens + +        eos_index = torch.nonzero( +            pred_tokens == self._mapper(self.eos_token), as_tuple=False, +        ) + +        eos_index = eos_index[0].item() if eos_index.nelement() else -1 + +        predicted_characters = "".join( +            [self.mapper(x) for x in pred_tokens[:eos_index].tolist()] +        ) + +        confidence = np.min(confidence.tolist()) + +        return predicted_characters, confidence + +    @torch.no_grad() +    def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: +        """Predict on a single input.""" +        self.eval() + +        if image.dtype == np.uint8: +            # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. +            image = self.tensor_transform(image) + +        # Rescale image between 0 and 1. +        if image.dtype == torch.uint8: +            # If the image is an unscaled tensor. +            image = image.type("torch.FloatTensor") / 255 + +        # Put the image tensor on the device the model weights are on. +        image = image.to(self.device) + +        (predicted_characters, confidence_of_prediction,) = self._generate_sentence( +            image +        ) + +        return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/models/vision_transformer_model.py b/src/text_recognizer/models/vision_transformer_model.py index 20bd4ca..3d36437 100644 --- a/src/text_recognizer/models/vision_transformer_model.py +++ b/src/text_recognizer/models/vision_transformer_model.py @@ -53,7 +53,7 @@ class VisionTransformerModel(Model):          if network_args is not None:              self.max_len = network_args["max_len"]          else: -            self.max_len = 128 +            self.max_len = 120          if self._mapper is None:              self._mapper = EmnistMapper( @@ -73,10 +73,10 @@ class VisionTransformerModel(Model):          confidence_of_predictions = []          trg_indices = [self.mapper(self.init_token)] -        for _ in range(self.max_len): +        for _ in range(self.max_len - 1):              trg = torch.tensor(trg_indices, device=self.device)[None, :].long() -            trg, trg_mask = self.network.preprocess_target(trg) -            logits = self.network.decoder(trg=trg, memory=memory, trg_mask=trg_mask) +            trg = self.network.preprocess_target(trg) +            logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None)              # Convert logits to probabilities.              probs = self.softmax(logits) @@ -112,6 +112,8 @@ class VisionTransformerModel(Model):          # Put the image tensor on the device the model weights are on.          image = image.to(self.device) -        predicted_characters, confidence_of_prediction = self._generate_sentence(image) +        (predicted_characters, confidence_of_prediction,) = self._generate_sentence( +            image +        )          return predicted_characters, confidence_of_prediction  |