summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/transformer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
commit7e8e54e84c63171e748bbf09516fd517e6821ace (patch)
tree996093f75a5d488dddf7ea1f159ed343a561ef89 /src/text_recognizer/networks/transformer
parentb0719d84138b6bbe5f04a4982dfca673aea1a368 (diff)
Inital commit for refactoring to lightning
Diffstat (limited to 'src/text_recognizer/networks/transformer')
-rw-r--r--src/text_recognizer/networks/transformer/__init__.py3
-rw-r--r--src/text_recognizer/networks/transformer/attention.py93
-rw-r--r--src/text_recognizer/networks/transformer/positional_encoding.py32
-rw-r--r--src/text_recognizer/networks/transformer/transformer.py264
4 files changed, 0 insertions, 392 deletions
diff --git a/src/text_recognizer/networks/transformer/__init__.py b/src/text_recognizer/networks/transformer/__init__.py
deleted file mode 100644
index 9febc88..0000000
--- a/src/text_recognizer/networks/transformer/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-"""Transformer modules."""
-from .positional_encoding import PositionalEncoding
-from .transformer import Decoder, Encoder, EncoderLayer, Transformer
diff --git a/src/text_recognizer/networks/transformer/attention.py b/src/text_recognizer/networks/transformer/attention.py
deleted file mode 100644
index cce1ecc..0000000
--- a/src/text_recognizer/networks/transformer/attention.py
+++ /dev/null
@@ -1,93 +0,0 @@
-"""Implementes the attention module for the transformer."""
-from typing import Optional, Tuple
-
-from einops import rearrange
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-
-
-class MultiHeadAttention(nn.Module):
- """Implementation of multihead attention."""
-
- def __init__(
- self, hidden_dim: int, num_heads: int = 8, dropout_rate: float = 0.0
- ) -> None:
- super().__init__()
- self.hidden_dim = hidden_dim
- self.num_heads = num_heads
- self.fc_q = nn.Linear(
- in_features=hidden_dim, out_features=hidden_dim, bias=False
- )
- self.fc_k = nn.Linear(
- in_features=hidden_dim, out_features=hidden_dim, bias=False
- )
- self.fc_v = nn.Linear(
- in_features=hidden_dim, out_features=hidden_dim, bias=False
- )
- self.fc_out = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
-
- self._init_weights()
-
- self.dropout = nn.Dropout(p=dropout_rate)
-
- def _init_weights(self) -> None:
- nn.init.normal_(
- self.fc_q.weight,
- mean=0,
- std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
- )
- nn.init.normal_(
- self.fc_k.weight,
- mean=0,
- std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
- )
- nn.init.normal_(
- self.fc_v.weight,
- mean=0,
- std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
- )
- nn.init.xavier_normal_(self.fc_out.weight)
-
- def scaled_dot_product_attention(
- self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
- ) -> Tensor:
- """Calculates the scaled dot product attention."""
-
- # Compute the energy.
- energy = torch.einsum("bhlk,bhtk->bhlt", [query, key]) / np.sqrt(
- query.shape[-1]
- )
-
- # If we have a mask for padding some inputs.
- if mask is not None:
- energy = energy.masked_fill(mask == 0, -np.inf)
-
- # Compute the attention from the energy.
- attention = torch.softmax(energy, dim=3)
-
- out = torch.einsum("bhlt,bhtv->bhlv", [attention, value])
- out = rearrange(out, "b head l v -> b l (head v)")
- return out, attention
-
- def forward(
- self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
- ) -> Tuple[Tensor, Tensor]:
- """Forward pass for computing the multihead attention."""
- # Get the query, key, and value tensor.
- query = rearrange(
- self.fc_q(query), "b l (head k) -> b head l k", head=self.num_heads
- )
- key = rearrange(
- self.fc_k(key), "b t (head k) -> b head t k", head=self.num_heads
- )
- value = rearrange(
- self.fc_v(value), "b t (head v) -> b head t v", head=self.num_heads
- )
-
- out, attention = self.scaled_dot_product_attention(query, key, value, mask)
-
- out = self.fc_out(out)
- out = self.dropout(out)
- return out, attention
diff --git a/src/text_recognizer/networks/transformer/positional_encoding.py b/src/text_recognizer/networks/transformer/positional_encoding.py
deleted file mode 100644
index 1ba5537..0000000
--- a/src/text_recognizer/networks/transformer/positional_encoding.py
+++ /dev/null
@@ -1,32 +0,0 @@
-"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence."""
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-
-
-class PositionalEncoding(nn.Module):
- """Encodes a sense of distance or time for transformer networks."""
-
- def __init__(
- self, hidden_dim: int, dropout_rate: float, max_len: int = 1000
- ) -> None:
- super().__init__()
- self.dropout = nn.Dropout(p=dropout_rate)
- self.max_len = max_len
-
- pe = torch.zeros(max_len, hidden_dim)
- position = torch.arange(0, max_len).unsqueeze(1)
- div_term = torch.exp(
- torch.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim)
- )
-
- pe[:, 0::2] = torch.sin(position * div_term)
- pe[:, 1::2] = torch.cos(position * div_term)
- pe = pe.unsqueeze(0)
- self.register_buffer("pe", pe)
-
- def forward(self, x: Tensor) -> Tensor:
- """Encodes the tensor with a postional embedding."""
- x = x + self.pe[:, : x.shape[1]]
- return self.dropout(x)
diff --git a/src/text_recognizer/networks/transformer/transformer.py b/src/text_recognizer/networks/transformer/transformer.py
deleted file mode 100644
index dd180c4..0000000
--- a/src/text_recognizer/networks/transformer/transformer.py
+++ /dev/null
@@ -1,264 +0,0 @@
-"""Transfomer module."""
-import copy
-from typing import Dict, Optional, Type, Union
-
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-import torch.nn.functional as F
-
-from text_recognizer.networks.transformer.attention import MultiHeadAttention
-from text_recognizer.networks.util import activation_function
-
-
-class GEGLU(nn.Module):
- """GLU activation for improving feedforward activations."""
-
- def __init__(self, dim_in: int, dim_out: int) -> None:
- super().__init__()
- self.proj = nn.Linear(dim_in, dim_out * 2)
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward propagation."""
- x, gate = self.proj(x).chunk(2, dim=-1)
- return x * F.gelu(gate)
-
-
-def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList:
- return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)])
-
-
-class _IntraLayerConnection(nn.Module):
- """Preforms the residual connection inside the transfomer blocks and applies layernorm."""
-
- def __init__(self, dropout_rate: float, hidden_dim: int) -> None:
- super().__init__()
- self.norm = nn.LayerNorm(normalized_shape=hidden_dim)
- self.dropout = nn.Dropout(p=dropout_rate)
-
- def forward(self, src: Tensor, residual: Tensor) -> Tensor:
- return self.norm(self.dropout(src) + residual)
-
-
-class _ConvolutionalLayer(nn.Module):
- def __init__(
- self,
- hidden_dim: int,
- expansion_dim: int,
- dropout_rate: float,
- activation: str = "relu",
- ) -> None:
- super().__init__()
-
- in_projection = (
- nn.Sequential(
- nn.Linear(hidden_dim, expansion_dim), activation_function(activation)
- )
- if activation != "glu"
- else GEGLU(hidden_dim, expansion_dim)
- )
-
- self.layer = nn.Sequential(
- in_projection,
- nn.Dropout(p=dropout_rate),
- nn.Linear(in_features=expansion_dim, out_features=hidden_dim),
- )
-
- def forward(self, x: Tensor) -> Tensor:
- return self.layer(x)
-
-
-class EncoderLayer(nn.Module):
- """Transfomer encoding layer."""
-
- def __init__(
- self,
- hidden_dim: int,
- num_heads: int,
- expansion_dim: int,
- dropout_rate: float,
- activation: str = "relu",
- ) -> None:
- super().__init__()
- self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
- self.cnn = _ConvolutionalLayer(
- hidden_dim, expansion_dim, dropout_rate, activation
- )
- self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
- self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
-
- def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
- """Forward pass through the encoder."""
- # First block.
- # Multi head attention.
- out, _ = self.self_attention(src, src, src, mask)
-
- # Add & norm.
- out = self.block1(out, src)
-
- # Second block.
- # Apply 1D-convolution.
- cnn_out = self.cnn(out)
-
- # Add & norm.
- out = self.block2(cnn_out, out)
-
- return out
-
-
-class Encoder(nn.Module):
- """Transfomer encoder module."""
-
- def __init__(
- self,
- num_layers: int,
- encoder_layer: Type[nn.Module],
- norm: Optional[Type[nn.Module]] = None,
- ) -> None:
- super().__init__()
- self.layers = _get_clones(encoder_layer, num_layers)
- self.norm = norm
-
- def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
- """Forward pass through all encoder layers."""
- for layer in self.layers:
- src = layer(src, src_mask)
-
- if self.norm is not None:
- src = self.norm(src)
-
- return src
-
-
-class DecoderLayer(nn.Module):
- """Transfomer decoder layer."""
-
- def __init__(
- self,
- hidden_dim: int,
- num_heads: int,
- expansion_dim: int,
- dropout_rate: float = 0.0,
- activation: str = "relu",
- ) -> None:
- super().__init__()
- self.hidden_dim = hidden_dim
- self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
- self.multihead_attention = MultiHeadAttention(
- hidden_dim, num_heads, dropout_rate
- )
- self.cnn = _ConvolutionalLayer(
- hidden_dim, expansion_dim, dropout_rate, activation
- )
- self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
- self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
- self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim)
-
- def forward(
- self,
- trg: Tensor,
- memory: Tensor,
- trg_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- ) -> Tensor:
- """Forward pass of the layer."""
- out, _ = self.self_attention(trg, trg, trg, trg_mask)
- trg = self.block1(out, trg)
-
- out, _ = self.multihead_attention(trg, memory, memory, memory_mask)
- trg = self.block2(out, trg)
-
- out = self.cnn(trg)
- out = self.block3(out, trg)
-
- return out
-
-
-class Decoder(nn.Module):
- """Transfomer decoder module."""
-
- def __init__(
- self,
- decoder_layer: Type[nn.Module],
- num_layers: int,
- norm: Optional[Type[nn.Module]] = None,
- ) -> None:
- super().__init__()
- self.layers = _get_clones(decoder_layer, num_layers)
- self.num_layers = num_layers
- self.norm = norm
-
- def forward(
- self,
- trg: Tensor,
- memory: Tensor,
- trg_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- ) -> Tensor:
- """Forward pass through the decoder."""
- for layer in self.layers:
- trg = layer(trg, memory, trg_mask, memory_mask)
-
- if self.norm is not None:
- trg = self.norm(trg)
-
- return trg
-
-
-class Transformer(nn.Module):
- """Transformer network."""
-
- def __init__(
- self,
- num_encoder_layers: int,
- num_decoder_layers: int,
- hidden_dim: int,
- num_heads: int,
- expansion_dim: int,
- dropout_rate: float,
- activation: str = "relu",
- ) -> None:
- super().__init__()
-
- # Configure encoder.
- encoder_norm = nn.LayerNorm(hidden_dim)
- encoder_layer = EncoderLayer(
- hidden_dim, num_heads, expansion_dim, dropout_rate, activation
- )
- self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm)
-
- # Configure decoder.
- decoder_norm = nn.LayerNorm(hidden_dim)
- decoder_layer = DecoderLayer(
- hidden_dim, num_heads, expansion_dim, dropout_rate, activation
- )
- self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm)
-
- self._reset_parameters()
-
- def _reset_parameters(self) -> None:
- for p in self.parameters():
- if p.dim() > 1:
- nn.init.xavier_uniform_(p)
-
- def forward(
- self,
- src: Tensor,
- trg: Tensor,
- src_mask: Optional[Tensor] = None,
- trg_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- ) -> Tensor:
- """Forward pass through the transformer."""
- if src.shape[0] != trg.shape[0]:
- print(trg.shape)
- raise RuntimeError("The batch size of the src and trg must be the same.")
- if src.shape[2] != trg.shape[2]:
- raise RuntimeError(
- "The number of features for the src and trg must be the same."
- )
-
- memory = self.encoder(src, src_mask)
- output = self.decoder(trg, memory, trg_mask, memory_mask)
- return output