diff options
Diffstat (limited to 'src/text_recognizer')
21 files changed, 269 insertions, 406 deletions
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py index 8deac7f..1105f23 100644 --- a/src/text_recognizer/datasets/transforms.py +++ b/src/text_recognizer/datasets/transforms.py @@ -3,7 +3,8 @@ import numpy as np  from PIL import Image  import torch  from torch import Tensor -from torchvision.transforms import Compose, Resize, ToPILImage, ToTensor +import torch.nn.functional as F +from torchvision.transforms import Compose, ToPILImage, ToTensor  from text_recognizer.datasets.util import EmnistMapper @@ -16,6 +17,18 @@ class Transpose:          return np.array(image).swapaxes(0, 1) +class Resize: +    """Resizes a tensor to a specified width.""" + +    def __init__(self, width: int = 952) -> None: +        # The default is 952 because of the IAM dataset. +        self.width = width + +    def __call__(self, image: Tensor) -> Tensor: +        """Resize tensor in the last dimension.""" +        return F.interpolate(image, size=self.width, mode="nearest") + +  class AddTokens:      """Adds start of sequence and end of sequence tokens to target tensor.""" diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py index 28aa52e..53340f1 100644 --- a/src/text_recognizer/models/__init__.py +++ b/src/text_recognizer/models/__init__.py @@ -2,18 +2,16 @@  from .base import Model  from .character_model import CharacterModel  from .crnn_model import CRNNModel -from .metrics import accuracy, cer, wer -from .transformer_encoder_model import TransformerEncoderModel -from .vision_transformer_model import VisionTransformerModel +from .metrics import accuracy, accuracy_ignore_pad, cer, wer +from .transformer_model import TransformerModel  __all__ = [ -    "Model", +    "accuracy", +    "accuracy_ignore_pad",      "cer",      "CharacterModel",      "CRNNModel", -    "CNNTransfromerModel", -    "accuracy", -    "TransformerEncoderModel", -    "VisionTransformerModel", +    "Model", +    "TransformerModel",      "wer",  ] diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index cc44c92..a945b41 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -49,7 +49,7 @@ class Model(ABC):              network_args (Optional[Dict]): Arguments for the network. Defaults to None.              dataset_args (Optional[Dict]):  Arguments for the dataset.              metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. -            criterion (Optional[Callable]): The criterion to evaulate the preformance of the network. +            criterion (Optional[Callable]): The criterion to evaluate the performance of the network.                  Defaults to None.              criterion_args (Optional[Dict]): Dict of arguments for criterion. Defaults to None.              optimizer (Optional[Callable]): The optimizer for updating the weights. Defaults to None. @@ -221,7 +221,7 @@ class Model(ABC):      def _configure_network(self, network_fn: Type[nn.Module]) -> None:          """Loads the network.""" -        # If no network arguemnts are given, load pretrained weights if they exist. +        # If no network arguments are given, load pretrained weights if they exist.          if self._network_args is None:              self.load_weights(network_fn)          else: @@ -245,7 +245,7 @@ class Model(ABC):              self._optimizer = None          if self._optimizer and self._lr_scheduler is not None: -            if "OneCycleLR" in str(self._lr_scheduler): +            if "steps_per_epoch" in self.lr_scheduler_args:                  self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader())              # Assume lr scheduler should update at each epoch if not specified. @@ -412,7 +412,7 @@ class Model(ABC):              self._optimizer.load_state_dict(checkpoint["optimizer_state"])          if self._lr_scheduler is not None: -            # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs +            # Does not work when loading from previous checkpoint and trying to train beyond the last max epochs              # with OneCycleLR.              if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR":                  self._lr_scheduler["lr_scheduler"].load_state_dict( diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py index 42c3c6e..af9adb5 100644 --- a/src/text_recognizer/models/metrics.py +++ b/src/text_recognizer/models/metrics.py @@ -6,7 +6,23 @@ from torch import Tensor  from text_recognizer.networks import greedy_decoder -def accuracy(outputs: Tensor, labels: Tensor) -> float: +def accuracy_ignore_pad( +    output: Tensor, +    target: Tensor, +    pad_index: int = 79, +    eos_index: int = 81, +    seq_len: int = 97, +) -> float: +    """Sets all predictions after eos to pad.""" +    start_indices = torch.nonzero(target == eos_index, as_tuple=False).squeeze(1) +    end_indices = torch.arange(seq_len, target.shape[0] + 1, seq_len) +    for start, stop in zip(start_indices, end_indices): +        output[start + 1 : stop] = pad_index + +    return accuracy(output, target) + + +def accuracy(outputs: Tensor, labels: Tensor,) -> float:      """Computes the accuracy.      Args: @@ -17,10 +33,9 @@ def accuracy(outputs: Tensor, labels: Tensor) -> float:          float: The accuracy for the batch.      """ -    # eos_index = torch.nonzero(labels == eos, as_tuple=False) -    # eos_index = eos_index[0].item() if eos_index.nelement() else -1      _, predicted = torch.max(outputs, dim=-1) +      acc = (predicted == labels).sum().float() / labels.shape[0]      acc = acc.item()      return acc diff --git a/src/text_recognizer/models/transformer_encoder_model.py b/src/text_recognizer/models/transformer_encoder_model.py deleted file mode 100644 index e35e298..0000000 --- a/src/text_recognizer/models/transformer_encoder_model.py +++ /dev/null @@ -1,111 +0,0 @@ -"""Defines the CNN-Transformer class.""" -from typing import Callable, Dict, List, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.datasets import EmnistMapper -from text_recognizer.models.base import Model - - -class TransformerEncoderModel(Model): -    """A class for only using the encoder part in the sequence modelling.""" - -    def __init__( -        self, -        network_fn: Type[nn.Module], -        dataset: Type[Dataset], -        network_args: Optional[Dict] = None, -        dataset_args: Optional[Dict] = None, -        metrics: Optional[Dict] = None, -        criterion: Optional[Callable] = None, -        criterion_args: Optional[Dict] = None, -        optimizer: Optional[Callable] = None, -        optimizer_args: Optional[Dict] = None, -        lr_scheduler: Optional[Callable] = None, -        lr_scheduler_args: Optional[Dict] = None, -        swa_args: Optional[Dict] = None, -        device: Optional[str] = None, -    ) -> None: -        super().__init__( -            network_fn, -            dataset, -            network_args, -            dataset_args, -            metrics, -            criterion, -            criterion_args, -            optimizer, -            optimizer_args, -            lr_scheduler, -            lr_scheduler_args, -            swa_args, -            device, -        ) -        # self.init_token = dataset_args["args"]["init_token"] -        self.pad_token = dataset_args["args"]["pad_token"] -        self.eos_token = dataset_args["args"]["eos_token"] -        if network_args is not None: -            self.max_len = network_args["max_len"] -        else: -            self.max_len = 128 - -        if self._mapper is None: -            self._mapper = EmnistMapper( -                # init_token=self.init_token, -                pad_token=self.pad_token, -                eos_token=self.eos_token, -            ) -        self.tensor_transform = ToTensor() - -        self.softmax = nn.Softmax(dim=2) - -    @torch.no_grad() -    def _generate_sentence(self, image: Tensor) -> Tuple[List, float]: -        logits = self.network(image) -        # Convert logits to probabilities. -        probs = self.softmax(logits).squeeze(0) - -        confidence, pred_tokens = probs.max(1) -        pred_tokens = pred_tokens - -        eos_index = torch.nonzero( -            pred_tokens == self._mapper(self.eos_token), as_tuple=False, -        ) - -        eos_index = eos_index[0].item() if eos_index.nelement() else -1 - -        predicted_characters = "".join( -            [self.mapper(x) for x in pred_tokens[:eos_index].tolist()] -        ) - -        confidence = np.min(confidence.tolist()) - -        return predicted_characters, confidence - -    @torch.no_grad() -    def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: -        """Predict on a single input.""" -        self.eval() - -        if image.dtype == np.uint8: -            # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. -            image = self.tensor_transform(image) - -        # Rescale image between 0 and 1. -        if image.dtype == torch.uint8: -            # If the image is an unscaled tensor. -            image = image.type("torch.FloatTensor") / 255 - -        # Put the image tensor on the device the model weights are on. -        image = image.to(self.device) - -        (predicted_characters, confidence_of_prediction,) = self._generate_sentence( -            image -        ) - -        return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/models/vision_transformer_model.py b/src/text_recognizer/models/transformer_model.py index 3d36437..968a047 100644 --- a/src/text_recognizer/models/vision_transformer_model.py +++ b/src/text_recognizer/models/transformer_model.py @@ -13,7 +13,7 @@ from text_recognizer.models.base import Model  from text_recognizer.networks import greedy_decoder -class VisionTransformerModel(Model): +class TransformerModel(Model):      """Model for predicting a sequence of characters from an image of a text line with a cnn-transformer."""      def __init__( @@ -50,10 +50,7 @@ class VisionTransformerModel(Model):          self.init_token = dataset_args["args"]["init_token"]          self.pad_token = dataset_args["args"]["pad_token"]          self.eos_token = dataset_args["args"]["eos_token"] -        if network_args is not None: -            self.max_len = network_args["max_len"] -        else: -            self.max_len = 120 +        self.max_len = 120          if self._mapper is None:              self._mapper = EmnistMapper( @@ -67,7 +64,7 @@ class VisionTransformerModel(Model):      @torch.no_grad()      def _generate_sentence(self, image: Tensor) -> Tuple[List, float]: -        src = self.network.preprocess_input(image) +        src = self.network.extract_image_features(image)          memory = self.network.encoder(src)          confidence_of_predictions = [] @@ -75,7 +72,7 @@ class VisionTransformerModel(Model):          for _ in range(self.max_len - 1):              trg = torch.tensor(trg_indices, device=self.device)[None, :].long() -            trg = self.network.preprocess_target(trg) +            trg = self.network.target_embedding(trg)              logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None)              # Convert logits to probabilities. @@ -101,7 +98,7 @@ class VisionTransformerModel(Model):          self.eval()          if image.dtype == np.uint8: -            # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. +            # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1].              image = self.tensor_transform(image)          # Rescale image between 0 and 1. diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index 6d88768..2cc1137 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -1,25 +1,20 @@  """Network modules."""  from .cnn_transformer import CNNTransformer -from .cnn_transformer_encoder import CNNTransformerEncoder  from .crnn import ConvolutionalRecurrentNetwork  from .ctc import greedy_decoder  from .densenet import DenseNet  from .lenet import LeNet -from .loss import EmbeddingLoss  from .mlp import MLP  from .residual_network import ResidualNetwork, ResidualNetworkEncoder  from .sparse_mlp import SparseMLP  from .transformer import Transformer  from .util import sliding_window -from .vision_transformer import VisionTransformer  from .wide_resnet import WideResidualNetwork  __all__ = [      "CNNTransformer", -    "CNNTransformerEncoder",      "ConvolutionalRecurrentNetwork",      "DenseNet", -    "EmbeddingLoss",      "greedy_decoder",      "MLP",      "LeNet", @@ -28,6 +23,5 @@ __all__ = [      "sliding_window",      "Transformer",      "SparseMLP", -    "VisionTransformer",      "WideResidualNetwork",  ] diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py index 3da2c9f..16c7a41 100644 --- a/src/text_recognizer/networks/cnn_transformer.py +++ b/src/text_recognizer/networks/cnn_transformer.py @@ -1,4 +1,4 @@ -"""A DETR style transfomers but for text recognition.""" +"""A CNN-Transformer for image to text recognition."""  from typing import Dict, Optional, Tuple  from einops import rearrange @@ -11,7 +11,7 @@ from text_recognizer.networks.util import configure_backbone  class CNNTransformer(nn.Module): -    """CNN+Transfomer for image to sequence prediction, sort of based on the ideas from DETR.""" +    """CNN+Transfomer for image to sequence prediction."""      def __init__(          self, @@ -25,22 +25,14 @@ class CNNTransformer(nn.Module):          dropout_rate: float,          trg_pad_index: int,          backbone: str, -        out_channels: int, -        max_len: int,          backbone_args: Optional[Dict] = None,          activation: str = "gelu",      ) -> None:          super().__init__()          self.trg_pad_index = trg_pad_index -          self.backbone = configure_backbone(backbone, backbone_args)          self.character_embedding = nn.Embedding(vocab_size, hidden_dim) - -        # self.conv = nn.Conv2d(out_channels, max_len, kernel_size=1) -          self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate) -        self.row_embed = nn.Parameter(torch.rand(max_len, max_len // 2)) -        self.col_embed = nn.Parameter(torch.rand(max_len, max_len // 2))          self.adaptive_pool = (              nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None @@ -78,8 +70,12 @@ class CNNTransformer(nn.Module):              self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)          ) -    def preprocess_input(self, src: Tensor) -> Tensor: -        """Encodes src with a backbone network and a positional encoding. +    def extract_image_features(self, src: Tensor) -> Tensor: +        """Extracts image features with a backbone neural network. + +        It seem like the winning idea was to swap channels and width dimension and collapse +        the height dimension. The transformer is learning like a baby with this implementation!!! :D +        Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D          Args:              src (Tensor): Input tensor. @@ -88,29 +84,19 @@ class CNNTransformer(nn.Module):              Tensor: A input src to the transformer.          """ -        # If batch dimenstion is missing, it needs to be added. +        # If batch dimension is missing, it needs to be added.          if len(src.shape) < 4:              src = src[(None,) * (4 - len(src.shape))]          src = self.backbone(src) -        # src = self.conv(src) +        src = rearrange(src, "b c h w -> b w c h")          if self.adaptive_pool is not None:              src = self.adaptive_pool(src) -        H, W = src.shape[-2:] -        src = rearrange(src, "b t h w -> b t (h w)") - -        # construct positional encodings -        pos = torch.cat( -            [ -                self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1), -                self.row_embed[:H].unsqueeze(1).repeat(1, W, 1), -            ], -            dim=-1, -        ).unsqueeze(0) -        pos = rearrange(pos, "b h w l -> b l (h w)") -        src = pos + 0.1 * src +        src = src.squeeze(3) +        src = self.position_encoding(src) +          return src -    def preprocess_target(self, trg: Tensor) -> Tuple[Tensor, Tensor]: +    def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]:          """Encodes target tensor with embedding and postion.          Args: @@ -126,9 +112,9 @@ class CNNTransformer(nn.Module):      def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:          """Forward pass with CNN transfomer.""" -        h = self.preprocess_input(x) +        h = self.extract_image_features(x)          trg_mask = self._create_trg_mask(trg) -        trg = self.preprocess_target(trg) +        trg = self.target_embedding(trg)          out = self.transformer(h, trg, trg_mask=trg_mask)          logits = self.head(out) diff --git a/src/text_recognizer/networks/cnn_transformer_encoder.py b/src/text_recognizer/networks/cnn_transformer_encoder.py deleted file mode 100644 index 93626bf..0000000 --- a/src/text_recognizer/networks/cnn_transformer_encoder.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Network with a CNN backend and a transformer encoder head.""" -from typing import Dict - -from einops import rearrange -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.transformer import PositionalEncoding -from text_recognizer.networks.util import configure_backbone - - -class CNNTransformerEncoder(nn.Module): -    """A CNN backbone with Transformer Encoder frontend for sequence prediction.""" - -    def __init__( -        self, -        backbone: str, -        backbone_args: Dict, -        mlp_dim: int, -        d_model: int, -        nhead: int = 8, -        dropout_rate: float = 0.1, -        activation: str = "relu", -        num_layers: int = 6, -        num_classes: int = 80, -        num_channels: int = 256, -        max_len: int = 97, -    ) -> None: -        super().__init__() -        self.d_model = d_model -        self.nhead = nhead -        self.dropout_rate = dropout_rate -        self.activation = activation -        self.num_layers = num_layers - -        self.backbone = configure_backbone(backbone, backbone_args) -        self.position_encoding = PositionalEncoding(d_model, dropout_rate) -        self.encoder = self._configure_encoder() - -        self.conv = nn.Conv2d(num_channels, max_len, kernel_size=1) - -        self.mlp = nn.Linear(mlp_dim, d_model) - -        self.head = nn.Linear(d_model, num_classes) - -    def _configure_encoder(self) -> nn.TransformerEncoder: -        encoder_layer = nn.TransformerEncoderLayer( -            d_model=self.d_model, -            nhead=self.nhead, -            dropout=self.dropout_rate, -            activation=self.activation, -        ) -        norm = nn.LayerNorm(self.d_model) -        return nn.TransformerEncoder( -            encoder_layer=encoder_layer, num_layers=self.num_layers, norm=norm -        ) - -    def forward(self, x: Tensor, targets: Tensor = None) -> Tensor: -        """Forward pass through the network.""" -        if len(x.shape) < 4: -            x = x[(None,) * (4 - len(x.shape))] - -        x = self.conv(self.backbone(x)) -        x = rearrange(x, "b c h w -> b c (h w)") -        x = self.mlp(x) -        x = self.position_encoding(x) -        x = rearrange(x, "b c h-> c b h") -        x = self.encoder(x) -        x = rearrange(x, "c b h-> b c h") -        logits = self.head(x) - -        return logits diff --git a/src/text_recognizer/networks/loss/__init__.py b/src/text_recognizer/networks/loss/__init__.py new file mode 100644 index 0000000..b489264 --- /dev/null +++ b/src/text_recognizer/networks/loss/__init__.py @@ -0,0 +1,2 @@ +"""Loss module.""" +from .loss import EmbeddingLoss, LabelSmoothingCrossEntropy diff --git a/src/text_recognizer/networks/loss.py b/src/text_recognizer/networks/loss/loss.py index cf9fa0d..cf9fa0d 100644 --- a/src/text_recognizer/networks/loss.py +++ b/src/text_recognizer/networks/loss/loss.py diff --git a/src/text_recognizer/networks/neural_machine_reader.py b/src/text_recognizer/networks/neural_machine_reader.py new file mode 100644 index 0000000..7f8c49b --- /dev/null +++ b/src/text_recognizer/networks/neural_machine_reader.py @@ -0,0 +1,201 @@ +"""Sequence to sequence network with RNN cells.""" +# from typing import Dict, Optional, Tuple + +# from einops import rearrange +# from einops.layers.torch import Rearrange +# import torch +# from torch import nn +# from torch import Tensor + +# from text_recognizer.networks.util import configure_backbone + + +# class Encoder(nn.Module): +#     def __init__( +#         self, +#         embedding_dim: int, +#         encoder_dim: int, +#         decoder_dim: int, +#         dropout_rate: float = 0.1, +#     ) -> None: +#         super().__init__() +#         self.rnn = nn.GRU( +#             input_size=embedding_dim, hidden_size=encoder_dim, bidirectional=True +#         ) +#         self.fc = nn.Sequential( +#             nn.Linear(in_features=2 * encoder_dim, out_features=decoder_dim), nn.Tanh() +#         ) +#         self.dropout = nn.Dropout(p=dropout_rate) + +#     def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: +#         """Encodes a sequence of tensors with a bidirectional GRU. + +#         Args: +#             x (Tensor): A input sequence. + +#         Shape: +#             - x: :math:`(T, N, E)`. +#             - output[0]: :math:`(T, N, 2 * E)`. +#             - output[1]: :math:`(T, N, D)`. + +#             where T is the sequence length, N is the batch size, E is the +#             embedding/encoder dimension, and D is the decoder dimension. + +#         Returns: +#             Tuple[Tensor, Tensor]: The encoder output and the hidden state of the +#                 encoder. + +#         """ + +#         output, hidden = self.rnn(x) + +#         # Get the hidden state from the forward and backward rnn. +#         hidden_state = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1) + +#         # Apply fully connected layer and tanh activation. +#         hidden_state = self.fc(hidden_state) + +#         return output, hidden_state + + +# class Attention(nn.Module): +#     def __init__(self, encoder_dim: int, decoder_dim: int) -> None: +#         super().__init__() +#         self.atten = nn.Linear( +#             in_features=2 * encoder_dim + decoder_dim, out_features=decoder_dim +#         ) +#         self.value = nn.Linear(in_features=decoder_dim, out_features=1, bias=False) + +#     def forward(self, hidden_state: Tensor, encoder_outputs: Tensor) -> Tensor: +#         """Short summary. + +#         Args: +#             hidden_state (Tensor): Description of parameter `h`. +#             encoder_outputs (Tensor): Description of parameter `enc_out`. + +#         Shape: +#             - x: :math:`(T, N, E)`. +#             - output[0]: :math:`(T, N, 2 * E)`. +#             - output[1]: :math:`(T, N, D)`. + +#             where T is the sequence length, N is the batch size, E is the +#             embedding/encoder dimension, and D is the decoder dimension. + +#         Returns: +#             Tensor: Description of returned object. + +#         """ +#         t, b = enc_out.shape[:2] +#         # repeat decoder hidden state src_len times +#         hidden_state = hidden_state.unsqueeze(1).repeat(1, t, 1) + +#         encoder_outputs = rearrange(encoder_outputs, "t b e2 -> b t e2") + +#         # Calculate the energy between the decoders previous hidden state and the +#         # encoders hidden states. +#         energy = torch.tanh( +#             self.attn(torch.cat((hidden_state, encoder_outputs), dim=2)) +#         ) + +#         attention = self.value(energy).squeeze(2) + +#         # Apply softmax on the attention to squeeze it between 0 and 1. +#         attention = F.softmax(attention, dim=1) + +#         return attention + + +# class Decoder(nn.Module): +#     def __init__( +#         self, +#         embedding_dim: int, +#         encoder_dim: int, +#         decoder_dim: int, +#         output_dim: int, +#         dropout_rate: float = 0.1, +#     ) -> None: +#         super().__init__() +#         self.output_dim = output_dim +#         self.embedding = nn.Embedding(output_dim, embedding_dim) +#         self.attention = Attention(encoder_dim, decoder_dim) +#         self.rnn = nn.GRU( +#             input_size=2 * encoder_dim + embedding_dim, hidden_size=decoder_dim +#         ) + +#         self.head = nn.Linear( +#             in_features=2 * encoder_dim + embedding_dim + decoder_dim, +#             out_features=output_dim, +#         ) +#         self.dropout = nn.Dropout(p=dropout_rate) + +#     def forward( +#         self, trg: Tensor, hidden_state: Tensor, encoder_outputs: Tensor +#     ) -> Tensor: +#         # input = [batch size] +#         # hidden = [batch size, dec hid dim] +#         # encoder_outputs = [src len, batch size, enc hid dim * 2] +#         trg = trg.unsqueeze(0) +#         trg_embedded = self.dropout(self.embedding(trg)) + +#         a = self.attention(hidden_state, encoder_outputs) + +#         weighted = torch.bmm(a, encoder_outputs) + +#         # Permutate the tensor. +#         weighted = rearrange(weighted, "b a e2 -> a b e2") + +#         rnn_input = torch.cat((trg_embedded, weighted), dim=2) + +#         output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0)) + +#         # seq len, n layers and n directions will always be 1 in this decoder, therefore: +#         # output = [1, batch size, dec hid dim] +#         # hidden = [1, batch size, dec hid dim] +#         # this also means that output == hidden +#         assert (output == hidden).all() + +#         trg_embedded = trg_embedded.squeeze(0) +#         output = output.squeeze(0) +#         weighted = weighted.squeeze(0) + +#         logits = self.fc_out(torch.cat((output, weighted, trg_embedded), dim=1)) + +#         # prediction = [batch size, output dim] + +#         return logits, hidden.squeeze(0) + + +# class NeuralMachineReader(nn.Module): +#     def __init__( +#         self, +#         embedding_dim: int, +#         encoder_dim: int, +#         decoder_dim: int, +#         output_dim: int, +#         backbone: Optional[str] = None, +#         backbone_args: Optional[Dict] = None, +#         adaptive_pool_dim: Tuple = (None, 1), +#         dropout_rate: float = 0.1, +#         teacher_forcing_ratio: float = 0.5, +#     ) -> None: +#         super().__init__() + +#         self.backbone = configure_backbone(backbone, backbone_args) +#         self.adaptive_pool = nn.AdaptiveAvgPool2d((adaptive_pool_dim)) + +#         self.encoder = Encoder(embedding_dim, encoder_dim, decoder_dim, dropout_rate) +#         self.decoder = Decoder( +#             embedding_dim, encoder_dim, decoder_dim, output_dim, dropout_rate +#         ) +#         self.teacher_forcing_ratio = teacher_forcing_ratio + +#     def extract_image_features(self, x: Tensor) -> Tensor: +#         x = self.backbone(x) +#         x = rearrange(x, "b c h w -> b w c h") +#         x = self.adaptive_pool(x) +#         x = x.squeeze(3) + +#     def forward(self, x: Tensor, trg: Tensor) -> Tensor: +#         # x = [batch size, height, width] +#         # trg = [trg len, batch size] +#         z = self.extract_image_features(x) diff --git a/src/text_recognizer/networks/stn.py b/src/text_recognizer/networks/stn.py index b031128..e9d216f 100644 --- a/src/text_recognizer/networks/stn.py +++ b/src/text_recognizer/networks/stn.py @@ -13,7 +13,7 @@ class SpatialTransformerNetwork(nn.Module):      Network that learns how to perform spatial transformations on the input image in order to enhance the      geometric invariance of the model. -    # TODO: add arguements to make it more general. +    # TODO: add arguments to make it more general.      """ diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py index b31e640..e2d7955 100644 --- a/src/text_recognizer/networks/util.py +++ b/src/text_recognizer/networks/util.py @@ -24,7 +24,7 @@ def sliding_window(      """      unfold = nn.Unfold(kernel_size=patch_size, stride=stride) -    # Preform the slidning window, unsqueeze as the channel dimesion is lost. +    # Preform the sliding window, unsqueeze as the channel dimesion is lost.      c = images.shape[1]      patches = unfold(images)      patches = rearrange( diff --git a/src/text_recognizer/networks/vision_transformer.py b/src/text_recognizer/networks/vision_transformer.py deleted file mode 100644 index f227954..0000000 --- a/src/text_recognizer/networks/vision_transformer.py +++ /dev/null @@ -1,159 +0,0 @@ -"""VisionTransformer module. - -Splits each image into patches and feeds them to a transformer. - -""" - -from typing import Dict, Optional, Tuple, Type - -from einops import rearrange, reduce -from einops.layers.torch import Rearrange -from loguru import logger -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.transformer import PositionalEncoding, Transformer -from text_recognizer.networks.util import configure_backbone - - -class VisionTransformer(nn.Module): -    """Linear projection+Transfomer for image to sequence prediction, sort of based on the ideas from ViT.""" - -    def __init__( -        self, -        num_encoder_layers: int, -        num_decoder_layers: int, -        hidden_dim: int, -        vocab_size: int, -        num_heads: int, -        max_len: int, -        expansion_dim: int, -        dropout_rate: float, -        trg_pad_index: int, -        mlp_dim: Optional[int] = None, -        patch_size: Tuple[int, int] = (28, 28), -        stride: Tuple[int, int] = (1, 14), -        activation: str = "gelu", -        backbone: Optional[str] = None, -        backbone_args: Optional[Dict] = None, -    ) -> None: -        super().__init__() - -        self.patch_size = patch_size -        self.stride = stride -        self.trg_pad_index = trg_pad_index -        self.slidning_window = self._configure_sliding_window() -        self.character_embedding = nn.Embedding(vocab_size, hidden_dim) -        self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate, max_len) -        self.mlp_dim = mlp_dim - -        self.use_backbone = False -        if backbone is None: -            self.linear_projection = nn.Linear( -                self.patch_size[0] * self.patch_size[1], hidden_dim -            ) -        else: -            self.backbone = configure_backbone(backbone, backbone_args) -            if mlp_dim: -                self.mlp = nn.Linear(mlp_dim, hidden_dim) -            self.use_backbone = True - -        self.transformer = Transformer( -            num_encoder_layers, -            num_decoder_layers, -            hidden_dim, -            num_heads, -            expansion_dim, -            dropout_rate, -            activation, -        ) - -        self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) - -    def _configure_sliding_window(self) -> nn.Sequential: -        return nn.Sequential( -            nn.Unfold(kernel_size=self.patch_size, stride=self.stride), -            Rearrange( -                "b (c h w) t -> b t c h w", -                h=self.patch_size[0], -                w=self.patch_size[1], -                c=1, -            ), -        ) - -    def _create_trg_mask(self, trg: Tensor) -> Tensor: -        # Move this outside the transformer. -        trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] -        trg_len = trg.shape[1] -        trg_sub_mask = torch.tril( -            torch.ones((trg_len, trg_len), device=trg.device) -        ).bool() -        trg_mask = trg_pad_mask & trg_sub_mask -        return trg_mask - -    def encoder(self, src: Tensor) -> Tensor: -        """Forward pass with the encoder of the transformer.""" -        return self.transformer.encoder(src) - -    def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: -        """Forward pass with the decoder of the transformer + classification head.""" -        return self.head( -            self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) -        ) - -    def _backbone(self, x: Tensor) -> Tensor: -        b, t = x.shape[:2] -        if self.use_backbone: -            x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) -            x = self.backbone(x) -            if self.mlp_dim: -                x = rearrange(x, "(b t) c h w -> b t (c h w)", b=b, t=t) -                x = self.mlp(x) -            else: -                x = rearrange(x, "(b t) h -> b t h", b=b, t=t) -        else: -            x = rearrange(x, "b t c h w -> b t (c h w)", b=b, t=t) -            x = self.linear_projection(x) -        return x - -    def preprocess_input(self, src: Tensor) -> Tensor: -        """Encodes src with a backbone network and a positional encoding. - -        Args: -            src (Tensor): Input tensor. - -        Returns: -            Tensor: A input src to the transformer. - -        """ -        # If batch dimenstion is missing, it needs to be added. -        if len(src.shape) < 4: -            src = src[(None,) * (4 - len(src.shape))] -        src = self.slidning_window(src)  # .squeeze(-2) -        src = self._backbone(src) -        src = self.position_encoding(src) -        return src - -    def preprocess_target(self, trg: Tensor) -> Tuple[Tensor, Tensor]: -        """Encodes target tensor with embedding and postion. - -        Args: -            trg (Tensor): Target tensor. - -        Returns: -            Tuple[Tensor, Tensor]: Encoded target tensor and target mask. - -        """ -        trg_mask = self._create_trg_mask(trg) -        trg = self.character_embedding(trg.long()) -        trg = self.position_encoding(trg) -        return trg, trg_mask - -    def forward(self, x: Tensor, trg: Tensor) -> Tensor: -        """Forward pass with vision transfomer.""" -        src = self.preprocess_input(x) -        trg, trg_mask = self.preprocess_target(trg) -        out = self.transformer(src, trg, trg_mask=trg_mask) -        logits = self.head(out) -        return logits diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py index aa79c12..28f3380 100644 --- a/src/text_recognizer/networks/wide_resnet.py +++ b/src/text_recognizer/networks/wide_resnet.py @@ -2,7 +2,7 @@  from functools import partial  from typing import Callable, Dict, List, Optional, Type, Union -from einops.layers.torch import Rearrange, Reduce +from einops.layers.torch import Reduce  import numpy as np  import torch  from torch import nn diff --git a/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt b/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt Binary files differindex 726c723..344e0a3 100644 --- a/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt +++ b/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt Binary files differindex 6a9a915..f2dfd84 100644 --- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt +++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt Binary files differindex 2d5a89b..e1add8d 100644 --- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt +++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt diff --git a/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt Binary files differindex 59c06c2..04e1952 100644 --- a/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt +++ b/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt diff --git a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt Binary files differindex 7fe1fa3..50a6a20 100644 --- a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt +++ b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt  |