summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py10
-rw-r--r--text_recognizer/data/iam_paragraphs.py9
-rw-r--r--text_recognizer/models/base.py6
-rw-r--r--text_recognizer/models/transformer.py6
-rw-r--r--text_recognizer/networks/image_transformer.py10
-rw-r--r--text_recognizer/networks/residual_network.py6
-rw-r--r--text_recognizer/networks/transducer/transducer.py7
-rw-r--r--text_recognizer/networks/vqvae/decoder.py18
-rw-r--r--text_recognizer/networks/vqvae/encoder.py12
9 files changed, 60 insertions, 24 deletions
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index d2529b4..c144341 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -18,10 +18,16 @@ class IAMExtendedParagraphs(BaseDataModule):
super().__init__(batch_size, num_workers)
self.iam_paragraphs = IAMParagraphs(
- batch_size, num_workers, train_fraction, augment,
+ batch_size,
+ num_workers,
+ train_fraction,
+ augment,
)
self.iam_synthetic_paragraphs = IAMSyntheticParagraphs(
- batch_size, num_workers, train_fraction, augment,
+ batch_size,
+ num_workers,
+ train_fraction,
+ augment,
)
self.dims = self.iam_paragraphs.dims
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index f588587..314d458 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -161,7 +161,10 @@ def get_dataset_properties() -> Dict:
"min": min(_get_property_values("num_lines")),
"max": max(_get_property_values("num_lines")),
},
- "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),},
+ "crop_shape": {
+ "min": crop_shapes.min(axis=0),
+ "max": crop_shapes.max(axis=0),
+ },
"aspect_ratio": {
"min": aspect_ratio.min(axis=0),
"max": aspect_ratio.max(axis=0),
@@ -282,7 +285,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com
),
transforms.ColorJitter(brightness=(0.8, 1.6)),
transforms.RandomAffine(
- degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR,
+ degrees=1,
+ shear=(-10, 10),
+ interpolation=InterpolationMode.BILINEAR,
),
]
else:
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 1004f48..11d1eb1 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -15,7 +15,7 @@ class LitBaseModel(pl.LightningModule):
def __init__(
self,
- network: Type[nn,Module],
+ network: Type[nn.Module],
optimizer: Union[OmegaConf, Dict],
lr_scheduler: Union[OmegaConf, Dict],
criterion: Union[OmegaConf, Dict],
@@ -40,14 +40,14 @@ class LitBaseModel(pl.LightningModule):
args = {} or criterion.args
return getattr(nn, criterion.type)(**args)
- def _configure_optimizer(self) -> type:
+ def _configure_optimizer(self) -> torch.optim.Optimizer:
"""Configures the optimizer."""
args = {} or self._optimizer.args
if self._optimizer.type == "MADGRAD":
optimizer_class = madgrad.MADGRAD
else:
optimizer_class = getattr(torch.optim, self._optimizer.type)
- return optimizer_class(parameters=self.parameters(), **args)
+ return optimizer_class(params=self.parameters(), **args)
def _configure_lr_scheduler(self) -> Dict[str, Any]:
"""Configures the lr scheduler."""
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 3625ab2..983e274 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -19,16 +19,14 @@ class LitTransformerModel(LitBaseModel):
def __init__(
self,
- network: Type[nn,Module],
+ network: Type[nn, Module],
optimizer: Union[OmegaConf, Dict],
lr_scheduler: Union[OmegaConf, Dict],
criterion: Union[OmegaConf, Dict],
monitor: str = "val_loss",
mapping: Optional[List[str]] = None,
) -> None:
- super().__init__(
- network, optimizer, lr_scheduler, criterion, monitor
- )
+ super().__init__(network, optimizer, lr_scheduler, criterion, monitor)
self.mapping, ignore_tokens = self.configure_mapping(mapping)
self.val_cer = CharacterErrorRate(ignore_tokens)
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py
index aa024e0..85a84d2 100644
--- a/text_recognizer/networks/image_transformer.py
+++ b/text_recognizer/networks/image_transformer.py
@@ -1,9 +1,9 @@
"""A Transformer with a cnn backbone.
The network encodes a image with a convolutional backbone to a latent representation,
-i.e. feature maps. A 2d positional encoding is applied to the feature maps for
+i.e. feature maps. A 2d positional encoding is applied to the feature maps for
spatial information. The resulting feature are then set to a transformer decoder
-together with the target tokens.
+together with the target tokens.
TODO: Local attention for transformer.j
@@ -107,9 +107,7 @@ class ImageTransformer(nn.Module):
encoder_class = getattr(network_module, encoder.type)
return encoder_class(**encoder.args)
- def _configure_mapping(
- self, mapping: str
- ) -> Tuple[List[str], Dict[str, int]]:
+ def _configure_mapping(self, mapping: str) -> Tuple[List[str], Dict[str, int]]:
"""Configures mapping."""
if mapping == "emnist":
mapping, inverse_mapping, _ = emnist_mapping()
@@ -125,7 +123,7 @@ class ImageTransformer(nn.Module):
Tensor: Image features.
Shapes:
- - image: :math: `(B, C, H, W)`
+ - image: :math: `(B, C, H, W)`
- latent: :math: `(B, T, C)`
"""
diff --git a/text_recognizer/networks/residual_network.py b/text_recognizer/networks/residual_network.py
index c33f419..da7553d 100644
--- a/text_recognizer/networks/residual_network.py
+++ b/text_recognizer/networks/residual_network.py
@@ -20,7 +20,11 @@ class Conv2dAuto(nn.Conv2d):
def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential:
"""3x3 convolution with batch norm."""
- conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,)
+ conv3x3 = partial(
+ Conv2dAuto,
+ kernel_size=3,
+ bias=False,
+ )
return nn.Sequential(
conv3x3(in_channels, out_channels, *args, **kwargs),
nn.BatchNorm2d(out_channels),
diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py
index d7e3d08..b10f93a 100644
--- a/text_recognizer/networks/transducer/transducer.py
+++ b/text_recognizer/networks/transducer/transducer.py
@@ -392,7 +392,12 @@ def load_transducer_loss(
transitions = gtn.load(str(processed_path / transitions))
preprocessor = Preprocessor(
- data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep,
+ data_dir,
+ num_features,
+ tokens_path,
+ lexicon_path,
+ use_words,
+ prepend_wordsep,
)
num_tokens = preprocessor.num_tokens
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
index 8847aba..67ed0d9 100644
--- a/text_recognizer/networks/vqvae/decoder.py
+++ b/text_recognizer/networks/vqvae/decoder.py
@@ -44,7 +44,12 @@ class Decoder(nn.Module):
# Configure encoder.
self.decoder = self._build_decoder(
- channels, kernel_sizes, strides, num_residual_layers, activation, dropout,
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ activation,
+ dropout,
)
def _build_decompression_block(
@@ -73,7 +78,9 @@ class Decoder(nn.Module):
)
if i < len(self.upsampling):
- modules.append(nn.Upsample(size=self.upsampling[i]),)
+ modules.append(
+ nn.Upsample(size=self.upsampling[i]),
+ )
if dropout is not None:
modules.append(dropout)
@@ -102,7 +109,12 @@ class Decoder(nn.Module):
) -> nn.Sequential:
self.res_block.append(
- nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,)
+ nn.Conv2d(
+ self.embedding_dim,
+ channels[0],
+ kernel_size=1,
+ stride=1,
+ )
)
# Bottleneck module.
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
index d3adac5..ede5c31 100644
--- a/text_recognizer/networks/vqvae/encoder.py
+++ b/text_recognizer/networks/vqvae/encoder.py
@@ -11,7 +11,10 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer
class _ResidualBlock(nn.Module):
def __init__(
- self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]],
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: Optional[Type[nn.Module]],
) -> None:
super().__init__()
self.block = [
@@ -135,7 +138,12 @@ class Encoder(nn.Module):
)
encoder.append(
- nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,)
+ nn.Conv2d(
+ channels[-1],
+ self.embedding_dim,
+ kernel_size=1,
+ stride=1,
+ )
)
return nn.Sequential(*encoder)