summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-20 00:09:20 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-20 00:09:20 +0200
commit521f3bbbba9f04f48e81d78033c6e1c29a08e515 (patch)
tree1ffb50acde358fb151d114f63b760e67c77e3274
parentce3f63801013aba2f05cfb92f1a3a87393610d27 (diff)
Update eff net config
-rw-r--r--text_recognizer/networks/efficientnet/efficientnet.py17
-rw-r--r--text_recognizer/networks/efficientnet/utils.py10
-rw-r--r--training/conf/experiment/conv_transformer_paragraphs.yaml13
-rw-r--r--training/conf/network/conv_transformer.yaml4
4 files changed, 15 insertions, 29 deletions
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py
index de08457..2a712d8 100644
--- a/text_recognizer/networks/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/efficientnet/efficientnet.py
@@ -61,14 +61,14 @@ class EfficientNet(nn.Module):
"""Builds the efficientnet backbone."""
_block_args = block_args()[: self.depth]
in_channels = 1 # BW
- out_channels = round_filters(32, self.params)
+ out_channels = round_filters(16, self.params)
self._conv_stem = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
- stride=self.stride,
+ stride=2,
bias=False,
),
nn.BatchNorm2d(
@@ -98,19 +98,6 @@ class EfficientNet(nn.Module):
nn.Conv2d(
in_channels,
self.out_channels,
- kernel_size=2,
- stride=self.stride,
- bias=False,
- ),
- nn.BatchNorm2d(
- num_features=self.out_channels,
- momentum=self.bn_momentum,
- eps=self.bn_eps,
- ),
- nn.Mish(inplace=True),
- nn.Conv2d(
- self.out_channels,
- self.out_channels,
kernel_size=1,
stride=1,
bias=False,
diff --git a/text_recognizer/networks/efficientnet/utils.py b/text_recognizer/networks/efficientnet/utils.py
index 5234324..770f85b 100644
--- a/text_recognizer/networks/efficientnet/utils.py
+++ b/text_recognizer/networks/efficientnet/utils.py
@@ -72,13 +72,13 @@ def block_args() -> List[DictConfig]:
"se_ratio",
]
args = [
- [1, 3, (1, 1), 1, 32, 16, 0.25],
+ [1, 3, (1, 1), 1, 16, 16, 0.25],
[2, 3, (2, 2), 6, 16, 24, 0.25],
[2, 5, (2, 2), 6, 24, 40, 0.25],
- [3, 3, (2, 2), 6, 40, 80, 0.25],
- [3, 5, (1, 1), 6, 80, 112, 0.25],
- [4, 5, (2, 2), 6, 112, 192, 0.25],
- [1, 3, (1, 1), 6, 192, 320, 0.25],
+ [3, 3, (2, 1), 6, 40, 80, 0.25],
+ [3, 5, (2, 1), 6, 80, 112, 0.25],
+ [4, 5, (2, 1), 6, 112, 192, 0.25],
+ [1, 3, (2, 1), 6, 192, 320, 0.25],
]
block_args_ = []
for row in args:
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml
index 7c6e231..c8db485 100644
--- a/training/conf/experiment/conv_transformer_paragraphs.yaml
+++ b/training/conf/experiment/conv_transformer_paragraphs.yaml
@@ -34,15 +34,14 @@ optimizer:
betas: [0.9, 0.999]
weight_decay: 0
eps: 1.0e-8
- parameters: network
lr_scheduler:
_target_: torch.optim.lr_scheduler.OneCycleLR
max_lr: 3.0e-4
total_steps: null
epochs: *epochs
- steps_per_epoch: 3201
- pct_start: 0.3
+ steps_per_epoch: 5037
+ pct_start: 0.15
anneal_strategy: cos
cycle_momentum: true
base_momentum: 0.85
@@ -56,7 +55,7 @@ lr_scheduler:
monitor: val/cer
datamodule:
- batch_size: 6
+ batch_size: 4
train_fraction: 0.95
network:
@@ -66,9 +65,9 @@ network:
encoder:
depth: 5
decoder:
- depth: 6
+ depth: 4
pixel_embedding:
- shape: [18, 78]
+ shape: [17, 79]
model:
max_output_len: *max_output_len
@@ -76,4 +75,4 @@ model:
trainer:
gradient_clip_val: 0.5
max_epochs: *epochs
- accumulate_grad_batches: 1
+ accumulate_grad_batches: 2
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml
index ccdf960..90c2cb8 100644
--- a/training/conf/network/conv_transformer.yaml
+++ b/training/conf/network/conv_transformer.yaml
@@ -9,7 +9,7 @@ encoder:
stochastic_dropout_rate: 0.2
bn_momentum: 0.99
bn_eps: 1.0e-3
- depth: 3
+ depth: 5
out_channels: *hidden_dim
stride: [2, 1]
decoder:
@@ -47,4 +47,4 @@ decoder:
pixel_embedding:
_target_: text_recognizer.networks.transformer.AxialPositionalEmbedding
dim: *hidden_dim
- shape: [17, 78]
+ shape: [18, 78]