summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r--text_recognizer/networks/transformer/embeddings/axial.py24
1 files changed, 19 insertions, 5 deletions
diff --git a/text_recognizer/networks/transformer/embeddings/axial.py b/text_recognizer/networks/transformer/embeddings/axial.py
index 25d8f60..9b872a9 100644
--- a/text_recognizer/networks/transformer/embeddings/axial.py
+++ b/text_recognizer/networks/transformer/embeddings/axial.py
@@ -1,17 +1,24 @@
"""Axial attention for multi-dimensional data.
Stolen from:
- https://github.com/lucidrains/axial-attention/blob/eff2c10c2e76c735a70a6b995b571213adffbbb7/axial_attention/axial_attention.py#L100
+ https://github.com/lucidrains/axial-attention/blob/
+ eff2c10c2e76c735a70a6b995b571213adffbbb7/axial_attention/axial_attention.py#L100
"""
from functools import reduce
from operator import mul
+from typing import Optional, Sequence
import torch
-from torch import nn
+from torch import Tensor, nn
class AxialPositionalEmbedding(nn.Module):
- def __init__(self, dim, axial_shape, axial_dims=None):
+ def __init__(
+ self,
+ dim: int,
+ axial_shape: Sequence[int],
+ axial_dims: Optional[Sequence[int]] = None,
+ ) -> None:
super().__init__()
self.dim = dim
@@ -37,7 +44,8 @@ class AxialPositionalEmbedding(nn.Module):
ax_emb = nn.Parameter(torch.zeros(ax_shape).normal_(0, 1))
self.weights.append(ax_emb)
- def forward(self, x):
+ def forward(self, x: Tensor) -> Tensor:
+ """Returns axial positional embedding."""
b, t, _ = x.shape
assert (
t <= self.max_seq_len
@@ -77,8 +85,14 @@ class ParameterList(object):
class AxialPositionalEmbeddingImage(nn.Module):
- def __init__(self, dim, axial_shape, axial_dims=None):
+ def __init__(
+ self,
+ dim: int,
+ axial_shape: Sequence[int],
+ axial_dims: Optional[Sequence[int]] = None,
+ ) -> None:
super().__init__()
+ axial_dims = (dim // 2, dim // 2) if axial_dims is None else axial_dims
assert len(axial_shape) == 2, "Axial shape must have 2 dimensions for images"
self.dim = dim
self.pos_emb = AxialPositionalEmbedding(dim, axial_shape, axial_dims)