diff options
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/__init__.py | 1 | ||||
-rw-r--r-- | text_recognizer/networks/image_transformer.py | 6 | ||||
-rw-r--r-- | text_recognizer/networks/residual_network.py | 6 | ||||
-rw-r--r-- | text_recognizer/networks/transducer/transducer.py | 7 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/decoder.py | 18 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/encoder.py | 12 |
6 files changed, 10 insertions, 40 deletions
diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py index 4dcaf2e..979149f 100644 --- a/text_recognizer/networks/__init__.py +++ b/text_recognizer/networks/__init__.py @@ -1,3 +1,2 @@ """Network modules""" from .image_transformer import ImageTransformer - diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py index edebca9..9ed67a4 100644 --- a/text_recognizer/networks/image_transformer.py +++ b/text_recognizer/networks/image_transformer.py @@ -13,7 +13,7 @@ import math from typing import Dict, List, Union, Sequence, Tuple, Type from einops import rearrange -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf import torch from torch import nn from torch import Tensor @@ -34,7 +34,7 @@ class ImageTransformer(nn.Module): self, input_shape: Sequence[int], output_shape: Sequence[int], - encoder: Union[OmegaConf, Dict], + encoder: Union[DictConfig, Dict], mapping: str, num_decoder_layers: int = 4, hidden_dim: int = 256, @@ -101,7 +101,7 @@ class ImageTransformer(nn.Module): nn.init.normal_(self.feature_map_encoding.bias, -bound, bound) @staticmethod - def _configure_encoder(encoder: Union[OmegaConf, NamedTuple]) -> Type[nn.Module]: + def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: encoder = OmegaConf.create(encoder) network_module = importlib.import_module("text_recognizer.networks") encoder_class = getattr(network_module, encoder.type) diff --git a/text_recognizer/networks/residual_network.py b/text_recognizer/networks/residual_network.py index da7553d..c33f419 100644 --- a/text_recognizer/networks/residual_network.py +++ b/text_recognizer/networks/residual_network.py @@ -20,11 +20,7 @@ class Conv2dAuto(nn.Conv2d): def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential: """3x3 convolution with batch norm.""" - conv3x3 = partial( - Conv2dAuto, - kernel_size=3, - bias=False, - ) + conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,) return nn.Sequential( conv3x3(in_channels, out_channels, *args, **kwargs), nn.BatchNorm2d(out_channels), diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py index b10f93a..d7e3d08 100644 --- a/text_recognizer/networks/transducer/transducer.py +++ b/text_recognizer/networks/transducer/transducer.py @@ -392,12 +392,7 @@ def load_transducer_loss( transitions = gtn.load(str(processed_path / transitions)) preprocessor = Preprocessor( - data_dir, - num_features, - tokens_path, - lexicon_path, - use_words, - prepend_wordsep, + data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep, ) num_tokens = preprocessor.num_tokens diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py index 67ed0d9..8847aba 100644 --- a/text_recognizer/networks/vqvae/decoder.py +++ b/text_recognizer/networks/vqvae/decoder.py @@ -44,12 +44,7 @@ class Decoder(nn.Module): # Configure encoder. self.decoder = self._build_decoder( - channels, - kernel_sizes, - strides, - num_residual_layers, - activation, - dropout, + channels, kernel_sizes, strides, num_residual_layers, activation, dropout, ) def _build_decompression_block( @@ -78,9 +73,7 @@ class Decoder(nn.Module): ) if i < len(self.upsampling): - modules.append( - nn.Upsample(size=self.upsampling[i]), - ) + modules.append(nn.Upsample(size=self.upsampling[i]),) if dropout is not None: modules.append(dropout) @@ -109,12 +102,7 @@ class Decoder(nn.Module): ) -> nn.Sequential: self.res_block.append( - nn.Conv2d( - self.embedding_dim, - channels[0], - kernel_size=1, - stride=1, - ) + nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,) ) # Bottleneck module. diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py index ede5c31..d3adac5 100644 --- a/text_recognizer/networks/vqvae/encoder.py +++ b/text_recognizer/networks/vqvae/encoder.py @@ -11,10 +11,7 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer class _ResidualBlock(nn.Module): def __init__( - self, - in_channels: int, - out_channels: int, - dropout: Optional[Type[nn.Module]], + self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]], ) -> None: super().__init__() self.block = [ @@ -138,12 +135,7 @@ class Encoder(nn.Module): ) encoder.append( - nn.Conv2d( - channels[-1], - self.embedding_dim, - kernel_size=1, - stride=1, - ) + nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,) ) return nn.Sequential(*encoder) |