summaryrefslogtreecommitdiff
path: root/src/text_recognizer
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-12 23:42:03 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-12 23:42:03 +0100
commit8fdb6435e15703fa5b76df19728d905650ee1aef (patch)
treebe3bec9e5cab4ef7f9d94528d102e57ce9b16c3f /src/text_recognizer
parentdc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 (diff)
parent6cb08a110620ee09fe9d8a5d008197a801d025df (diff)
Working cnn transformer.
Diffstat (limited to 'src/text_recognizer')
-rw-r--r--src/text_recognizer/datasets/transforms.py15
-rw-r--r--src/text_recognizer/models/__init__.py14
-rw-r--r--src/text_recognizer/models/base.py8
-rw-r--r--src/text_recognizer/models/metrics.py21
-rw-r--r--src/text_recognizer/models/transformer_encoder_model.py111
-rw-r--r--src/text_recognizer/models/transformer_model.py (renamed from src/text_recognizer/models/vision_transformer_model.py)13
-rw-r--r--src/text_recognizer/networks/__init__.py6
-rw-r--r--src/text_recognizer/networks/cnn_transformer.py46
-rw-r--r--src/text_recognizer/networks/cnn_transformer_encoder.py73
-rw-r--r--src/text_recognizer/networks/loss/__init__.py2
-rw-r--r--src/text_recognizer/networks/loss/loss.py (renamed from src/text_recognizer/networks/loss.py)0
-rw-r--r--src/text_recognizer/networks/neural_machine_reader.py201
-rw-r--r--src/text_recognizer/networks/stn.py2
-rw-r--r--src/text_recognizer/networks/util.py2
-rw-r--r--src/text_recognizer/networks/vision_transformer.py159
-rw-r--r--src/text_recognizer/networks/wide_resnet.py2
-rw-r--r--src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.ptbin5628749 -> 132 bytes
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.ptbin1273881 -> 132 bytes
-rw-r--r--src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.ptbin14953410 -> 133 bytes
-rw-r--r--src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.ptbin61946486 -> 133 bytes
-rw-r--r--src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.ptbin3457858 -> 132 bytes
21 files changed, 269 insertions, 406 deletions
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
index 8deac7f..1105f23 100644
--- a/src/text_recognizer/datasets/transforms.py
+++ b/src/text_recognizer/datasets/transforms.py
@@ -3,7 +3,8 @@ import numpy as np
from PIL import Image
import torch
from torch import Tensor
-from torchvision.transforms import Compose, Resize, ToPILImage, ToTensor
+import torch.nn.functional as F
+from torchvision.transforms import Compose, ToPILImage, ToTensor
from text_recognizer.datasets.util import EmnistMapper
@@ -16,6 +17,18 @@ class Transpose:
return np.array(image).swapaxes(0, 1)
+class Resize:
+ """Resizes a tensor to a specified width."""
+
+ def __init__(self, width: int = 952) -> None:
+ # The default is 952 because of the IAM dataset.
+ self.width = width
+
+ def __call__(self, image: Tensor) -> Tensor:
+ """Resize tensor in the last dimension."""
+ return F.interpolate(image, size=self.width, mode="nearest")
+
+
class AddTokens:
"""Adds start of sequence and end of sequence tokens to target tensor."""
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index 28aa52e..53340f1 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -2,18 +2,16 @@
from .base import Model
from .character_model import CharacterModel
from .crnn_model import CRNNModel
-from .metrics import accuracy, cer, wer
-from .transformer_encoder_model import TransformerEncoderModel
-from .vision_transformer_model import VisionTransformerModel
+from .metrics import accuracy, accuracy_ignore_pad, cer, wer
+from .transformer_model import TransformerModel
__all__ = [
- "Model",
+ "accuracy",
+ "accuracy_ignore_pad",
"cer",
"CharacterModel",
"CRNNModel",
- "CNNTransfromerModel",
- "accuracy",
- "TransformerEncoderModel",
- "VisionTransformerModel",
+ "Model",
+ "TransformerModel",
"wer",
]
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index cc44c92..a945b41 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -49,7 +49,7 @@ class Model(ABC):
network_args (Optional[Dict]): Arguments for the network. Defaults to None.
dataset_args (Optional[Dict]): Arguments for the dataset.
metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
- criterion (Optional[Callable]): The criterion to evaulate the preformance of the network.
+ criterion (Optional[Callable]): The criterion to evaluate the performance of the network.
Defaults to None.
criterion_args (Optional[Dict]): Dict of arguments for criterion. Defaults to None.
optimizer (Optional[Callable]): The optimizer for updating the weights. Defaults to None.
@@ -221,7 +221,7 @@ class Model(ABC):
def _configure_network(self, network_fn: Type[nn.Module]) -> None:
"""Loads the network."""
- # If no network arguemnts are given, load pretrained weights if they exist.
+ # If no network arguments are given, load pretrained weights if they exist.
if self._network_args is None:
self.load_weights(network_fn)
else:
@@ -245,7 +245,7 @@ class Model(ABC):
self._optimizer = None
if self._optimizer and self._lr_scheduler is not None:
- if "OneCycleLR" in str(self._lr_scheduler):
+ if "steps_per_epoch" in self.lr_scheduler_args:
self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader())
# Assume lr scheduler should update at each epoch if not specified.
@@ -412,7 +412,7 @@ class Model(ABC):
self._optimizer.load_state_dict(checkpoint["optimizer_state"])
if self._lr_scheduler is not None:
- # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs
+ # Does not work when loading from previous checkpoint and trying to train beyond the last max epochs
# with OneCycleLR.
if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR":
self._lr_scheduler["lr_scheduler"].load_state_dict(
diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py
index 42c3c6e..af9adb5 100644
--- a/src/text_recognizer/models/metrics.py
+++ b/src/text_recognizer/models/metrics.py
@@ -6,7 +6,23 @@ from torch import Tensor
from text_recognizer.networks import greedy_decoder
-def accuracy(outputs: Tensor, labels: Tensor) -> float:
+def accuracy_ignore_pad(
+ output: Tensor,
+ target: Tensor,
+ pad_index: int = 79,
+ eos_index: int = 81,
+ seq_len: int = 97,
+) -> float:
+ """Sets all predictions after eos to pad."""
+ start_indices = torch.nonzero(target == eos_index, as_tuple=False).squeeze(1)
+ end_indices = torch.arange(seq_len, target.shape[0] + 1, seq_len)
+ for start, stop in zip(start_indices, end_indices):
+ output[start + 1 : stop] = pad_index
+
+ return accuracy(output, target)
+
+
+def accuracy(outputs: Tensor, labels: Tensor,) -> float:
"""Computes the accuracy.
Args:
@@ -17,10 +33,9 @@ def accuracy(outputs: Tensor, labels: Tensor) -> float:
float: The accuracy for the batch.
"""
- # eos_index = torch.nonzero(labels == eos, as_tuple=False)
- # eos_index = eos_index[0].item() if eos_index.nelement() else -1
_, predicted = torch.max(outputs, dim=-1)
+
acc = (predicted == labels).sum().float() / labels.shape[0]
acc = acc.item()
return acc
diff --git a/src/text_recognizer/models/transformer_encoder_model.py b/src/text_recognizer/models/transformer_encoder_model.py
deleted file mode 100644
index e35e298..0000000
--- a/src/text_recognizer/models/transformer_encoder_model.py
+++ /dev/null
@@ -1,111 +0,0 @@
-"""Defines the CNN-Transformer class."""
-from typing import Callable, Dict, List, Optional, Tuple, Type, Union
-
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-from torch.utils.data import Dataset
-from torchvision.transforms import ToTensor
-
-from text_recognizer.datasets import EmnistMapper
-from text_recognizer.models.base import Model
-
-
-class TransformerEncoderModel(Model):
- """A class for only using the encoder part in the sequence modelling."""
-
- def __init__(
- self,
- network_fn: Type[nn.Module],
- dataset: Type[Dataset],
- network_args: Optional[Dict] = None,
- dataset_args: Optional[Dict] = None,
- metrics: Optional[Dict] = None,
- criterion: Optional[Callable] = None,
- criterion_args: Optional[Dict] = None,
- optimizer: Optional[Callable] = None,
- optimizer_args: Optional[Dict] = None,
- lr_scheduler: Optional[Callable] = None,
- lr_scheduler_args: Optional[Dict] = None,
- swa_args: Optional[Dict] = None,
- device: Optional[str] = None,
- ) -> None:
- super().__init__(
- network_fn,
- dataset,
- network_args,
- dataset_args,
- metrics,
- criterion,
- criterion_args,
- optimizer,
- optimizer_args,
- lr_scheduler,
- lr_scheduler_args,
- swa_args,
- device,
- )
- # self.init_token = dataset_args["args"]["init_token"]
- self.pad_token = dataset_args["args"]["pad_token"]
- self.eos_token = dataset_args["args"]["eos_token"]
- if network_args is not None:
- self.max_len = network_args["max_len"]
- else:
- self.max_len = 128
-
- if self._mapper is None:
- self._mapper = EmnistMapper(
- # init_token=self.init_token,
- pad_token=self.pad_token,
- eos_token=self.eos_token,
- )
- self.tensor_transform = ToTensor()
-
- self.softmax = nn.Softmax(dim=2)
-
- @torch.no_grad()
- def _generate_sentence(self, image: Tensor) -> Tuple[List, float]:
- logits = self.network(image)
- # Convert logits to probabilities.
- probs = self.softmax(logits).squeeze(0)
-
- confidence, pred_tokens = probs.max(1)
- pred_tokens = pred_tokens
-
- eos_index = torch.nonzero(
- pred_tokens == self._mapper(self.eos_token), as_tuple=False,
- )
-
- eos_index = eos_index[0].item() if eos_index.nelement() else -1
-
- predicted_characters = "".join(
- [self.mapper(x) for x in pred_tokens[:eos_index].tolist()]
- )
-
- confidence = np.min(confidence.tolist())
-
- return predicted_characters, confidence
-
- @torch.no_grad()
- def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
- """Predict on a single input."""
- self.eval()
-
- if image.dtype == np.uint8:
- # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
- image = self.tensor_transform(image)
-
- # Rescale image between 0 and 1.
- if image.dtype == torch.uint8:
- # If the image is an unscaled tensor.
- image = image.type("torch.FloatTensor") / 255
-
- # Put the image tensor on the device the model weights are on.
- image = image.to(self.device)
-
- (predicted_characters, confidence_of_prediction,) = self._generate_sentence(
- image
- )
-
- return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/models/vision_transformer_model.py b/src/text_recognizer/models/transformer_model.py
index 3d36437..968a047 100644
--- a/src/text_recognizer/models/vision_transformer_model.py
+++ b/src/text_recognizer/models/transformer_model.py
@@ -13,7 +13,7 @@ from text_recognizer.models.base import Model
from text_recognizer.networks import greedy_decoder
-class VisionTransformerModel(Model):
+class TransformerModel(Model):
"""Model for predicting a sequence of characters from an image of a text line with a cnn-transformer."""
def __init__(
@@ -50,10 +50,7 @@ class VisionTransformerModel(Model):
self.init_token = dataset_args["args"]["init_token"]
self.pad_token = dataset_args["args"]["pad_token"]
self.eos_token = dataset_args["args"]["eos_token"]
- if network_args is not None:
- self.max_len = network_args["max_len"]
- else:
- self.max_len = 120
+ self.max_len = 120
if self._mapper is None:
self._mapper = EmnistMapper(
@@ -67,7 +64,7 @@ class VisionTransformerModel(Model):
@torch.no_grad()
def _generate_sentence(self, image: Tensor) -> Tuple[List, float]:
- src = self.network.preprocess_input(image)
+ src = self.network.extract_image_features(image)
memory = self.network.encoder(src)
confidence_of_predictions = []
@@ -75,7 +72,7 @@ class VisionTransformerModel(Model):
for _ in range(self.max_len - 1):
trg = torch.tensor(trg_indices, device=self.device)[None, :].long()
- trg = self.network.preprocess_target(trg)
+ trg = self.network.target_embedding(trg)
logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None)
# Convert logits to probabilities.
@@ -101,7 +98,7 @@ class VisionTransformerModel(Model):
self.eval()
if image.dtype == np.uint8:
- # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1].
image = self.tensor_transform(image)
# Rescale image between 0 and 1.
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index 6d88768..2cc1137 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -1,25 +1,20 @@
"""Network modules."""
from .cnn_transformer import CNNTransformer
-from .cnn_transformer_encoder import CNNTransformerEncoder
from .crnn import ConvolutionalRecurrentNetwork
from .ctc import greedy_decoder
from .densenet import DenseNet
from .lenet import LeNet
-from .loss import EmbeddingLoss
from .mlp import MLP
from .residual_network import ResidualNetwork, ResidualNetworkEncoder
from .sparse_mlp import SparseMLP
from .transformer import Transformer
from .util import sliding_window
-from .vision_transformer import VisionTransformer
from .wide_resnet import WideResidualNetwork
__all__ = [
"CNNTransformer",
- "CNNTransformerEncoder",
"ConvolutionalRecurrentNetwork",
"DenseNet",
- "EmbeddingLoss",
"greedy_decoder",
"MLP",
"LeNet",
@@ -28,6 +23,5 @@ __all__ = [
"sliding_window",
"Transformer",
"SparseMLP",
- "VisionTransformer",
"WideResidualNetwork",
]
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py
index 3da2c9f..16c7a41 100644
--- a/src/text_recognizer/networks/cnn_transformer.py
+++ b/src/text_recognizer/networks/cnn_transformer.py
@@ -1,4 +1,4 @@
-"""A DETR style transfomers but for text recognition."""
+"""A CNN-Transformer for image to text recognition."""
from typing import Dict, Optional, Tuple
from einops import rearrange
@@ -11,7 +11,7 @@ from text_recognizer.networks.util import configure_backbone
class CNNTransformer(nn.Module):
- """CNN+Transfomer for image to sequence prediction, sort of based on the ideas from DETR."""
+ """CNN+Transfomer for image to sequence prediction."""
def __init__(
self,
@@ -25,22 +25,14 @@ class CNNTransformer(nn.Module):
dropout_rate: float,
trg_pad_index: int,
backbone: str,
- out_channels: int,
- max_len: int,
backbone_args: Optional[Dict] = None,
activation: str = "gelu",
) -> None:
super().__init__()
self.trg_pad_index = trg_pad_index
-
self.backbone = configure_backbone(backbone, backbone_args)
self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
-
- # self.conv = nn.Conv2d(out_channels, max_len, kernel_size=1)
-
self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
- self.row_embed = nn.Parameter(torch.rand(max_len, max_len // 2))
- self.col_embed = nn.Parameter(torch.rand(max_len, max_len // 2))
self.adaptive_pool = (
nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None
@@ -78,8 +70,12 @@ class CNNTransformer(nn.Module):
self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
)
- def preprocess_input(self, src: Tensor) -> Tensor:
- """Encodes src with a backbone network and a positional encoding.
+ def extract_image_features(self, src: Tensor) -> Tensor:
+ """Extracts image features with a backbone neural network.
+
+ It seem like the winning idea was to swap channels and width dimension and collapse
+ the height dimension. The transformer is learning like a baby with this implementation!!! :D
+ Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D
Args:
src (Tensor): Input tensor.
@@ -88,29 +84,19 @@ class CNNTransformer(nn.Module):
Tensor: A input src to the transformer.
"""
- # If batch dimenstion is missing, it needs to be added.
+ # If batch dimension is missing, it needs to be added.
if len(src.shape) < 4:
src = src[(None,) * (4 - len(src.shape))]
src = self.backbone(src)
- # src = self.conv(src)
+ src = rearrange(src, "b c h w -> b w c h")
if self.adaptive_pool is not None:
src = self.adaptive_pool(src)
- H, W = src.shape[-2:]
- src = rearrange(src, "b t h w -> b t (h w)")
-
- # construct positional encodings
- pos = torch.cat(
- [
- self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
- self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
- ],
- dim=-1,
- ).unsqueeze(0)
- pos = rearrange(pos, "b h w l -> b l (h w)")
- src = pos + 0.1 * src
+ src = src.squeeze(3)
+ src = self.position_encoding(src)
+
return src
- def preprocess_target(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
+ def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
"""Encodes target tensor with embedding and postion.
Args:
@@ -126,9 +112,9 @@ class CNNTransformer(nn.Module):
def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
"""Forward pass with CNN transfomer."""
- h = self.preprocess_input(x)
+ h = self.extract_image_features(x)
trg_mask = self._create_trg_mask(trg)
- trg = self.preprocess_target(trg)
+ trg = self.target_embedding(trg)
out = self.transformer(h, trg, trg_mask=trg_mask)
logits = self.head(out)
diff --git a/src/text_recognizer/networks/cnn_transformer_encoder.py b/src/text_recognizer/networks/cnn_transformer_encoder.py
deleted file mode 100644
index 93626bf..0000000
--- a/src/text_recognizer/networks/cnn_transformer_encoder.py
+++ /dev/null
@@ -1,73 +0,0 @@
-"""Network with a CNN backend and a transformer encoder head."""
-from typing import Dict
-
-from einops import rearrange
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.transformer import PositionalEncoding
-from text_recognizer.networks.util import configure_backbone
-
-
-class CNNTransformerEncoder(nn.Module):
- """A CNN backbone with Transformer Encoder frontend for sequence prediction."""
-
- def __init__(
- self,
- backbone: str,
- backbone_args: Dict,
- mlp_dim: int,
- d_model: int,
- nhead: int = 8,
- dropout_rate: float = 0.1,
- activation: str = "relu",
- num_layers: int = 6,
- num_classes: int = 80,
- num_channels: int = 256,
- max_len: int = 97,
- ) -> None:
- super().__init__()
- self.d_model = d_model
- self.nhead = nhead
- self.dropout_rate = dropout_rate
- self.activation = activation
- self.num_layers = num_layers
-
- self.backbone = configure_backbone(backbone, backbone_args)
- self.position_encoding = PositionalEncoding(d_model, dropout_rate)
- self.encoder = self._configure_encoder()
-
- self.conv = nn.Conv2d(num_channels, max_len, kernel_size=1)
-
- self.mlp = nn.Linear(mlp_dim, d_model)
-
- self.head = nn.Linear(d_model, num_classes)
-
- def _configure_encoder(self) -> nn.TransformerEncoder:
- encoder_layer = nn.TransformerEncoderLayer(
- d_model=self.d_model,
- nhead=self.nhead,
- dropout=self.dropout_rate,
- activation=self.activation,
- )
- norm = nn.LayerNorm(self.d_model)
- return nn.TransformerEncoder(
- encoder_layer=encoder_layer, num_layers=self.num_layers, norm=norm
- )
-
- def forward(self, x: Tensor, targets: Tensor = None) -> Tensor:
- """Forward pass through the network."""
- if len(x.shape) < 4:
- x = x[(None,) * (4 - len(x.shape))]
-
- x = self.conv(self.backbone(x))
- x = rearrange(x, "b c h w -> b c (h w)")
- x = self.mlp(x)
- x = self.position_encoding(x)
- x = rearrange(x, "b c h-> c b h")
- x = self.encoder(x)
- x = rearrange(x, "c b h-> b c h")
- logits = self.head(x)
-
- return logits
diff --git a/src/text_recognizer/networks/loss/__init__.py b/src/text_recognizer/networks/loss/__init__.py
new file mode 100644
index 0000000..b489264
--- /dev/null
+++ b/src/text_recognizer/networks/loss/__init__.py
@@ -0,0 +1,2 @@
+"""Loss module."""
+from .loss import EmbeddingLoss, LabelSmoothingCrossEntropy
diff --git a/src/text_recognizer/networks/loss.py b/src/text_recognizer/networks/loss/loss.py
index cf9fa0d..cf9fa0d 100644
--- a/src/text_recognizer/networks/loss.py
+++ b/src/text_recognizer/networks/loss/loss.py
diff --git a/src/text_recognizer/networks/neural_machine_reader.py b/src/text_recognizer/networks/neural_machine_reader.py
new file mode 100644
index 0000000..7f8c49b
--- /dev/null
+++ b/src/text_recognizer/networks/neural_machine_reader.py
@@ -0,0 +1,201 @@
+"""Sequence to sequence network with RNN cells."""
+# from typing import Dict, Optional, Tuple
+
+# from einops import rearrange
+# from einops.layers.torch import Rearrange
+# import torch
+# from torch import nn
+# from torch import Tensor
+
+# from text_recognizer.networks.util import configure_backbone
+
+
+# class Encoder(nn.Module):
+# def __init__(
+# self,
+# embedding_dim: int,
+# encoder_dim: int,
+# decoder_dim: int,
+# dropout_rate: float = 0.1,
+# ) -> None:
+# super().__init__()
+# self.rnn = nn.GRU(
+# input_size=embedding_dim, hidden_size=encoder_dim, bidirectional=True
+# )
+# self.fc = nn.Sequential(
+# nn.Linear(in_features=2 * encoder_dim, out_features=decoder_dim), nn.Tanh()
+# )
+# self.dropout = nn.Dropout(p=dropout_rate)
+
+# def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+# """Encodes a sequence of tensors with a bidirectional GRU.
+
+# Args:
+# x (Tensor): A input sequence.
+
+# Shape:
+# - x: :math:`(T, N, E)`.
+# - output[0]: :math:`(T, N, 2 * E)`.
+# - output[1]: :math:`(T, N, D)`.
+
+# where T is the sequence length, N is the batch size, E is the
+# embedding/encoder dimension, and D is the decoder dimension.
+
+# Returns:
+# Tuple[Tensor, Tensor]: The encoder output and the hidden state of the
+# encoder.
+
+# """
+
+# output, hidden = self.rnn(x)
+
+# # Get the hidden state from the forward and backward rnn.
+# hidden_state = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)
+
+# # Apply fully connected layer and tanh activation.
+# hidden_state = self.fc(hidden_state)
+
+# return output, hidden_state
+
+
+# class Attention(nn.Module):
+# def __init__(self, encoder_dim: int, decoder_dim: int) -> None:
+# super().__init__()
+# self.atten = nn.Linear(
+# in_features=2 * encoder_dim + decoder_dim, out_features=decoder_dim
+# )
+# self.value = nn.Linear(in_features=decoder_dim, out_features=1, bias=False)
+
+# def forward(self, hidden_state: Tensor, encoder_outputs: Tensor) -> Tensor:
+# """Short summary.
+
+# Args:
+# hidden_state (Tensor): Description of parameter `h`.
+# encoder_outputs (Tensor): Description of parameter `enc_out`.
+
+# Shape:
+# - x: :math:`(T, N, E)`.
+# - output[0]: :math:`(T, N, 2 * E)`.
+# - output[1]: :math:`(T, N, D)`.
+
+# where T is the sequence length, N is the batch size, E is the
+# embedding/encoder dimension, and D is the decoder dimension.
+
+# Returns:
+# Tensor: Description of returned object.
+
+# """
+# t, b = enc_out.shape[:2]
+# # repeat decoder hidden state src_len times
+# hidden_state = hidden_state.unsqueeze(1).repeat(1, t, 1)
+
+# encoder_outputs = rearrange(encoder_outputs, "t b e2 -> b t e2")
+
+# # Calculate the energy between the decoders previous hidden state and the
+# # encoders hidden states.
+# energy = torch.tanh(
+# self.attn(torch.cat((hidden_state, encoder_outputs), dim=2))
+# )
+
+# attention = self.value(energy).squeeze(2)
+
+# # Apply softmax on the attention to squeeze it between 0 and 1.
+# attention = F.softmax(attention, dim=1)
+
+# return attention
+
+
+# class Decoder(nn.Module):
+# def __init__(
+# self,
+# embedding_dim: int,
+# encoder_dim: int,
+# decoder_dim: int,
+# output_dim: int,
+# dropout_rate: float = 0.1,
+# ) -> None:
+# super().__init__()
+# self.output_dim = output_dim
+# self.embedding = nn.Embedding(output_dim, embedding_dim)
+# self.attention = Attention(encoder_dim, decoder_dim)
+# self.rnn = nn.GRU(
+# input_size=2 * encoder_dim + embedding_dim, hidden_size=decoder_dim
+# )
+
+# self.head = nn.Linear(
+# in_features=2 * encoder_dim + embedding_dim + decoder_dim,
+# out_features=output_dim,
+# )
+# self.dropout = nn.Dropout(p=dropout_rate)
+
+# def forward(
+# self, trg: Tensor, hidden_state: Tensor, encoder_outputs: Tensor
+# ) -> Tensor:
+# # input = [batch size]
+# # hidden = [batch size, dec hid dim]
+# # encoder_outputs = [src len, batch size, enc hid dim * 2]
+# trg = trg.unsqueeze(0)
+# trg_embedded = self.dropout(self.embedding(trg))
+
+# a = self.attention(hidden_state, encoder_outputs)
+
+# weighted = torch.bmm(a, encoder_outputs)
+
+# # Permutate the tensor.
+# weighted = rearrange(weighted, "b a e2 -> a b e2")
+
+# rnn_input = torch.cat((trg_embedded, weighted), dim=2)
+
+# output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
+
+# # seq len, n layers and n directions will always be 1 in this decoder, therefore:
+# # output = [1, batch size, dec hid dim]
+# # hidden = [1, batch size, dec hid dim]
+# # this also means that output == hidden
+# assert (output == hidden).all()
+
+# trg_embedded = trg_embedded.squeeze(0)
+# output = output.squeeze(0)
+# weighted = weighted.squeeze(0)
+
+# logits = self.fc_out(torch.cat((output, weighted, trg_embedded), dim=1))
+
+# # prediction = [batch size, output dim]
+
+# return logits, hidden.squeeze(0)
+
+
+# class NeuralMachineReader(nn.Module):
+# def __init__(
+# self,
+# embedding_dim: int,
+# encoder_dim: int,
+# decoder_dim: int,
+# output_dim: int,
+# backbone: Optional[str] = None,
+# backbone_args: Optional[Dict] = None,
+# adaptive_pool_dim: Tuple = (None, 1),
+# dropout_rate: float = 0.1,
+# teacher_forcing_ratio: float = 0.5,
+# ) -> None:
+# super().__init__()
+
+# self.backbone = configure_backbone(backbone, backbone_args)
+# self.adaptive_pool = nn.AdaptiveAvgPool2d((adaptive_pool_dim))
+
+# self.encoder = Encoder(embedding_dim, encoder_dim, decoder_dim, dropout_rate)
+# self.decoder = Decoder(
+# embedding_dim, encoder_dim, decoder_dim, output_dim, dropout_rate
+# )
+# self.teacher_forcing_ratio = teacher_forcing_ratio
+
+# def extract_image_features(self, x: Tensor) -> Tensor:
+# x = self.backbone(x)
+# x = rearrange(x, "b c h w -> b w c h")
+# x = self.adaptive_pool(x)
+# x = x.squeeze(3)
+
+# def forward(self, x: Tensor, trg: Tensor) -> Tensor:
+# # x = [batch size, height, width]
+# # trg = [trg len, batch size]
+# z = self.extract_image_features(x)
diff --git a/src/text_recognizer/networks/stn.py b/src/text_recognizer/networks/stn.py
index b031128..e9d216f 100644
--- a/src/text_recognizer/networks/stn.py
+++ b/src/text_recognizer/networks/stn.py
@@ -13,7 +13,7 @@ class SpatialTransformerNetwork(nn.Module):
Network that learns how to perform spatial transformations on the input image in order to enhance the
geometric invariance of the model.
- # TODO: add arguements to make it more general.
+ # TODO: add arguments to make it more general.
"""
diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py
index b31e640..e2d7955 100644
--- a/src/text_recognizer/networks/util.py
+++ b/src/text_recognizer/networks/util.py
@@ -24,7 +24,7 @@ def sliding_window(
"""
unfold = nn.Unfold(kernel_size=patch_size, stride=stride)
- # Preform the slidning window, unsqueeze as the channel dimesion is lost.
+ # Preform the sliding window, unsqueeze as the channel dimesion is lost.
c = images.shape[1]
patches = unfold(images)
patches = rearrange(
diff --git a/src/text_recognizer/networks/vision_transformer.py b/src/text_recognizer/networks/vision_transformer.py
deleted file mode 100644
index f227954..0000000
--- a/src/text_recognizer/networks/vision_transformer.py
+++ /dev/null
@@ -1,159 +0,0 @@
-"""VisionTransformer module.
-
-Splits each image into patches and feeds them to a transformer.
-
-"""
-
-from typing import Dict, Optional, Tuple, Type
-
-from einops import rearrange, reduce
-from einops.layers.torch import Rearrange
-from loguru import logger
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.transformer import PositionalEncoding, Transformer
-from text_recognizer.networks.util import configure_backbone
-
-
-class VisionTransformer(nn.Module):
- """Linear projection+Transfomer for image to sequence prediction, sort of based on the ideas from ViT."""
-
- def __init__(
- self,
- num_encoder_layers: int,
- num_decoder_layers: int,
- hidden_dim: int,
- vocab_size: int,
- num_heads: int,
- max_len: int,
- expansion_dim: int,
- dropout_rate: float,
- trg_pad_index: int,
- mlp_dim: Optional[int] = None,
- patch_size: Tuple[int, int] = (28, 28),
- stride: Tuple[int, int] = (1, 14),
- activation: str = "gelu",
- backbone: Optional[str] = None,
- backbone_args: Optional[Dict] = None,
- ) -> None:
- super().__init__()
-
- self.patch_size = patch_size
- self.stride = stride
- self.trg_pad_index = trg_pad_index
- self.slidning_window = self._configure_sliding_window()
- self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
- self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate, max_len)
- self.mlp_dim = mlp_dim
-
- self.use_backbone = False
- if backbone is None:
- self.linear_projection = nn.Linear(
- self.patch_size[0] * self.patch_size[1], hidden_dim
- )
- else:
- self.backbone = configure_backbone(backbone, backbone_args)
- if mlp_dim:
- self.mlp = nn.Linear(mlp_dim, hidden_dim)
- self.use_backbone = True
-
- self.transformer = Transformer(
- num_encoder_layers,
- num_decoder_layers,
- hidden_dim,
- num_heads,
- expansion_dim,
- dropout_rate,
- activation,
- )
-
- self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
-
- def _configure_sliding_window(self) -> nn.Sequential:
- return nn.Sequential(
- nn.Unfold(kernel_size=self.patch_size, stride=self.stride),
- Rearrange(
- "b (c h w) t -> b t c h w",
- h=self.patch_size[0],
- w=self.patch_size[1],
- c=1,
- ),
- )
-
- def _create_trg_mask(self, trg: Tensor) -> Tensor:
- # Move this outside the transformer.
- trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
- trg_len = trg.shape[1]
- trg_sub_mask = torch.tril(
- torch.ones((trg_len, trg_len), device=trg.device)
- ).bool()
- trg_mask = trg_pad_mask & trg_sub_mask
- return trg_mask
-
- def encoder(self, src: Tensor) -> Tensor:
- """Forward pass with the encoder of the transformer."""
- return self.transformer.encoder(src)
-
- def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
- """Forward pass with the decoder of the transformer + classification head."""
- return self.head(
- self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
- )
-
- def _backbone(self, x: Tensor) -> Tensor:
- b, t = x.shape[:2]
- if self.use_backbone:
- x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
- x = self.backbone(x)
- if self.mlp_dim:
- x = rearrange(x, "(b t) c h w -> b t (c h w)", b=b, t=t)
- x = self.mlp(x)
- else:
- x = rearrange(x, "(b t) h -> b t h", b=b, t=t)
- else:
- x = rearrange(x, "b t c h w -> b t (c h w)", b=b, t=t)
- x = self.linear_projection(x)
- return x
-
- def preprocess_input(self, src: Tensor) -> Tensor:
- """Encodes src with a backbone network and a positional encoding.
-
- Args:
- src (Tensor): Input tensor.
-
- Returns:
- Tensor: A input src to the transformer.
-
- """
- # If batch dimenstion is missing, it needs to be added.
- if len(src.shape) < 4:
- src = src[(None,) * (4 - len(src.shape))]
- src = self.slidning_window(src) # .squeeze(-2)
- src = self._backbone(src)
- src = self.position_encoding(src)
- return src
-
- def preprocess_target(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
- """Encodes target tensor with embedding and postion.
-
- Args:
- trg (Tensor): Target tensor.
-
- Returns:
- Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
-
- """
- trg_mask = self._create_trg_mask(trg)
- trg = self.character_embedding(trg.long())
- trg = self.position_encoding(trg)
- return trg, trg_mask
-
- def forward(self, x: Tensor, trg: Tensor) -> Tensor:
- """Forward pass with vision transfomer."""
- src = self.preprocess_input(x)
- trg, trg_mask = self.preprocess_target(trg)
- out = self.transformer(src, trg, trg_mask=trg_mask)
- logits = self.head(out)
- return logits
diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py
index aa79c12..28f3380 100644
--- a/src/text_recognizer/networks/wide_resnet.py
+++ b/src/text_recognizer/networks/wide_resnet.py
@@ -2,7 +2,7 @@
from functools import partial
from typing import Callable, Dict, List, Optional, Type, Union
-from einops.layers.torch import Rearrange, Reduce
+from einops.layers.torch import Reduce
import numpy as np
import torch
from torch import nn
diff --git a/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt b/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt
index 726c723..344e0a3 100644
--- a/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt
+++ b/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt
index 6a9a915..f2dfd84 100644
--- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt
+++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt
index 2d5a89b..e1add8d 100644
--- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt
+++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt
index 59c06c2..04e1952 100644
--- a/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt
+++ b/src/text_recognizer/weights/LineCTCModel_EmnistLinesDataset_LineRecurrentNetwork_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt
index 7fe1fa3..50a6a20 100644
--- a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt
+++ b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt
Binary files differ