summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/image_transformer.py10
-rw-r--r--text_recognizer/networks/residual_network.py6
-rw-r--r--text_recognizer/networks/transducer/transducer.py7
-rw-r--r--text_recognizer/networks/vqvae/decoder.py18
-rw-r--r--text_recognizer/networks/vqvae/encoder.py12
5 files changed, 40 insertions, 13 deletions
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py
index aa024e0..85a84d2 100644
--- a/text_recognizer/networks/image_transformer.py
+++ b/text_recognizer/networks/image_transformer.py
@@ -1,9 +1,9 @@
"""A Transformer with a cnn backbone.
The network encodes a image with a convolutional backbone to a latent representation,
-i.e. feature maps. A 2d positional encoding is applied to the feature maps for
+i.e. feature maps. A 2d positional encoding is applied to the feature maps for
spatial information. The resulting feature are then set to a transformer decoder
-together with the target tokens.
+together with the target tokens.
TODO: Local attention for transformer.j
@@ -107,9 +107,7 @@ class ImageTransformer(nn.Module):
encoder_class = getattr(network_module, encoder.type)
return encoder_class(**encoder.args)
- def _configure_mapping(
- self, mapping: str
- ) -> Tuple[List[str], Dict[str, int]]:
+ def _configure_mapping(self, mapping: str) -> Tuple[List[str], Dict[str, int]]:
"""Configures mapping."""
if mapping == "emnist":
mapping, inverse_mapping, _ = emnist_mapping()
@@ -125,7 +123,7 @@ class ImageTransformer(nn.Module):
Tensor: Image features.
Shapes:
- - image: :math: `(B, C, H, W)`
+ - image: :math: `(B, C, H, W)`
- latent: :math: `(B, T, C)`
"""
diff --git a/text_recognizer/networks/residual_network.py b/text_recognizer/networks/residual_network.py
index c33f419..da7553d 100644
--- a/text_recognizer/networks/residual_network.py
+++ b/text_recognizer/networks/residual_network.py
@@ -20,7 +20,11 @@ class Conv2dAuto(nn.Conv2d):
def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential:
"""3x3 convolution with batch norm."""
- conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,)
+ conv3x3 = partial(
+ Conv2dAuto,
+ kernel_size=3,
+ bias=False,
+ )
return nn.Sequential(
conv3x3(in_channels, out_channels, *args, **kwargs),
nn.BatchNorm2d(out_channels),
diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py
index d7e3d08..b10f93a 100644
--- a/text_recognizer/networks/transducer/transducer.py
+++ b/text_recognizer/networks/transducer/transducer.py
@@ -392,7 +392,12 @@ def load_transducer_loss(
transitions = gtn.load(str(processed_path / transitions))
preprocessor = Preprocessor(
- data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep,
+ data_dir,
+ num_features,
+ tokens_path,
+ lexicon_path,
+ use_words,
+ prepend_wordsep,
)
num_tokens = preprocessor.num_tokens
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
index 8847aba..67ed0d9 100644
--- a/text_recognizer/networks/vqvae/decoder.py
+++ b/text_recognizer/networks/vqvae/decoder.py
@@ -44,7 +44,12 @@ class Decoder(nn.Module):
# Configure encoder.
self.decoder = self._build_decoder(
- channels, kernel_sizes, strides, num_residual_layers, activation, dropout,
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ activation,
+ dropout,
)
def _build_decompression_block(
@@ -73,7 +78,9 @@ class Decoder(nn.Module):
)
if i < len(self.upsampling):
- modules.append(nn.Upsample(size=self.upsampling[i]),)
+ modules.append(
+ nn.Upsample(size=self.upsampling[i]),
+ )
if dropout is not None:
modules.append(dropout)
@@ -102,7 +109,12 @@ class Decoder(nn.Module):
) -> nn.Sequential:
self.res_block.append(
- nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,)
+ nn.Conv2d(
+ self.embedding_dim,
+ channels[0],
+ kernel_size=1,
+ stride=1,
+ )
)
# Bottleneck module.
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
index d3adac5..ede5c31 100644
--- a/text_recognizer/networks/vqvae/encoder.py
+++ b/text_recognizer/networks/vqvae/encoder.py
@@ -11,7 +11,10 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer
class _ResidualBlock(nn.Module):
def __init__(
- self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]],
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: Optional[Type[nn.Module]],
) -> None:
super().__init__()
self.block = [
@@ -135,7 +138,12 @@ class Encoder(nn.Module):
)
encoder.append(
- nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,)
+ nn.Conv2d(
+ channels[-1],
+ self.embedding_dim,
+ kernel_size=1,
+ stride=1,
+ )
)
return nn.Sequential(*encoder)