From 6cb08a110620ee09fe9d8a5d008197a801d025df Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Sun, 8 Nov 2020 23:06:12 +0100 Subject: working on seq2seq network --- .../networks/neural_machine_reader.py | 180 +++++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 src/text_recognizer/networks/neural_machine_reader.py diff --git a/src/text_recognizer/networks/neural_machine_reader.py b/src/text_recognizer/networks/neural_machine_reader.py new file mode 100644 index 0000000..540a7d2 --- /dev/null +++ b/src/text_recognizer/networks/neural_machine_reader.py @@ -0,0 +1,180 @@ +from typing import Dict, Optional, Tuple + +from einops import rearrange +from einops.layers.torch import Rearrange +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.util import configure_backbone + + +class Encoder(nn.Module): + + def __init__(self, embedding_dim: int, encoder_dim: int, decoder_dim: int, dropout_rate: float = 0.1) -> None: + super().__init__() + self.rnn = nn.GRU(input_size=embedding_dim, hidden_size=encoder_dim, bidirectional=True) + self.fc = nn.Sequential(nn.Linear(in_features=2*encoder_dim, out_features=decoder_dim), nn.Tanh()) + self.dropout = nn.Dropout(p=dropout_rate) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Encodes a sequence of tensors with a bidirectional GRU. + + Args: + x (Tensor): A input sequence. + + Shape: + - x: :math:`(T, N, E)`. + - output[0]: :math:`(T, N, 2 * E)`. + - output[1]: :math:`(T, N, D)`. + + where T is the sequence length, N is the batch size, E is the + embedding/encoder dimension, and D is the decoder dimension. + + Returns: + Tuple[Tensor, Tensor]: The encoder output and the hidden state of the + encoder. + + """ + + output, hidden = self.rnn(x) + + # Get the hidden state from the forward and backward rnn. + hidden_state = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)) + + # Apply fully connected layer and tanh activation. + hidden_state = self.fc(hidden_state) + + return output, hidden_state + + +class Attention(nn.Module): + + def __init__(self, encoder_dim: int, decoder_dim: int) -> None: + super().__init__() + self.atten = nn.Linear(in_features=2*encoder_dim + decoder_dim, out_features=decoder_dim) + self.value = nn.Linear(in_features=decoder_dim, out_features=1, bias=False) + + def forward(self, hidden_state: Tensor, encoder_outputs: Tensor) -> Tensor: + """Short summary. + + Args: + hidden_state (Tensor): Description of parameter `h`. + encoder_outputs (Tensor): Description of parameter `enc_out`. + + Shape: + - x: :math:`(T, N, E)`. + - output[0]: :math:`(T, N, 2 * E)`. + - output[1]: :math:`(T, N, D)`. + + where T is the sequence length, N is the batch size, E is the + embedding/encoder dimension, and D is the decoder dimension. + + Returns: + Tensor: Description of returned object. + + """ + t, b = enc_out.shape[:2] + #repeat decoder hidden state src_len times + hidden_state = hidden_state.unsqueeze(1).repeat(1, t, 1) + + encoder_outputs = rearrange(encoder_outputs, "t b e2 -> b t e2") + + # Calculate the energy between the decoders previous hidden state and the + # encoders hidden states. + energy = torch.tanh(self.attn(torch.cat((hidden_state, encoder_outputs), dim = 2))) + + attention = self.value(energy).squeeze(2) + + # Apply softmax on the attention to squeeze it between 0 and 1. + attention = F.softmax(attention, dim=1) + + return attention + + +class Decoder(nn.Module): + + def __init__(self, embedding_dim: int, encoder_dim: int, decoder_dim: int, output_dim: int, dropout_rate: float = 0.1) -> None: + super().__init__() + self.output_dim = output_dim + self.embedding = nn.Embedding(output_dim, embedding_dim) + self.attention = Attention(encoder_dim, decoder_dim) + self.rnn = nn.GRU(input_size=2*encoder_dim + embedding_dim, hidden_size=decoder_dim) + + self.head = nn.Linear(in_features=2*encoder_dim+embedding_dim+decoder_dim, out_features=output_dim) + self.dropout = nn.Dropout(p=dropout_rate) + + def forward(self, trg: Tensor, hidden_state: Tensor, encoder_outputs: Tensor) -> Tensor: + #input = [batch size] + #hidden = [batch size, dec hid dim] + #encoder_outputs = [src len, batch size, enc hid dim * 2] + trg = trg.unsqueeze(0) + trg_embedded = self.dropout(self.embedding(trg)) + + a = self.attention(hidden_state, encoder_outputs) + + weighted = torch.bmm(a, encoder_outputs) + + # Permutate the tensor. + weighted = rearrange(weighted, "b a e2 -> a b e2") + + rnn_input = torch.cat((trg_embedded, weighted), dim = 2) + + output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0)) + + #seq len, n layers and n directions will always be 1 in this decoder, therefore: + #output = [1, batch size, dec hid dim] + #hidden = [1, batch size, dec hid dim] + #this also means that output == hidden + assert (output == hidden).all() + + trg_embedded = trg_embedded.squeeze(0) + output = output.squeeze(0) + weighted = weighted.squeeze(0) + + logits = self.fc_out(torch.cat((output, weighted, trg_embedded), dim = 1)) + + #prediction = [batch size, output dim] + + return logits, hidden.squeeze(0) + + +class NeuralMachineReader(nn.Module): + + def __init__(self, embedding_dim: int, encoder_dim: int, decoder_dim: int, output_dim: int, backbone: Optional[str] = None, + backbone_args: Optional[Dict] = None, patch_size: Tuple[int, int] = (28, 28), + stride: Tuple[int, int] = (1, 14), dropout_rate: float = 0.1, teacher_forcing_ratio: float = 0.5) -> None: + super().__init__() + self.patch_size = patch_size + self.stride = stride + self.sliding_window = self._configure_sliding_window() + + self.backbone = + self.encoder = Encoder(embedding_dim, encoder_dim, decoder_dim, dropout_rate) + self.decoder = Decoder(embedding_dim, encoder_dim, decoder_dim, output_dim, dropout_rate) + self.teacher_forcing_ratio = teacher_forcing_ratio + + def _configure_sliding_window(self) -> nn.Sequential: + return nn.Sequential( + nn.Unfold(kernel_size=self.patch_size, stride=self.stride), + Rearrange( + "b (c h w) t -> b t c h w", + h=self.patch_size[0], + w=self.patch_size[1], + c=1, + ), + ) + + def forward(self, x: Tensor, trg: Tensor) -> Tensor: + #x = [batch size, height, width] + #trg = [trg len, batch size] + + # Create image patches with a sliding window kernel. + x = self.sliding_window(x) + + # Rearrange from a sequence of patches for feedforward network. + b, t = x.shape[:2] + x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) + + x = self.backbone(x) + x = rearrange(x, "(b t) h -> t b h", b=b, t=t) -- cgit v1.2.3-70-g09d2