diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/data/iam_extended_paragraphs.py | 10 | ||||
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 9 | ||||
-rw-r--r-- | text_recognizer/models/base.py | 6 | ||||
-rw-r--r-- | text_recognizer/models/transformer.py | 6 | ||||
-rw-r--r-- | text_recognizer/networks/image_transformer.py | 10 | ||||
-rw-r--r-- | text_recognizer/networks/residual_network.py | 6 | ||||
-rw-r--r-- | text_recognizer/networks/transducer/transducer.py | 7 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/decoder.py | 18 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/encoder.py | 12 |
9 files changed, 60 insertions, 24 deletions
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index d2529b4..c144341 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -18,10 +18,16 @@ class IAMExtendedParagraphs(BaseDataModule): super().__init__(batch_size, num_workers) self.iam_paragraphs = IAMParagraphs( - batch_size, num_workers, train_fraction, augment, + batch_size, + num_workers, + train_fraction, + augment, ) self.iam_synthetic_paragraphs = IAMSyntheticParagraphs( - batch_size, num_workers, train_fraction, augment, + batch_size, + num_workers, + train_fraction, + augment, ) self.dims = self.iam_paragraphs.dims diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index f588587..314d458 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -161,7 +161,10 @@ def get_dataset_properties() -> Dict: "min": min(_get_property_values("num_lines")), "max": max(_get_property_values("num_lines")), }, - "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),}, + "crop_shape": { + "min": crop_shapes.min(axis=0), + "max": crop_shapes.max(axis=0), + }, "aspect_ratio": { "min": aspect_ratio.min(axis=0), "max": aspect_ratio.max(axis=0), @@ -282,7 +285,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com ), transforms.ColorJitter(brightness=(0.8, 1.6)), transforms.RandomAffine( - degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR, + degrees=1, + shear=(-10, 10), + interpolation=InterpolationMode.BILINEAR, ), ] else: diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 1004f48..11d1eb1 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -15,7 +15,7 @@ class LitBaseModel(pl.LightningModule): def __init__( self, - network: Type[nn,Module], + network: Type[nn.Module], optimizer: Union[OmegaConf, Dict], lr_scheduler: Union[OmegaConf, Dict], criterion: Union[OmegaConf, Dict], @@ -40,14 +40,14 @@ class LitBaseModel(pl.LightningModule): args = {} or criterion.args return getattr(nn, criterion.type)(**args) - def _configure_optimizer(self) -> type: + def _configure_optimizer(self) -> torch.optim.Optimizer: """Configures the optimizer.""" args = {} or self._optimizer.args if self._optimizer.type == "MADGRAD": optimizer_class = madgrad.MADGRAD else: optimizer_class = getattr(torch.optim, self._optimizer.type) - return optimizer_class(parameters=self.parameters(), **args) + return optimizer_class(params=self.parameters(), **args) def _configure_lr_scheduler(self) -> Dict[str, Any]: """Configures the lr scheduler.""" diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 3625ab2..983e274 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -19,16 +19,14 @@ class LitTransformerModel(LitBaseModel): def __init__( self, - network: Type[nn,Module], + network: Type[nn, Module], optimizer: Union[OmegaConf, Dict], lr_scheduler: Union[OmegaConf, Dict], criterion: Union[OmegaConf, Dict], monitor: str = "val_loss", mapping: Optional[List[str]] = None, ) -> None: - super().__init__( - network, optimizer, lr_scheduler, criterion, monitor - ) + super().__init__(network, optimizer, lr_scheduler, criterion, monitor) self.mapping, ignore_tokens = self.configure_mapping(mapping) self.val_cer = CharacterErrorRate(ignore_tokens) diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py index aa024e0..85a84d2 100644 --- a/text_recognizer/networks/image_transformer.py +++ b/text_recognizer/networks/image_transformer.py @@ -1,9 +1,9 @@ """A Transformer with a cnn backbone. The network encodes a image with a convolutional backbone to a latent representation, -i.e. feature maps. A 2d positional encoding is applied to the feature maps for +i.e. feature maps. A 2d positional encoding is applied to the feature maps for spatial information. The resulting feature are then set to a transformer decoder -together with the target tokens. +together with the target tokens. TODO: Local attention for transformer.j @@ -107,9 +107,7 @@ class ImageTransformer(nn.Module): encoder_class = getattr(network_module, encoder.type) return encoder_class(**encoder.args) - def _configure_mapping( - self, mapping: str - ) -> Tuple[List[str], Dict[str, int]]: + def _configure_mapping(self, mapping: str) -> Tuple[List[str], Dict[str, int]]: """Configures mapping.""" if mapping == "emnist": mapping, inverse_mapping, _ = emnist_mapping() @@ -125,7 +123,7 @@ class ImageTransformer(nn.Module): Tensor: Image features. Shapes: - - image: :math: `(B, C, H, W)` + - image: :math: `(B, C, H, W)` - latent: :math: `(B, T, C)` """ diff --git a/text_recognizer/networks/residual_network.py b/text_recognizer/networks/residual_network.py index c33f419..da7553d 100644 --- a/text_recognizer/networks/residual_network.py +++ b/text_recognizer/networks/residual_network.py @@ -20,7 +20,11 @@ class Conv2dAuto(nn.Conv2d): def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential: """3x3 convolution with batch norm.""" - conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,) + conv3x3 = partial( + Conv2dAuto, + kernel_size=3, + bias=False, + ) return nn.Sequential( conv3x3(in_channels, out_channels, *args, **kwargs), nn.BatchNorm2d(out_channels), diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py index d7e3d08..b10f93a 100644 --- a/text_recognizer/networks/transducer/transducer.py +++ b/text_recognizer/networks/transducer/transducer.py @@ -392,7 +392,12 @@ def load_transducer_loss( transitions = gtn.load(str(processed_path / transitions)) preprocessor = Preprocessor( - data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep, + data_dir, + num_features, + tokens_path, + lexicon_path, + use_words, + prepend_wordsep, ) num_tokens = preprocessor.num_tokens diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py index 8847aba..67ed0d9 100644 --- a/text_recognizer/networks/vqvae/decoder.py +++ b/text_recognizer/networks/vqvae/decoder.py @@ -44,7 +44,12 @@ class Decoder(nn.Module): # Configure encoder. self.decoder = self._build_decoder( - channels, kernel_sizes, strides, num_residual_layers, activation, dropout, + channels, + kernel_sizes, + strides, + num_residual_layers, + activation, + dropout, ) def _build_decompression_block( @@ -73,7 +78,9 @@ class Decoder(nn.Module): ) if i < len(self.upsampling): - modules.append(nn.Upsample(size=self.upsampling[i]),) + modules.append( + nn.Upsample(size=self.upsampling[i]), + ) if dropout is not None: modules.append(dropout) @@ -102,7 +109,12 @@ class Decoder(nn.Module): ) -> nn.Sequential: self.res_block.append( - nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,) + nn.Conv2d( + self.embedding_dim, + channels[0], + kernel_size=1, + stride=1, + ) ) # Bottleneck module. diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py index d3adac5..ede5c31 100644 --- a/text_recognizer/networks/vqvae/encoder.py +++ b/text_recognizer/networks/vqvae/encoder.py @@ -11,7 +11,10 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer class _ResidualBlock(nn.Module): def __init__( - self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]], + self, + in_channels: int, + out_channels: int, + dropout: Optional[Type[nn.Module]], ) -> None: super().__init__() self.block = [ @@ -135,7 +138,12 @@ class Encoder(nn.Module): ) encoder.append( - nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,) + nn.Conv2d( + channels[-1], + self.embedding_dim, + kernel_size=1, + stride=1, + ) ) return nn.Sequential(*encoder) |