summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-10-02 01:45:34 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-10-02 01:45:34 +0200
commitffec11ce67d8fe75ea0d5dde5ddf17eb1017fa7d (patch)
treedb8c78232e588b12d7a8b408682783e0b5858615 /text_recognizer/networks
parentcf2a827db5798a245dd5207685251675d311dbec (diff)
Add comments
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/conv_transformer.py5
-rw-r--r--text_recognizer/networks/convnext/convnext.py7
-rw-r--r--text_recognizer/networks/convnext/downsample.py4
-rw-r--r--text_recognizer/networks/convnext/norm.py3
-rw-r--r--text_recognizer/networks/convnext/residual.py3
-rw-r--r--text_recognizer/networks/image_encoder.py2
-rw-r--r--text_recognizer/networks/text_decoder.py14
-rw-r--r--text_recognizer/networks/transformer/embeddings/axial.py24
8 files changed, 42 insertions, 20 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index e36a786..d36162a 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -1,7 +1,6 @@
"""Base network module."""
from typing import Type
-import torch
from torch import Tensor, nn
from text_recognizer.networks.transformer.decoder import Decoder
@@ -28,11 +27,11 @@ class ConvTransformer(nn.Module):
return self.decoder(tokens, img_features)
def forward(self, img: Tensor, tokens: Tensor) -> Tensor:
- """Encodes images into word piece logtis.
+ """Encodes images into token logtis.
Args:
img (Tensor): Input image(s).
- tokens (Tensor): Target word embeddings.
+ tokens (Tensor): token embeddings.
Shapes:
- img: :math: `(B, 1, H, W)`
diff --git a/text_recognizer/networks/convnext/convnext.py b/text_recognizer/networks/convnext/convnext.py
index b4dfad7..9419a15 100644
--- a/text_recognizer/networks/convnext/convnext.py
+++ b/text_recognizer/networks/convnext/convnext.py
@@ -1,3 +1,4 @@
+"""ConvNext module."""
from typing import Optional, Sequence
from torch import Tensor, nn
@@ -8,7 +9,9 @@ from text_recognizer.networks.convnext.norm import LayerNorm
class ConvNextBlock(nn.Module):
- def __init__(self, dim, dim_out, mult):
+ """ConvNext block."""
+
+ def __init__(self, dim: int, dim_out: int, mult: int) -> None:
super().__init__()
self.ds_conv = nn.Conv2d(
dim, dim, kernel_size=(7, 7), padding="same", groups=dim
@@ -21,7 +24,7 @@ class ConvNextBlock(nn.Module):
)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
- def forward(self, x):
+ def forward(self, x: Tensor) -> Tensor:
h = self.ds_conv(x)
h = self.net(h)
return h + self.res_conv(x)
diff --git a/text_recognizer/networks/convnext/downsample.py b/text_recognizer/networks/convnext/downsample.py
index c28ecca..a8a0466 100644
--- a/text_recognizer/networks/convnext/downsample.py
+++ b/text_recognizer/networks/convnext/downsample.py
@@ -1,3 +1,4 @@
+"""Convnext downsample module."""
from typing import Tuple
from einops.layers.torch import Rearrange
@@ -5,6 +6,8 @@ from torch import Tensor, nn
class Downsample(nn.Module):
+ """Downsamples feature maps by patches."""
+
def __init__(self, dim: int, dim_out: int, factors: Tuple[int, int]) -> None:
super().__init__()
s1, s2 = factors
@@ -14,4 +17,5 @@ class Downsample(nn.Module):
)
def forward(self, x: Tensor) -> Tensor:
+ """Applies patch function."""
return self.fn(x)
diff --git a/text_recognizer/networks/convnext/norm.py b/text_recognizer/networks/convnext/norm.py
index 23cf07a..3355de9 100644
--- a/text_recognizer/networks/convnext/norm.py
+++ b/text_recognizer/networks/convnext/norm.py
@@ -4,11 +4,14 @@ from torch import Tensor, nn
class LayerNorm(nn.Module):
+ """Layer norm for convolutions."""
+
def __init__(self, dim: int) -> None:
super().__init__()
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x: Tensor) -> Tensor:
+ """Applies layer norm."""
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim=1, unbiased=False, keepdim=True)
mean = torch.mean(x, dim=1, keepdim=True)
diff --git a/text_recognizer/networks/convnext/residual.py b/text_recognizer/networks/convnext/residual.py
index 8e76ae9..dfc2847 100644
--- a/text_recognizer/networks/convnext/residual.py
+++ b/text_recognizer/networks/convnext/residual.py
@@ -5,9 +5,12 @@ from torch import Tensor, nn
class Residual(nn.Module):
+ """Residual layer."""
+
def __init__(self, fn: Callable) -> None:
super().__init__()
self.fn = fn
def forward(self, x: Tensor) -> Tensor:
+ """Applies residual fn."""
return self.fn(x) + x
diff --git a/text_recognizer/networks/image_encoder.py b/text_recognizer/networks/image_encoder.py
index b5fd0c5..ab60560 100644
--- a/text_recognizer/networks/image_encoder.py
+++ b/text_recognizer/networks/image_encoder.py
@@ -9,7 +9,7 @@ from text_recognizer.networks.transformer.embeddings.axial import (
class ImageEncoder(nn.Module):
- """Base transformer network."""
+ """Encodes images to latent embeddings."""
def __init__(
self,
diff --git a/text_recognizer/networks/text_decoder.py b/text_recognizer/networks/text_decoder.py
index 7ee6720..7498663 100644
--- a/text_recognizer/networks/text_decoder.py
+++ b/text_recognizer/networks/text_decoder.py
@@ -1,6 +1,4 @@
"""Text decoder."""
-from typing import Optional, Type
-
import torch
from torch import Tensor, nn
@@ -8,26 +6,24 @@ from text_recognizer.networks.transformer.decoder import Decoder
class TextDecoder(nn.Module):
- """Decoder transformer network."""
+ """Decodes images to token logits."""
def __init__(
self,
- hidden_dim: int,
+ dim: int,
num_classes: int,
pad_index: Tensor,
decoder: Decoder,
) -> None:
super().__init__()
- self.hidden_dim = hidden_dim
+ self.dim = dim
self.num_classes = num_classes
self.pad_index = pad_index
self.decoder = decoder
self.token_embedding = nn.Embedding(
- num_embeddings=self.num_classes, embedding_dim=self.hidden_dim
- )
- self.to_logits = nn.Linear(
- in_features=self.hidden_dim, out_features=self.num_classes
+ num_embeddings=self.num_classes, embedding_dim=self.dim
)
+ self.to_logits = nn.Linear(in_features=self.dim, out_features=self.num_classes)
def forward(self, tokens: Tensor, img_features: Tensor) -> Tensor:
"""Decodes latent images embedding into word pieces.
diff --git a/text_recognizer/networks/transformer/embeddings/axial.py b/text_recognizer/networks/transformer/embeddings/axial.py
index 25d8f60..9b872a9 100644
--- a/text_recognizer/networks/transformer/embeddings/axial.py
+++ b/text_recognizer/networks/transformer/embeddings/axial.py
@@ -1,17 +1,24 @@
"""Axial attention for multi-dimensional data.
Stolen from:
- https://github.com/lucidrains/axial-attention/blob/eff2c10c2e76c735a70a6b995b571213adffbbb7/axial_attention/axial_attention.py#L100
+ https://github.com/lucidrains/axial-attention/blob/
+ eff2c10c2e76c735a70a6b995b571213adffbbb7/axial_attention/axial_attention.py#L100
"""
from functools import reduce
from operator import mul
+from typing import Optional, Sequence
import torch
-from torch import nn
+from torch import Tensor, nn
class AxialPositionalEmbedding(nn.Module):
- def __init__(self, dim, axial_shape, axial_dims=None):
+ def __init__(
+ self,
+ dim: int,
+ axial_shape: Sequence[int],
+ axial_dims: Optional[Sequence[int]] = None,
+ ) -> None:
super().__init__()
self.dim = dim
@@ -37,7 +44,8 @@ class AxialPositionalEmbedding(nn.Module):
ax_emb = nn.Parameter(torch.zeros(ax_shape).normal_(0, 1))
self.weights.append(ax_emb)
- def forward(self, x):
+ def forward(self, x: Tensor) -> Tensor:
+ """Returns axial positional embedding."""
b, t, _ = x.shape
assert (
t <= self.max_seq_len
@@ -77,8 +85,14 @@ class ParameterList(object):
class AxialPositionalEmbeddingImage(nn.Module):
- def __init__(self, dim, axial_shape, axial_dims=None):
+ def __init__(
+ self,
+ dim: int,
+ axial_shape: Sequence[int],
+ axial_dims: Optional[Sequence[int]] = None,
+ ) -> None:
super().__init__()
+ axial_dims = (dim // 2, dim // 2) if axial_dims is None else axial_dims
assert len(axial_shape) == 2, "Axial shape must have 2 dimensions for images"
self.dim = dim
self.pos_emb = AxialPositionalEmbedding(dim, axial_shape, axial_dims)