summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/encoders/efficientnet/efficientnet.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/encoders/efficientnet/efficientnet.py')
-rw-r--r--text_recognizer/networks/encoders/efficientnet/efficientnet.py15
1 files changed, 11 insertions, 4 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
index a36150a..b8eb53b 100644
--- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
@@ -1,4 +1,4 @@
-"""Efficient net."""
+"""Efficientnet backbone."""
from typing import Tuple
import attr
@@ -12,8 +12,10 @@ from .utils import (
)
-@attr.s
+@attr.s(eq=False)
class EfficientNet(nn.Module):
+ """Efficientnet without classification head."""
+
def __attrs_pre_init__(self) -> None:
super().__init__()
@@ -47,11 +49,13 @@ class EfficientNet(nn.Module):
@arch.validator
def check_arch(self, attribute: attr._make.Attribute, value: str) -> None:
+ """Validates the efficientnet architecure."""
if value not in self.archs:
raise ValueError(f"{value} not a valid architecure.")
self.params = self.archs[value]
def _build(self) -> None:
+ """Builds the efficientnet backbone."""
_block_args = block_args()
in_channels = 1 # BW
out_channels = round_filters(32, self.params)
@@ -73,8 +77,9 @@ class EfficientNet(nn.Module):
for args in _block_args:
args.in_channels = round_filters(args.in_channels, self.params)
args.out_channels = round_filters(args.out_channels, self.params)
- args.num_repeats = round_repeats(args.num_repeats, self.params)
- for _ in range(args.num_repeats):
+ num_repeats = round_repeats(args.num_repeats, self.params)
+ del args.num_repeats
+ for _ in range(num_repeats):
self._blocks.append(
MBConvBlock(
**args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps,
@@ -93,6 +98,7 @@ class EfficientNet(nn.Module):
)
def extract_features(self, x: Tensor) -> Tensor:
+ """Extracts the final feature map layer."""
x = self._conv_stem(x)
for i, block in enumerate(self._blocks):
stochastic_dropout_rate = self.stochastic_dropout_rate
@@ -103,4 +109,5 @@ class EfficientNet(nn.Module):
return x
def forward(self, x: Tensor) -> Tensor:
+ """Returns efficientnet image features."""
return self.extract_features(x)