diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-10-02 01:45:34 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-10-02 01:45:34 +0200 |
commit | ffec11ce67d8fe75ea0d5dde5ddf17eb1017fa7d (patch) | |
tree | db8c78232e588b12d7a8b408682783e0b5858615 /text_recognizer/networks/transformer/embeddings | |
parent | cf2a827db5798a245dd5207685251675d311dbec (diff) |
Add comments
Diffstat (limited to 'text_recognizer/networks/transformer/embeddings')
-rw-r--r-- | text_recognizer/networks/transformer/embeddings/axial.py | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/text_recognizer/networks/transformer/embeddings/axial.py b/text_recognizer/networks/transformer/embeddings/axial.py index 25d8f60..9b872a9 100644 --- a/text_recognizer/networks/transformer/embeddings/axial.py +++ b/text_recognizer/networks/transformer/embeddings/axial.py @@ -1,17 +1,24 @@ """Axial attention for multi-dimensional data. Stolen from: - https://github.com/lucidrains/axial-attention/blob/eff2c10c2e76c735a70a6b995b571213adffbbb7/axial_attention/axial_attention.py#L100 + https://github.com/lucidrains/axial-attention/blob/ + eff2c10c2e76c735a70a6b995b571213adffbbb7/axial_attention/axial_attention.py#L100 """ from functools import reduce from operator import mul +from typing import Optional, Sequence import torch -from torch import nn +from torch import Tensor, nn class AxialPositionalEmbedding(nn.Module): - def __init__(self, dim, axial_shape, axial_dims=None): + def __init__( + self, + dim: int, + axial_shape: Sequence[int], + axial_dims: Optional[Sequence[int]] = None, + ) -> None: super().__init__() self.dim = dim @@ -37,7 +44,8 @@ class AxialPositionalEmbedding(nn.Module): ax_emb = nn.Parameter(torch.zeros(ax_shape).normal_(0, 1)) self.weights.append(ax_emb) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: + """Returns axial positional embedding.""" b, t, _ = x.shape assert ( t <= self.max_seq_len @@ -77,8 +85,14 @@ class ParameterList(object): class AxialPositionalEmbeddingImage(nn.Module): - def __init__(self, dim, axial_shape, axial_dims=None): + def __init__( + self, + dim: int, + axial_shape: Sequence[int], + axial_dims: Optional[Sequence[int]] = None, + ) -> None: super().__init__() + axial_dims = (dim // 2, dim // 2) if axial_dims is None else axial_dims assert len(axial_shape) == 2, "Axial shape must have 2 dimensions for images" self.dim = dim self.pos_emb = AxialPositionalEmbedding(dim, axial_shape, axial_dims) |