diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-05 00:06:37 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-05 00:06:37 +0200 |
commit | cfb460666953c87f606833bf597b53eba0a2900d (patch) | |
tree | ece2fcb5c9003a4dac8ed4f2108d063b07cdc274 | |
parent | f95d51e45ea24a956ce4384e4680f849651b2506 (diff) |
Format
-rw-r--r-- | text_recognizer/models/__init__.py | 2 | ||||
-rw-r--r-- | text_recognizer/networks/efficientnet/efficientnet.py | 11 | ||||
-rw-r--r-- | text_recognizer/networks/efficientnet/utils.py | 10 |
3 files changed, 7 insertions, 16 deletions
diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py index cc02487..56e3e93 100644 --- a/text_recognizer/models/__init__.py +++ b/text_recognizer/models/__init__.py @@ -1,2 +1,4 @@ """PyTorch Lightning models modules.""" from text_recognizer.models.transformer import LitTransformer +from text_recognizer.models.perceiver import LitPerceiver +from text_recognizer.models.vq_transformer import LitVqTransformer diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py index bd47e4b..3481090 100644 --- a/text_recognizer/networks/efficientnet/efficientnet.py +++ b/text_recognizer/networks/efficientnet/efficientnet.py @@ -73,17 +73,6 @@ class EfficientNet(nn.Module): num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps ), nn.Mish(inplace=True), - nn.Conv2d( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=3, - stride=2, - bias=False, - ), - nn.BatchNorm2d( - num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps - ), - nn.Mish(inplace=True), ) self._blocks = nn.ModuleList([]) for args in _block_args: diff --git a/text_recognizer/networks/efficientnet/utils.py b/text_recognizer/networks/efficientnet/utils.py index 412d07d..5234324 100644 --- a/text_recognizer/networks/efficientnet/utils.py +++ b/text_recognizer/networks/efficientnet/utils.py @@ -74,11 +74,11 @@ def block_args() -> List[DictConfig]: args = [ [1, 3, (1, 1), 1, 32, 16, 0.25], [2, 3, (2, 2), 6, 16, 24, 0.25], - [2, 5, (2, 1), 6, 24, 40, 0.25], - [3, 3, (2, 1), 6, 40, 80, 0.25], - [3, 5, (2, 1), 6, 80, 112, 0.25], - [4, 5, (1, 1), 6, 112, 192, 0.25], - [1, 3, (2, 1), 6, 192, 320, 0.25], + [2, 5, (2, 2), 6, 24, 40, 0.25], + [3, 3, (2, 2), 6, 40, 80, 0.25], + [3, 5, (1, 1), 6, 80, 112, 0.25], + [4, 5, (2, 2), 6, 112, 192, 0.25], + [1, 3, (1, 1), 6, 192, 320, 0.25], ] block_args_ = [] for row in args: |