From ffec11ce67d8fe75ea0d5dde5ddf17eb1017fa7d Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 2 Oct 2022 01:45:34 +0200 Subject: Add comments --- .../networks/transformer/embeddings/axial.py | 24 +++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) (limited to 'text_recognizer/networks/transformer') diff --git a/text_recognizer/networks/transformer/embeddings/axial.py b/text_recognizer/networks/transformer/embeddings/axial.py index 25d8f60..9b872a9 100644 --- a/text_recognizer/networks/transformer/embeddings/axial.py +++ b/text_recognizer/networks/transformer/embeddings/axial.py @@ -1,17 +1,24 @@ """Axial attention for multi-dimensional data. Stolen from: - https://github.com/lucidrains/axial-attention/blob/eff2c10c2e76c735a70a6b995b571213adffbbb7/axial_attention/axial_attention.py#L100 + https://github.com/lucidrains/axial-attention/blob/ + eff2c10c2e76c735a70a6b995b571213adffbbb7/axial_attention/axial_attention.py#L100 """ from functools import reduce from operator import mul +from typing import Optional, Sequence import torch -from torch import nn +from torch import Tensor, nn class AxialPositionalEmbedding(nn.Module): - def __init__(self, dim, axial_shape, axial_dims=None): + def __init__( + self, + dim: int, + axial_shape: Sequence[int], + axial_dims: Optional[Sequence[int]] = None, + ) -> None: super().__init__() self.dim = dim @@ -37,7 +44,8 @@ class AxialPositionalEmbedding(nn.Module): ax_emb = nn.Parameter(torch.zeros(ax_shape).normal_(0, 1)) self.weights.append(ax_emb) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: + """Returns axial positional embedding.""" b, t, _ = x.shape assert ( t <= self.max_seq_len @@ -77,8 +85,14 @@ class ParameterList(object): class AxialPositionalEmbeddingImage(nn.Module): - def __init__(self, dim, axial_shape, axial_dims=None): + def __init__( + self, + dim: int, + axial_shape: Sequence[int], + axial_dims: Optional[Sequence[int]] = None, + ) -> None: super().__init__() + axial_dims = (dim // 2, dim // 2) if axial_dims is None else axial_dims assert len(axial_shape) == 2, "Axial shape must have 2 dimensions for images" self.dim = dim self.pos_emb = AxialPositionalEmbedding(dim, axial_shape, axial_dims) -- cgit v1.2.3-70-g09d2