summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/models/__init__.py2
-rw-r--r--text_recognizer/networks/efficientnet/efficientnet.py11
-rw-r--r--text_recognizer/networks/efficientnet/utils.py10
3 files changed, 7 insertions, 16 deletions
diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py
index cc02487..56e3e93 100644
--- a/text_recognizer/models/__init__.py
+++ b/text_recognizer/models/__init__.py
@@ -1,2 +1,4 @@
"""PyTorch Lightning models modules."""
from text_recognizer.models.transformer import LitTransformer
+from text_recognizer.models.perceiver import LitPerceiver
+from text_recognizer.models.vq_transformer import LitVqTransformer
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py
index bd47e4b..3481090 100644
--- a/text_recognizer/networks/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/efficientnet/efficientnet.py
@@ -73,17 +73,6 @@ class EfficientNet(nn.Module):
num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
),
nn.Mish(inplace=True),
- nn.Conv2d(
- in_channels=out_channels,
- out_channels=out_channels,
- kernel_size=3,
- stride=2,
- bias=False,
- ),
- nn.BatchNorm2d(
- num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
- ),
- nn.Mish(inplace=True),
)
self._blocks = nn.ModuleList([])
for args in _block_args:
diff --git a/text_recognizer/networks/efficientnet/utils.py b/text_recognizer/networks/efficientnet/utils.py
index 412d07d..5234324 100644
--- a/text_recognizer/networks/efficientnet/utils.py
+++ b/text_recognizer/networks/efficientnet/utils.py
@@ -74,11 +74,11 @@ def block_args() -> List[DictConfig]:
args = [
[1, 3, (1, 1), 1, 32, 16, 0.25],
[2, 3, (2, 2), 6, 16, 24, 0.25],
- [2, 5, (2, 1), 6, 24, 40, 0.25],
- [3, 3, (2, 1), 6, 40, 80, 0.25],
- [3, 5, (2, 1), 6, 80, 112, 0.25],
- [4, 5, (1, 1), 6, 112, 192, 0.25],
- [1, 3, (2, 1), 6, 192, 320, 0.25],
+ [2, 5, (2, 2), 6, 24, 40, 0.25],
+ [3, 3, (2, 2), 6, 40, 80, 0.25],
+ [3, 5, (1, 1), 6, 80, 112, 0.25],
+ [4, 5, (2, 2), 6, 112, 192, 0.25],
+ [1, 3, (1, 1), 6, 192, 320, 0.25],
]
block_args_ = []
for row in args: