diff options
-rw-r--r-- | text_recognizer/networks/efficientnet/efficientnet.py | 17 | ||||
-rw-r--r-- | text_recognizer/networks/efficientnet/utils.py | 10 | ||||
-rw-r--r-- | training/conf/experiment/conv_transformer_paragraphs.yaml | 13 | ||||
-rw-r--r-- | training/conf/network/conv_transformer.yaml | 4 |
4 files changed, 15 insertions, 29 deletions
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py index de08457..2a712d8 100644 --- a/text_recognizer/networks/efficientnet/efficientnet.py +++ b/text_recognizer/networks/efficientnet/efficientnet.py @@ -61,14 +61,14 @@ class EfficientNet(nn.Module): """Builds the efficientnet backbone.""" _block_args = block_args()[: self.depth] in_channels = 1 # BW - out_channels = round_filters(32, self.params) + out_channels = round_filters(16, self.params) self._conv_stem = nn.Sequential( nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, - stride=self.stride, + stride=2, bias=False, ), nn.BatchNorm2d( @@ -98,19 +98,6 @@ class EfficientNet(nn.Module): nn.Conv2d( in_channels, self.out_channels, - kernel_size=2, - stride=self.stride, - bias=False, - ), - nn.BatchNorm2d( - num_features=self.out_channels, - momentum=self.bn_momentum, - eps=self.bn_eps, - ), - nn.Mish(inplace=True), - nn.Conv2d( - self.out_channels, - self.out_channels, kernel_size=1, stride=1, bias=False, diff --git a/text_recognizer/networks/efficientnet/utils.py b/text_recognizer/networks/efficientnet/utils.py index 5234324..770f85b 100644 --- a/text_recognizer/networks/efficientnet/utils.py +++ b/text_recognizer/networks/efficientnet/utils.py @@ -72,13 +72,13 @@ def block_args() -> List[DictConfig]: "se_ratio", ] args = [ - [1, 3, (1, 1), 1, 32, 16, 0.25], + [1, 3, (1, 1), 1, 16, 16, 0.25], [2, 3, (2, 2), 6, 16, 24, 0.25], [2, 5, (2, 2), 6, 24, 40, 0.25], - [3, 3, (2, 2), 6, 40, 80, 0.25], - [3, 5, (1, 1), 6, 80, 112, 0.25], - [4, 5, (2, 2), 6, 112, 192, 0.25], - [1, 3, (1, 1), 6, 192, 320, 0.25], + [3, 3, (2, 1), 6, 40, 80, 0.25], + [3, 5, (2, 1), 6, 80, 112, 0.25], + [4, 5, (2, 1), 6, 112, 192, 0.25], + [1, 3, (2, 1), 6, 192, 320, 0.25], ] block_args_ = [] for row in args: diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml index 7c6e231..c8db485 100644 --- a/training/conf/experiment/conv_transformer_paragraphs.yaml +++ b/training/conf/experiment/conv_transformer_paragraphs.yaml @@ -34,15 +34,14 @@ optimizer: betas: [0.9, 0.999] weight_decay: 0 eps: 1.0e-8 - parameters: network lr_scheduler: _target_: torch.optim.lr_scheduler.OneCycleLR max_lr: 3.0e-4 total_steps: null epochs: *epochs - steps_per_epoch: 3201 - pct_start: 0.3 + steps_per_epoch: 5037 + pct_start: 0.15 anneal_strategy: cos cycle_momentum: true base_momentum: 0.85 @@ -56,7 +55,7 @@ lr_scheduler: monitor: val/cer datamodule: - batch_size: 6 + batch_size: 4 train_fraction: 0.95 network: @@ -66,9 +65,9 @@ network: encoder: depth: 5 decoder: - depth: 6 + depth: 4 pixel_embedding: - shape: [18, 78] + shape: [17, 79] model: max_output_len: *max_output_len @@ -76,4 +75,4 @@ model: trainer: gradient_clip_val: 0.5 max_epochs: *epochs - accumulate_grad_batches: 1 + accumulate_grad_batches: 2 diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index ccdf960..90c2cb8 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -9,7 +9,7 @@ encoder: stochastic_dropout_rate: 0.2 bn_momentum: 0.99 bn_eps: 1.0e-3 - depth: 3 + depth: 5 out_channels: *hidden_dim stride: [2, 1] decoder: @@ -47,4 +47,4 @@ decoder: pixel_embedding: _target_: text_recognizer.networks.transformer.AxialPositionalEmbedding dim: *hidden_dim - shape: [17, 78] + shape: [18, 78] |