From ffec11ce67d8fe75ea0d5dde5ddf17eb1017fa7d Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 2 Oct 2022 01:45:34 +0200 Subject: Add comments --- text_recognizer/data/tokenizer.py | 1 + text_recognizer/networks/conv_transformer.py | 5 ++--- text_recognizer/networks/convnext/convnext.py | 7 +++++-- text_recognizer/networks/convnext/downsample.py | 4 ++++ text_recognizer/networks/convnext/norm.py | 3 +++ text_recognizer/networks/convnext/residual.py | 3 +++ text_recognizer/networks/image_encoder.py | 2 +- text_recognizer/networks/text_decoder.py | 14 +++++-------- .../networks/transformer/embeddings/axial.py | 24 +++++++++++++++++----- 9 files changed, 43 insertions(+), 20 deletions(-) diff --git a/text_recognizer/data/tokenizer.py b/text_recognizer/data/tokenizer.py index a5f44e6..12617a1 100644 --- a/text_recognizer/data/tokenizer.py +++ b/text_recognizer/data/tokenizer.py @@ -37,6 +37,7 @@ class Tokenizer: @property def num_classes(self) -> int: + """Return number of classes in the dataset.""" return self.__len__() def _load_mapping(self) -> Tuple[List, Dict[str, int], List[int]]: diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index e36a786..d36162a 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -1,7 +1,6 @@ """Base network module.""" from typing import Type -import torch from torch import Tensor, nn from text_recognizer.networks.transformer.decoder import Decoder @@ -28,11 +27,11 @@ class ConvTransformer(nn.Module): return self.decoder(tokens, img_features) def forward(self, img: Tensor, tokens: Tensor) -> Tensor: - """Encodes images into word piece logtis. + """Encodes images into token logtis. Args: img (Tensor): Input image(s). - tokens (Tensor): Target word embeddings. + tokens (Tensor): token embeddings. Shapes: - img: :math: `(B, 1, H, W)` diff --git a/text_recognizer/networks/convnext/convnext.py b/text_recognizer/networks/convnext/convnext.py index b4dfad7..9419a15 100644 --- a/text_recognizer/networks/convnext/convnext.py +++ b/text_recognizer/networks/convnext/convnext.py @@ -1,3 +1,4 @@ +"""ConvNext module.""" from typing import Optional, Sequence from torch import Tensor, nn @@ -8,7 +9,9 @@ from text_recognizer.networks.convnext.norm import LayerNorm class ConvNextBlock(nn.Module): - def __init__(self, dim, dim_out, mult): + """ConvNext block.""" + + def __init__(self, dim: int, dim_out: int, mult: int) -> None: super().__init__() self.ds_conv = nn.Conv2d( dim, dim, kernel_size=(7, 7), padding="same", groups=dim @@ -21,7 +24,7 @@ class ConvNextBlock(nn.Module): ) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: h = self.ds_conv(x) h = self.net(h) return h + self.res_conv(x) diff --git a/text_recognizer/networks/convnext/downsample.py b/text_recognizer/networks/convnext/downsample.py index c28ecca..a8a0466 100644 --- a/text_recognizer/networks/convnext/downsample.py +++ b/text_recognizer/networks/convnext/downsample.py @@ -1,3 +1,4 @@ +"""Convnext downsample module.""" from typing import Tuple from einops.layers.torch import Rearrange @@ -5,6 +6,8 @@ from torch import Tensor, nn class Downsample(nn.Module): + """Downsamples feature maps by patches.""" + def __init__(self, dim: int, dim_out: int, factors: Tuple[int, int]) -> None: super().__init__() s1, s2 = factors @@ -14,4 +17,5 @@ class Downsample(nn.Module): ) def forward(self, x: Tensor) -> Tensor: + """Applies patch function.""" return self.fn(x) diff --git a/text_recognizer/networks/convnext/norm.py b/text_recognizer/networks/convnext/norm.py index 23cf07a..3355de9 100644 --- a/text_recognizer/networks/convnext/norm.py +++ b/text_recognizer/networks/convnext/norm.py @@ -4,11 +4,14 @@ from torch import Tensor, nn class LayerNorm(nn.Module): + """Layer norm for convolutions.""" + def __init__(self, dim: int) -> None: super().__init__() self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1)) def forward(self, x: Tensor) -> Tensor: + """Applies layer norm.""" eps = 1e-5 if x.dtype == torch.float32 else 1e-3 var = torch.var(x, dim=1, unbiased=False, keepdim=True) mean = torch.mean(x, dim=1, keepdim=True) diff --git a/text_recognizer/networks/convnext/residual.py b/text_recognizer/networks/convnext/residual.py index 8e76ae9..dfc2847 100644 --- a/text_recognizer/networks/convnext/residual.py +++ b/text_recognizer/networks/convnext/residual.py @@ -5,9 +5,12 @@ from torch import Tensor, nn class Residual(nn.Module): + """Residual layer.""" + def __init__(self, fn: Callable) -> None: super().__init__() self.fn = fn def forward(self, x: Tensor) -> Tensor: + """Applies residual fn.""" return self.fn(x) + x diff --git a/text_recognizer/networks/image_encoder.py b/text_recognizer/networks/image_encoder.py index b5fd0c5..ab60560 100644 --- a/text_recognizer/networks/image_encoder.py +++ b/text_recognizer/networks/image_encoder.py @@ -9,7 +9,7 @@ from text_recognizer.networks.transformer.embeddings.axial import ( class ImageEncoder(nn.Module): - """Base transformer network.""" + """Encodes images to latent embeddings.""" def __init__( self, diff --git a/text_recognizer/networks/text_decoder.py b/text_recognizer/networks/text_decoder.py index 7ee6720..7498663 100644 --- a/text_recognizer/networks/text_decoder.py +++ b/text_recognizer/networks/text_decoder.py @@ -1,6 +1,4 @@ """Text decoder.""" -from typing import Optional, Type - import torch from torch import Tensor, nn @@ -8,26 +6,24 @@ from text_recognizer.networks.transformer.decoder import Decoder class TextDecoder(nn.Module): - """Decoder transformer network.""" + """Decodes images to token logits.""" def __init__( self, - hidden_dim: int, + dim: int, num_classes: int, pad_index: Tensor, decoder: Decoder, ) -> None: super().__init__() - self.hidden_dim = hidden_dim + self.dim = dim self.num_classes = num_classes self.pad_index = pad_index self.decoder = decoder self.token_embedding = nn.Embedding( - num_embeddings=self.num_classes, embedding_dim=self.hidden_dim - ) - self.to_logits = nn.Linear( - in_features=self.hidden_dim, out_features=self.num_classes + num_embeddings=self.num_classes, embedding_dim=self.dim ) + self.to_logits = nn.Linear(in_features=self.dim, out_features=self.num_classes) def forward(self, tokens: Tensor, img_features: Tensor) -> Tensor: """Decodes latent images embedding into word pieces. diff --git a/text_recognizer/networks/transformer/embeddings/axial.py b/text_recognizer/networks/transformer/embeddings/axial.py index 25d8f60..9b872a9 100644 --- a/text_recognizer/networks/transformer/embeddings/axial.py +++ b/text_recognizer/networks/transformer/embeddings/axial.py @@ -1,17 +1,24 @@ """Axial attention for multi-dimensional data. Stolen from: - https://github.com/lucidrains/axial-attention/blob/eff2c10c2e76c735a70a6b995b571213adffbbb7/axial_attention/axial_attention.py#L100 + https://github.com/lucidrains/axial-attention/blob/ + eff2c10c2e76c735a70a6b995b571213adffbbb7/axial_attention/axial_attention.py#L100 """ from functools import reduce from operator import mul +from typing import Optional, Sequence import torch -from torch import nn +from torch import Tensor, nn class AxialPositionalEmbedding(nn.Module): - def __init__(self, dim, axial_shape, axial_dims=None): + def __init__( + self, + dim: int, + axial_shape: Sequence[int], + axial_dims: Optional[Sequence[int]] = None, + ) -> None: super().__init__() self.dim = dim @@ -37,7 +44,8 @@ class AxialPositionalEmbedding(nn.Module): ax_emb = nn.Parameter(torch.zeros(ax_shape).normal_(0, 1)) self.weights.append(ax_emb) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: + """Returns axial positional embedding.""" b, t, _ = x.shape assert ( t <= self.max_seq_len @@ -77,8 +85,14 @@ class ParameterList(object): class AxialPositionalEmbeddingImage(nn.Module): - def __init__(self, dim, axial_shape, axial_dims=None): + def __init__( + self, + dim: int, + axial_shape: Sequence[int], + axial_dims: Optional[Sequence[int]] = None, + ) -> None: super().__init__() + axial_dims = (dim // 2, dim // 2) if axial_dims is None else axial_dims assert len(axial_shape) == 2, "Axial shape must have 2 dimensions for images" self.dim = dim self.pos_emb = AxialPositionalEmbedding(dim, axial_shape, axial_dims) -- cgit v1.2.3-70-g09d2