summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/networks/transformer/mlp.py46
-rw-r--r--text_recognizer/networks/transformer/norm.py22
-rw-r--r--text_recognizer/networks/transformer/transformer.py16
3 files changed, 74 insertions, 10 deletions
diff --git a/text_recognizer/networks/transformer/mlp.py b/text_recognizer/networks/transformer/mlp.py
new file mode 100644
index 0000000..4028ab3
--- /dev/null
+++ b/text_recognizer/networks/transformer/mlp.py
@@ -0,0 +1,46 @@
+"""Feedforward layer in transformer.
+
+Stolen from lucidrains:
+ https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py
+"""
+from typing import Optional
+
+from torch import nn
+from torch import Tensor
+import torch.nn.functional as F
+
+
+class GEGLU(nn.Module):
+ def __init__(self, dim_in: int, dim_out: int) -> None:
+ super().__init__()
+ self.fc = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x, gate = self.fc(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ expansion_factor: int = 4,
+ glu: bool = True,
+ dropout_rate: float = 0.0,
+ ) -> None:
+ super().__init__()
+ inner_dim = dim * expansion_factor
+ dim_out = dim_out if dim_out is not None else dim
+ in_projection = (
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
+ if not glu
+ else GEGLU(dim, inner_dim)
+ )
+
+ self.mlp = nn.Sequential(
+ in_projection, nn.Dropout(dropout_rate), nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.mlp(x)
diff --git a/text_recognizer/networks/transformer/norm.py b/text_recognizer/networks/transformer/norm.py
new file mode 100644
index 0000000..99a5291
--- /dev/null
+++ b/text_recognizer/networks/transformer/norm.py
@@ -0,0 +1,22 @@
+"""Normalization layers for transfromers.
+
+Copied from lucidrains:
+ https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py
+
+"""
+from typing import Callable, Dict
+
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class Rezero(nn.Module):
+ def __init__(self, fn: Callable) -> None:
+ super().__init__()
+ self.fn = fn
+ self.g = nn.Parameter(torch.zeros(1))
+
+ def forward(self, x: Tensor, **kwargs: Dict) -> Tensor:
+ x, *rest = self.fn(x, **kwargs)
+ return (x * self.g, *rest)
diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py
index dd180c4..5ac2787 100644
--- a/text_recognizer/networks/transformer/transformer.py
+++ b/text_recognizer/networks/transformer/transformer.py
@@ -41,7 +41,7 @@ class _IntraLayerConnection(nn.Module):
return self.norm(self.dropout(src) + residual)
-class _ConvolutionalLayer(nn.Module):
+class FeedForward(nn.Module):
def __init__(
self,
hidden_dim: int,
@@ -82,9 +82,7 @@ class EncoderLayer(nn.Module):
) -> None:
super().__init__()
self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
- self.cnn = _ConvolutionalLayer(
- hidden_dim, expansion_dim, dropout_rate, activation
- )
+ self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation)
self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
@@ -99,10 +97,10 @@ class EncoderLayer(nn.Module):
# Second block.
# Apply 1D-convolution.
- cnn_out = self.cnn(out)
+ mlp_out = self.mlp(out)
# Add & norm.
- out = self.block2(cnn_out, out)
+ out = self.block2(mlp_out, out)
return out
@@ -148,9 +146,7 @@ class DecoderLayer(nn.Module):
self.multihead_attention = MultiHeadAttention(
hidden_dim, num_heads, dropout_rate
)
- self.cnn = _ConvolutionalLayer(
- hidden_dim, expansion_dim, dropout_rate, activation
- )
+ self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation)
self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim)
@@ -169,7 +165,7 @@ class DecoderLayer(nn.Module):
out, _ = self.multihead_attention(trg, memory, memory, memory_mask)
trg = self.block2(out, trg)
- out = self.cnn(trg)
+ out = self.mlp(trg)
out = self.block3(out, trg)
return out