From 75801019981492eedf9280cb352eea3d8e99b65f Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 2 Aug 2021 21:13:48 +0200 Subject: Fix log import, fix mapping in datamodules, fix nn modules can be hashed --- .../networks/encoders/efficientnet/efficientnet.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) (limited to 'text_recognizer/networks/encoders/efficientnet/efficientnet.py') diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py index a36150a..b8eb53b 100644 --- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py +++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py @@ -1,4 +1,4 @@ -"""Efficient net.""" +"""Efficientnet backbone.""" from typing import Tuple import attr @@ -12,8 +12,10 @@ from .utils import ( ) -@attr.s +@attr.s(eq=False) class EfficientNet(nn.Module): + """Efficientnet without classification head.""" + def __attrs_pre_init__(self) -> None: super().__init__() @@ -47,11 +49,13 @@ class EfficientNet(nn.Module): @arch.validator def check_arch(self, attribute: attr._make.Attribute, value: str) -> None: + """Validates the efficientnet architecure.""" if value not in self.archs: raise ValueError(f"{value} not a valid architecure.") self.params = self.archs[value] def _build(self) -> None: + """Builds the efficientnet backbone.""" _block_args = block_args() in_channels = 1 # BW out_channels = round_filters(32, self.params) @@ -73,8 +77,9 @@ class EfficientNet(nn.Module): for args in _block_args: args.in_channels = round_filters(args.in_channels, self.params) args.out_channels = round_filters(args.out_channels, self.params) - args.num_repeats = round_repeats(args.num_repeats, self.params) - for _ in range(args.num_repeats): + num_repeats = round_repeats(args.num_repeats, self.params) + del args.num_repeats + for _ in range(num_repeats): self._blocks.append( MBConvBlock( **args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps, @@ -93,6 +98,7 @@ class EfficientNet(nn.Module): ) def extract_features(self, x: Tensor) -> Tensor: + """Extracts the final feature map layer.""" x = self._conv_stem(x) for i, block in enumerate(self._blocks): stochastic_dropout_rate = self.stochastic_dropout_rate @@ -103,4 +109,5 @@ class EfficientNet(nn.Module): return x def forward(self, x: Tensor) -> Tensor: + """Returns efficientnet image features.""" return self.extract_features(x) -- cgit v1.2.3-70-g09d2