summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/neural_machine_reader.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/neural_machine_reader.py')
-rw-r--r--src/text_recognizer/networks/neural_machine_reader.py361
1 files changed, 191 insertions, 170 deletions
diff --git a/src/text_recognizer/networks/neural_machine_reader.py b/src/text_recognizer/networks/neural_machine_reader.py
index 540a7d2..7f8c49b 100644
--- a/src/text_recognizer/networks/neural_machine_reader.py
+++ b/src/text_recognizer/networks/neural_machine_reader.py
@@ -1,180 +1,201 @@
-from typing import Dict, Optional, Tuple
+"""Sequence to sequence network with RNN cells."""
+# from typing import Dict, Optional, Tuple
-from einops import rearrange
-from einops.layers.torch import Rearrange
-import torch
-from torch import nn
-from torch import Tensor
+# from einops import rearrange
+# from einops.layers.torch import Rearrange
+# import torch
+# from torch import nn
+# from torch import Tensor
-from text_recognizer.networks.util import configure_backbone
+# from text_recognizer.networks.util import configure_backbone
-class Encoder(nn.Module):
+# class Encoder(nn.Module):
+# def __init__(
+# self,
+# embedding_dim: int,
+# encoder_dim: int,
+# decoder_dim: int,
+# dropout_rate: float = 0.1,
+# ) -> None:
+# super().__init__()
+# self.rnn = nn.GRU(
+# input_size=embedding_dim, hidden_size=encoder_dim, bidirectional=True
+# )
+# self.fc = nn.Sequential(
+# nn.Linear(in_features=2 * encoder_dim, out_features=decoder_dim), nn.Tanh()
+# )
+# self.dropout = nn.Dropout(p=dropout_rate)
+
+# def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+# """Encodes a sequence of tensors with a bidirectional GRU.
+
+# Args:
+# x (Tensor): A input sequence.
+
+# Shape:
+# - x: :math:`(T, N, E)`.
+# - output[0]: :math:`(T, N, 2 * E)`.
+# - output[1]: :math:`(T, N, D)`.
+
+# where T is the sequence length, N is the batch size, E is the
+# embedding/encoder dimension, and D is the decoder dimension.
+
+# Returns:
+# Tuple[Tensor, Tensor]: The encoder output and the hidden state of the
+# encoder.
+
+# """
+
+# output, hidden = self.rnn(x)
- def __init__(self, embedding_dim: int, encoder_dim: int, decoder_dim: int, dropout_rate: float = 0.1) -> None:
- super().__init__()
- self.rnn = nn.GRU(input_size=embedding_dim, hidden_size=encoder_dim, bidirectional=True)
- self.fc = nn.Sequential(nn.Linear(in_features=2*encoder_dim, out_features=decoder_dim), nn.Tanh())
- self.dropout = nn.Dropout(p=dropout_rate)
+# # Get the hidden state from the forward and backward rnn.
+# hidden_state = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)
+
+# # Apply fully connected layer and tanh activation.
+# hidden_state = self.fc(hidden_state)
+
+# return output, hidden_state
+
+
+# class Attention(nn.Module):
+# def __init__(self, encoder_dim: int, decoder_dim: int) -> None:
+# super().__init__()
+# self.atten = nn.Linear(
+# in_features=2 * encoder_dim + decoder_dim, out_features=decoder_dim
+# )
+# self.value = nn.Linear(in_features=decoder_dim, out_features=1, bias=False)
- def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Encodes a sequence of tensors with a bidirectional GRU.
+# def forward(self, hidden_state: Tensor, encoder_outputs: Tensor) -> Tensor:
+# """Short summary.
- Args:
- x (Tensor): A input sequence.
+# Args:
+# hidden_state (Tensor): Description of parameter `h`.
+# encoder_outputs (Tensor): Description of parameter `enc_out`.
- Shape:
- - x: :math:`(T, N, E)`.
- - output[0]: :math:`(T, N, 2 * E)`.
- - output[1]: :math:`(T, N, D)`.
+# Shape:
+# - x: :math:`(T, N, E)`.
+# - output[0]: :math:`(T, N, 2 * E)`.
+# - output[1]: :math:`(T, N, D)`.
+
+# where T is the sequence length, N is the batch size, E is the
+# embedding/encoder dimension, and D is the decoder dimension.
+
+# Returns:
+# Tensor: Description of returned object.
+
+# """
+# t, b = enc_out.shape[:2]
+# # repeat decoder hidden state src_len times
+# hidden_state = hidden_state.unsqueeze(1).repeat(1, t, 1)
+
+# encoder_outputs = rearrange(encoder_outputs, "t b e2 -> b t e2")
+
+# # Calculate the energy between the decoders previous hidden state and the
+# # encoders hidden states.
+# energy = torch.tanh(
+# self.attn(torch.cat((hidden_state, encoder_outputs), dim=2))
+# )
+
+# attention = self.value(energy).squeeze(2)
+
+# # Apply softmax on the attention to squeeze it between 0 and 1.
+# attention = F.softmax(attention, dim=1)
+
+# return attention
+
+
+# class Decoder(nn.Module):
+# def __init__(
+# self,
+# embedding_dim: int,
+# encoder_dim: int,
+# decoder_dim: int,
+# output_dim: int,
+# dropout_rate: float = 0.1,
+# ) -> None:
+# super().__init__()
+# self.output_dim = output_dim
+# self.embedding = nn.Embedding(output_dim, embedding_dim)
+# self.attention = Attention(encoder_dim, decoder_dim)
+# self.rnn = nn.GRU(
+# input_size=2 * encoder_dim + embedding_dim, hidden_size=decoder_dim
+# )
+
+# self.head = nn.Linear(
+# in_features=2 * encoder_dim + embedding_dim + decoder_dim,
+# out_features=output_dim,
+# )
+# self.dropout = nn.Dropout(p=dropout_rate)
+
+# def forward(
+# self, trg: Tensor, hidden_state: Tensor, encoder_outputs: Tensor
+# ) -> Tensor:
+# # input = [batch size]
+# # hidden = [batch size, dec hid dim]
+# # encoder_outputs = [src len, batch size, enc hid dim * 2]
+# trg = trg.unsqueeze(0)
+# trg_embedded = self.dropout(self.embedding(trg))
- where T is the sequence length, N is the batch size, E is the
- embedding/encoder dimension, and D is the decoder dimension.
+# a = self.attention(hidden_state, encoder_outputs)
+
+# weighted = torch.bmm(a, encoder_outputs)
- Returns:
- Tuple[Tensor, Tensor]: The encoder output and the hidden state of the
- encoder.
-
- """
-
- output, hidden = self.rnn(x)
-
- # Get the hidden state from the forward and backward rnn.
- hidden_state = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1))
-
- # Apply fully connected layer and tanh activation.
- hidden_state = self.fc(hidden_state)
-
- return output, hidden_state
-
-
-class Attention(nn.Module):
-
- def __init__(self, encoder_dim: int, decoder_dim: int) -> None:
- super().__init__()
- self.atten = nn.Linear(in_features=2*encoder_dim + decoder_dim, out_features=decoder_dim)
- self.value = nn.Linear(in_features=decoder_dim, out_features=1, bias=False)
-
- def forward(self, hidden_state: Tensor, encoder_outputs: Tensor) -> Tensor:
- """Short summary.
-
- Args:
- hidden_state (Tensor): Description of parameter `h`.
- encoder_outputs (Tensor): Description of parameter `enc_out`.
-
- Shape:
- - x: :math:`(T, N, E)`.
- - output[0]: :math:`(T, N, 2 * E)`.
- - output[1]: :math:`(T, N, D)`.
-
- where T is the sequence length, N is the batch size, E is the
- embedding/encoder dimension, and D is the decoder dimension.
-
- Returns:
- Tensor: Description of returned object.
-
- """
- t, b = enc_out.shape[:2]
- #repeat decoder hidden state src_len times
- hidden_state = hidden_state.unsqueeze(1).repeat(1, t, 1)
-
- encoder_outputs = rearrange(encoder_outputs, "t b e2 -> b t e2")
-
- # Calculate the energy between the decoders previous hidden state and the
- # encoders hidden states.
- energy = torch.tanh(self.attn(torch.cat((hidden_state, encoder_outputs), dim = 2)))
-
- attention = self.value(energy).squeeze(2)
-
- # Apply softmax on the attention to squeeze it between 0 and 1.
- attention = F.softmax(attention, dim=1)
-
- return attention
-
-
-class Decoder(nn.Module):
-
- def __init__(self, embedding_dim: int, encoder_dim: int, decoder_dim: int, output_dim: int, dropout_rate: float = 0.1) -> None:
- super().__init__()
- self.output_dim = output_dim
- self.embedding = nn.Embedding(output_dim, embedding_dim)
- self.attention = Attention(encoder_dim, decoder_dim)
- self.rnn = nn.GRU(input_size=2*encoder_dim + embedding_dim, hidden_size=decoder_dim)
-
- self.head = nn.Linear(in_features=2*encoder_dim+embedding_dim+decoder_dim, out_features=output_dim)
- self.dropout = nn.Dropout(p=dropout_rate)
-
- def forward(self, trg: Tensor, hidden_state: Tensor, encoder_outputs: Tensor) -> Tensor:
- #input = [batch size]
- #hidden = [batch size, dec hid dim]
- #encoder_outputs = [src len, batch size, enc hid dim * 2]
- trg = trg.unsqueeze(0)
- trg_embedded = self.dropout(self.embedding(trg))
-
- a = self.attention(hidden_state, encoder_outputs)
-
- weighted = torch.bmm(a, encoder_outputs)
-
- # Permutate the tensor.
- weighted = rearrange(weighted, "b a e2 -> a b e2")
-
- rnn_input = torch.cat((trg_embedded, weighted), dim = 2)
-
- output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
-
- #seq len, n layers and n directions will always be 1 in this decoder, therefore:
- #output = [1, batch size, dec hid dim]
- #hidden = [1, batch size, dec hid dim]
- #this also means that output == hidden
- assert (output == hidden).all()
-
- trg_embedded = trg_embedded.squeeze(0)
- output = output.squeeze(0)
- weighted = weighted.squeeze(0)
-
- logits = self.fc_out(torch.cat((output, weighted, trg_embedded), dim = 1))
-
- #prediction = [batch size, output dim]
-
- return logits, hidden.squeeze(0)
-
-
-class NeuralMachineReader(nn.Module):
-
- def __init__(self, embedding_dim: int, encoder_dim: int, decoder_dim: int, output_dim: int, backbone: Optional[str] = None,
- backbone_args: Optional[Dict] = None, patch_size: Tuple[int, int] = (28, 28),
- stride: Tuple[int, int] = (1, 14), dropout_rate: float = 0.1, teacher_forcing_ratio: float = 0.5) -> None:
- super().__init__()
- self.patch_size = patch_size
- self.stride = stride
- self.sliding_window = self._configure_sliding_window()
-
- self.backbone =
- self.encoder = Encoder(embedding_dim, encoder_dim, decoder_dim, dropout_rate)
- self.decoder = Decoder(embedding_dim, encoder_dim, decoder_dim, output_dim, dropout_rate)
- self.teacher_forcing_ratio = teacher_forcing_ratio
-
- def _configure_sliding_window(self) -> nn.Sequential:
- return nn.Sequential(
- nn.Unfold(kernel_size=self.patch_size, stride=self.stride),
- Rearrange(
- "b (c h w) t -> b t c h w",
- h=self.patch_size[0],
- w=self.patch_size[1],
- c=1,
- ),
- )
-
- def forward(self, x: Tensor, trg: Tensor) -> Tensor:
- #x = [batch size, height, width]
- #trg = [trg len, batch size]
-
- # Create image patches with a sliding window kernel.
- x = self.sliding_window(x)
-
- # Rearrange from a sequence of patches for feedforward network.
- b, t = x.shape[:2]
- x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
-
- x = self.backbone(x)
- x = rearrange(x, "(b t) h -> t b h", b=b, t=t)
+# # Permutate the tensor.
+# weighted = rearrange(weighted, "b a e2 -> a b e2")
+
+# rnn_input = torch.cat((trg_embedded, weighted), dim=2)
+
+# output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
+
+# # seq len, n layers and n directions will always be 1 in this decoder, therefore:
+# # output = [1, batch size, dec hid dim]
+# # hidden = [1, batch size, dec hid dim]
+# # this also means that output == hidden
+# assert (output == hidden).all()
+
+# trg_embedded = trg_embedded.squeeze(0)
+# output = output.squeeze(0)
+# weighted = weighted.squeeze(0)
+
+# logits = self.fc_out(torch.cat((output, weighted, trg_embedded), dim=1))
+
+# # prediction = [batch size, output dim]
+
+# return logits, hidden.squeeze(0)
+
+
+# class NeuralMachineReader(nn.Module):
+# def __init__(
+# self,
+# embedding_dim: int,
+# encoder_dim: int,
+# decoder_dim: int,
+# output_dim: int,
+# backbone: Optional[str] = None,
+# backbone_args: Optional[Dict] = None,
+# adaptive_pool_dim: Tuple = (None, 1),
+# dropout_rate: float = 0.1,
+# teacher_forcing_ratio: float = 0.5,
+# ) -> None:
+# super().__init__()
+
+# self.backbone = configure_backbone(backbone, backbone_args)
+# self.adaptive_pool = nn.AdaptiveAvgPool2d((adaptive_pool_dim))
+
+# self.encoder = Encoder(embedding_dim, encoder_dim, decoder_dim, dropout_rate)
+# self.decoder = Decoder(
+# embedding_dim, encoder_dim, decoder_dim, output_dim, dropout_rate
+# )
+# self.teacher_forcing_ratio = teacher_forcing_ratio
+
+# def extract_image_features(self, x: Tensor) -> Tensor:
+# x = self.backbone(x)
+# x = rearrange(x, "b c h w -> b w c h")
+# x = self.adaptive_pool(x)
+# x = x.squeeze(3)
+
+# def forward(self, x: Tensor, trg: Tensor) -> Tensor:
+# # x = [batch size, height, width]
+# # trg = [trg len, batch size]
+# z = self.extract_image_features(x)