From 65df6a72b002c4b23d6f2eb545839e157f7f2aa0 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 5 Jun 2022 23:39:11 +0200 Subject: Remove attrs --- text_recognizer/data/base_data_module.py | 46 ++++---- text_recognizer/data/base_dataset.py | 19 +++- text_recognizer/data/emnist.py | 6 +- text_recognizer/data/emnist_lines.py | 49 ++++++--- text_recognizer/data/iam.py | 7 +- text_recognizer/data/iam_extended_paragraphs.py | 27 ++++- text_recognizer/data/iam_lines.py | 10 +- text_recognizer/data/iam_paragraphs.py | 11 +- text_recognizer/data/iam_synthetic_paragraphs.py | 5 +- text_recognizer/models/base.py | 36 +++--- text_recognizer/models/metrics.py | 13 +-- text_recognizer/models/transformer.py | 29 ++--- .../networks/efficientnet/efficientnet.py | 47 ++++---- text_recognizer/networks/efficientnet/mbconv.py | 121 +++++++++++++-------- 14 files changed, 250 insertions(+), 176 deletions(-) diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index a0c8416..6306cf8 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -2,7 +2,6 @@ from pathlib import Path from typing import Callable, Dict, Optional, Tuple, Type, TypeVar -from attrs import define, field from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader @@ -20,29 +19,36 @@ def load_and_print_info(data_module_class: type) -> None: print(dataset) -@define(repr=False) class BaseDataModule(LightningDataModule): """Base PyTorch Lightning DataModule.""" - def __attrs_post_init__(self) -> None: - """Pre init constructor.""" + def __init__( + self, + mapping: Type[AbstractMapping], + transform: Optional[Callable] = None, + test_transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + train_fraction: float = 0.8, + batch_size: int = 16, + num_workers: int = 0, + pin_memory: bool = True, + ) -> None: super().__init__() - - mapping: Type[AbstractMapping] = field() - transform: Optional[Callable] = field(default=None) - test_transform: Optional[Callable] = field(default=None) - target_transform: Optional[Callable] = field(default=None) - train_fraction: float = field(default=0.8) - batch_size: int = field(default=16) - num_workers: int = field(default=0) - pin_memory: bool = field(default=True) - - # Placeholders - data_train: BaseDataset = field(init=False, default=None) - data_val: BaseDataset = field(init=False, default=None) - data_test: BaseDataset = field(init=False, default=None) - dims: Tuple[int, ...] = field(init=False, default=None) - output_dims: Tuple[int, ...] = field(init=False, default=None) + self.mapping = mapping + self.transform = transform + self.test_transform = test_transform + self.target_transform = target_transform + self.train_fraction = train_fraction + self.batch_size = batch_size + self.num_workers = num_workers + self.pin_memory = pin_memory + + # Placeholders + self.data_train: BaseDataset + self.data_val: BaseDataset + self.data_test: BaseDataset + self.dims: Tuple[int, ...] + self.output_dims: Tuple[int, ...] @classmethod def data_dirname(cls: T) -> Path: diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index c57cbcc..675683a 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -1,7 +1,6 @@ """Base PyTorch Dataset class.""" from typing import Callable, Dict, Optional, Sequence, Tuple, Union -from attrs import define, field import torch from torch import Tensor from torch.utils.data import Dataset @@ -9,7 +8,6 @@ from torch.utils.data import Dataset from text_recognizer.data.transforms.load_transform import load_transform_from_file -@define class BaseDataset(Dataset): r"""Base Dataset class that processes data and targets through optional transfroms. @@ -21,10 +19,19 @@ class BaseDataset(Dataset): target transforms. """ - data: Union[Sequence, Tensor] = field() - targets: Union[Sequence, Tensor] = field() - transform: Union[Optional[Callable], str] = field(default=None) - target_transform: Union[Optional[Callable], str] = field(default=None) + def __init__( + self, + data: Union[Sequence, Tensor], + targets: Union[Sequence, Tensor], + transform: Union[Optional[Callable], str], + target_transform: Union[Optional[Callable], str], + ) -> None: + super().__init__() + + self.data = data + self.targets = targets + self.transform = transform + self.target_transform = target_transform def __attrs_pre_init__(self) -> None: """Pre init constructor.""" diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 94882bf..1b1381a 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -6,7 +6,6 @@ import shutil from typing import Dict, List, Optional, Sequence, Set, Tuple import zipfile -from attrs import define import h5py from loguru import logger as log import numpy as np @@ -35,7 +34,6 @@ ESSENTIALS_FILENAME = ( ) -@define(auto_attribs=True) class EMNIST(BaseDataModule): """Lightning DataModule class for loading EMNIST dataset. @@ -48,8 +46,8 @@ class EMNIST(BaseDataModule): EMNIST ByClass: 814,255 characters. 62 unbalanced classes. """ - def __attrs_post_init__(self) -> None: - """Post init configuration.""" + def __init__(self) -> None: + super().__init__() self.dims = (1, *self.mapping.input_size) def prepare_data(self) -> None: diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 43d55b9..062257d 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -1,9 +1,8 @@ """Dataset of generated text from EMNIST characters.""" from collections import defaultdict from pathlib import Path -from typing import DefaultDict, List, Tuple +from typing import Callable, DefaultDict, List, Optional, Tuple, Type -from attrs import define, field import h5py from loguru import logger as log import numpy as np @@ -17,6 +16,7 @@ from text_recognizer.data.base_data_module import ( ) from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels from text_recognizer.data.emnist import EMNIST +from text_recognizer.data.mappings import AbstractMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file from text_recognizer.data.utils.sentence_generator import SentenceGenerator @@ -33,22 +33,45 @@ IMAGE_X_PADDING = 28 MAX_OUTPUT_LENGTH = 89 # Same as IAMLines -@define(auto_attribs=True, repr=False) class EMNISTLines(BaseDataModule): """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST.""" - max_length: int = field(default=128) - min_overlap: float = field(default=0.0) - max_overlap: float = field(default=0.33) - num_train: int = field(default=10_000) - num_val: int = field(default=2_000) - num_test: int = field(default=2_000) - emnist: EMNIST = field(init=False, default=None) + def __init__( + self, + mapping: Type[AbstractMapping], + transform: Optional[Callable] = None, + test_transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + train_fraction: float = 0.8, + batch_size: int = 16, + num_workers: int = 0, + pin_memory: bool = True, + max_length: int = 128, + min_overlap: float = 0.0, + max_overlap: float = 0.33, + num_train: int = 10_000, + num_val: int = 2_000, + num_test: int = 2_000, + ) -> None: + super().__init__( + mapping, + transform, + test_transform, + target_transform, + train_fraction, + batch_size, + num_workers, + pin_memory, + ) - def __attrs_post_init__(self) -> None: - """Post init constructor.""" - self.emnist = EMNIST(mapping=self.mapping) + self.max_length = max_length + self.min_overlap = min_overlap + self.max_overlap = max_overlap + self.num_train = num_train + self.num_val = num_val + self.num_test = num_test + self.emnist = EMNIST(mapping=self.mapping) max_width = ( int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) + IMAGE_X_PADDING diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py index 8166863..a4d1d21 100644 --- a/text_recognizer/data/iam.py +++ b/text_recognizer/data/iam.py @@ -9,7 +9,6 @@ from typing import Any, Dict, List import xml.etree.ElementTree as ElementTree import zipfile -from attrs import define, field from boltons.cacheutils import cachedproperty from loguru import logger as log import toml @@ -27,7 +26,6 @@ DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be. LINE_REGION_PADDING = 16 # Add this many pixels around the exact coordinates. -@define(auto_attribs=True) class IAM(BaseDataModule): r"""The IAM Lines dataset. @@ -44,7 +42,10 @@ class IAM(BaseDataModule): contributed to one set only. """ - metadata: Dict = field(init=False, default=toml.load(METADATA_FILENAME)) + def __init__(self) -> None: + super().__init__() + + self.metadata: Dict = toml.load(METADATA_FILENAME) def prepare_data(self) -> None: """Prepares the IAM dataset.""" diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 52c10c3..0c16181 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -1,19 +1,38 @@ """IAM original and sythetic dataset class.""" -from attrs import define, field +from typing import Callable, Optional, Type from torch.utils.data import ConcatDataset from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam_paragraphs import IAMParagraphs +from text_recognizer.data.mappings import AbstractMapping from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs from text_recognizer.data.transforms.load_transform import load_transform_from_file -@define(auto_attribs=True, repr=False) class IAMExtendedParagraphs(BaseDataModule): """A dataset with synthetic and real handwritten paragraph.""" - def __attrs_post_init__(self) -> None: - """Post init constructor.""" + def __init__( + self, + mapping: Type[AbstractMapping], + transform: Optional[Callable] = None, + test_transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + train_fraction: float = 0.8, + batch_size: int = 16, + num_workers: int = 0, + pin_memory: bool = True, + ) -> None: + super().__init__( + mapping, + transform, + test_transform, + target_transform, + train_fraction, + batch_size, + num_workers, + pin_memory, + ) self.iam_paragraphs = IAMParagraphs( mapping=self.mapping, batch_size=self.batch_size, diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index 34cf605..5f38f14 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -7,7 +7,6 @@ import json from pathlib import Path from typing import List, Sequence, Tuple -from attrs import define, field from loguru import logger as log import numpy as np from PIL import Image, ImageFile, ImageOps @@ -35,14 +34,13 @@ MAX_LABEL_LENGTH = 89 MAX_WORD_PIECE_LENGTH = 72 -@define(auto_attribs=True, repr=False) class IAMLines(BaseDataModule): """IAM handwritten lines dataset.""" - dims: Tuple[int, int, int] = field( - init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH) - ) - output_dims: Tuple[int, int] = field(init=False, default=(MAX_LABEL_LENGTH, 1)) + def __init__(self) -> None: + super().__init__() + self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) + self.output_dims = (MAX_LABEL_LENGTH, 1) def prepare_data(self) -> None: """Creates the IAM lines dataset if not existing.""" diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index b605bbc..dde505d 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -3,7 +3,6 @@ import json from pathlib import Path from typing import Dict, List, Optional, Sequence, Tuple -from attrs import define, field from loguru import logger as log import numpy as np from PIL import Image, ImageOps @@ -33,15 +32,13 @@ MAX_LABEL_LENGTH = 682 MAX_WORD_PIECE_LENGTH = 451 -@define(auto_attribs=True, repr=False) class IAMParagraphs(BaseDataModule): """IAM handwriting database paragraphs.""" - # Placeholders - dims: Tuple[int, int, int] = field( - init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH) - ) - output_dims: Tuple[int, int] = field(init=False, default=(MAX_LABEL_LENGTH, 1)) + def __init__(self) -> None: + super().__init__() + self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) + self.output_dims = (MAX_LABEL_LENGTH, 1) def prepare_data(self) -> None: """Create data for training/testing.""" diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index 7143951..2e7762c 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -2,7 +2,6 @@ import random from typing import Any, List, Sequence, Tuple -from attrs import define from loguru import logger as log import numpy as np from PIL import Image @@ -34,10 +33,12 @@ PROCESSED_DATA_DIRNAME = ( ) -@define(auto_attribs=True, repr=False) class IAMSyntheticParagraphs(IAMParagraphs): """IAM Handwriting database of synthetic paragraphs.""" + def __init__(self) -> None: + super().__init__() + def prepare_data(self) -> None: """Prepare IAM lines to be used to generate paragraphs.""" if PROCESSED_DATA_DIRNAME.exists(): diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index bf3bc08..63fe5a7 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -1,7 +1,6 @@ """Base PyTorch Lightning model.""" from typing import Any, Dict, List, Optional, Tuple, Type -from attrs import define, field import hydra from loguru import logger as log from omegaconf import DictConfig @@ -9,31 +8,34 @@ from pytorch_lightning import LightningModule import torch from torch import nn from torch import Tensor -import torchmetrics +from torchmetrics import Accuracy from text_recognizer.data.mappings.base import AbstractMapping -@define(eq=False) class BaseLitModel(LightningModule): """Abstract PyTorch Lightning class.""" - def __attrs_pre_init__(self) -> None: - """Pre init constructor.""" + def __init__( + self, + network: Type[nn.Module], + loss_fn: Type[nn.Module], + optimizer_configs: DictConfig, + lr_scheduler_configs: Optional[DictConfig], + mapping: Type[AbstractMapping], + ) -> None: super().__init__() - network: Type[nn.Module] = field() - loss_fn: Type[nn.Module] = field() - optimizer_configs: DictConfig = field() - lr_scheduler_configs: Optional[DictConfig] = field() - mapping: Type[AbstractMapping] = field() - - # Placeholders - train_acc: torchmetrics.Accuracy = field( - init=False, default=torchmetrics.Accuracy() - ) - val_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy()) - test_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy()) + self.network = network + self.loss_fn = loss_fn + self.optimizer_configs = optimizer_configs + self.lr_scheduler_configs = lr_scheduler_configs + self.mapping = mapping + + # Placeholders + self.train_acc = Accuracy() + self.val_acc = Accuracy() + self.test_acc = Accuracy() def optimizer_zero_grad( self, diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py index e59a830..3cb16b5 100644 --- a/text_recognizer/models/metrics.py +++ b/text_recognizer/models/metrics.py @@ -1,25 +1,22 @@ """Character Error Rate (CER).""" -from typing import Set +from typing import Sequence -from attrs import define, field import editdistance import torch from torch import Tensor from torchmetrics import Metric -@define(eq=False) class CharacterErrorRate(Metric): """Character error rate metric, computed using Levenshtein distance.""" - ignore_indices: Set[Tensor] = field(converter=set) - error: Tensor = field(init=False) - total: Tensor = field(init=False) - - def __attrs_post_init__(self) -> None: + def __init__(self, ignore_indices: Sequence[Tensor]) -> None: super().__init__() + self.ignore_indices = set(ignore_indices) self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + self.error: Tensor + self.total: Tensor def update(self, preds: Tensor, targets: Tensor) -> None: """Update CER.""" diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index c5120fe..9537dd9 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,7 +1,6 @@ """PyTorch Lightning model for base Transformers.""" from typing import Set, Tuple -from attrs import define, field import torch from torch import Tensor @@ -9,25 +8,21 @@ from text_recognizer.models.base import BaseLitModel from text_recognizer.models.metrics import CharacterErrorRate -@define(auto_attribs=True, eq=False) class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - max_output_len: int = field(default=451) - start_token: str = field(default="") - end_token: str = field(default="") - pad_token: str = field(default="

") - - start_index: int = field(init=False) - end_index: int = field(init=False) - pad_index: int = field(init=False) - - ignore_indices: Set[Tensor] = field(init=False) - val_cer: CharacterErrorRate = field(init=False) - test_cer: CharacterErrorRate = field(init=False) - - def __attrs_post_init__(self) -> None: - """Post init configuration.""" + def __init__( + self, + max_output_len: int = 451, + start_token: str = "", + end_token: str = "", + pad_token: str = "

", + ) -> None: + super().__init__() + self.max_output_len = max_output_len + self.start_token = start_token + self.end_token = end_token + self.pad_token = pad_token self.start_index = int(self.mapping.get_index(self.start_token)) self.end_index = int(self.mapping.get_index(self.end_token)) self.pad_index = int(self.mapping.get_index(self.pad_token)) diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py index cf64bcf..2260ee2 100644 --- a/text_recognizer/networks/efficientnet/efficientnet.py +++ b/text_recognizer/networks/efficientnet/efficientnet.py @@ -1,7 +1,6 @@ """Efficientnet backbone.""" from typing import Tuple -from attrs import define, field from torch import nn, Tensor from text_recognizer.networks.efficientnet.mbconv import MBConvBlock @@ -12,13 +11,9 @@ from text_recognizer.networks.efficientnet.utils import ( ) -@define(eq=False) class EfficientNet(nn.Module): """Efficientnet without classification head.""" - def __attrs_pre_init__(self) -> None: - super().__init__() - archs = { # width, depth, dropout "b0": (1.0, 1.0, 0.2), @@ -33,32 +28,32 @@ class EfficientNet(nn.Module): "l2": (4.3, 5.3, 0.5), } - arch: str = field() - params: Tuple[float, float, float] = field(default=None, init=False) - stochastic_dropout_rate: float = field(default=0.2) - bn_momentum: float = field(default=0.99) - bn_eps: float = field(default=1.0e-3) - depth: int = field(default=7) - out_channels: int = field(default=None, init=False) - _conv_stem: nn.Sequential = field(default=None, init=False) - _blocks: nn.ModuleList = field(default=None, init=False) - _conv_head: nn.Sequential = field(default=None, init=False) - - def __attrs_post_init__(self) -> None: - """Post init configuration.""" + def __init__( + self, + arch: str, + params: Tuple[float, float, float], + stochastic_dropout_rate: float = 0.2, + bn_momentum: float = 0.99, + bn_eps: float = 1.0e-3, + depth: int = 7, + ) -> None: + super().__init__() + self.params = self._get_arch_params(arch) + self.stochastic_dropout_rate = stochastic_dropout_rate + self.bn_momentum = bn_momentum + self.bn_eps = bn_eps + self.depth = depth + self.out_channels: int + self._conv_stem: nn.Sequential + self._blocks: nn.ModuleList + self._conv_head: nn.Sequential self._build() - @depth.validator - def _check_depth(self, attribute, value: str) -> None: - if not 5 <= value <= 7: - raise ValueError(f"Depth has to be between 5 and 7, was: {value}") - - @arch.validator - def _check_arch(self, attribute, value: str) -> None: + def _get_arch_params(self, value: str) -> Tuple[float, float, float]: """Validates the efficientnet architecure.""" if value not in self.archs: raise ValueError(f"{value} not a valid architecure.") - self.params = self.archs[value] + return self.archs[value] def _build(self) -> None: """Builds the efficientnet backbone.""" diff --git a/text_recognizer/networks/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py index 98e9353..64debd9 100644 --- a/text_recognizer/networks/efficientnet/mbconv.py +++ b/text_recognizer/networks/efficientnet/mbconv.py @@ -1,7 +1,6 @@ """Mobile inverted residual block.""" from typing import Optional, Tuple, Union -from attrs import define, field import torch from torch import nn, Tensor import torch.nn.functional as F @@ -14,18 +13,15 @@ def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]: return (stride,) * 2 if isinstance(stride, int) else stride -@define(eq=False) class BaseModule(nn.Module): """Base sub module class.""" - bn_momentum: float = field() - bn_eps: float = field() - block: nn.Sequential = field(init=False) - - def __attrs_pre_init__(self) -> None: + def __init__(self, bn_momentum: float, bn_eps: float, block: nn.Sequential) -> None: super().__init__() - def __attrs_post_init__(self) -> None: + self.bn_momentum = bn_momentum + self.bn_eps = bn_eps + self.block = block self._build() def _build(self) -> None: @@ -36,12 +32,20 @@ class BaseModule(nn.Module): return self.block(x) -@define(auto_attribs=True, eq=False) class InvertedBottleneck(BaseModule): """Inverted bottleneck module.""" - in_channels: int = field() - out_channels: int = field() + def __init__( + self, + bn_momentum: float, + bn_eps: float, + block: nn.Sequential, + in_channels: int, + out_channels: int, + ) -> None: + super().__init__(bn_momentum, bn_eps, block) + self.in_channels = in_channels + self.out_channels = out_channels def _build(self) -> None: self.block = nn.Sequential( @@ -60,13 +64,22 @@ class InvertedBottleneck(BaseModule): ) -@define(auto_attribs=True, eq=False) class Depthwise(BaseModule): """Depthwise convolution module.""" - channels: int = field() - kernel_size: int = field() - stride: int = field() + def __init__( + self, + bn_momentum: float, + bn_eps: float, + block: nn.Sequential, + channels: int, + kernel_size: int, + stride: int, + ) -> None: + super().__init__(bn_momentum, bn_eps, block) + self.channels = channels + self.kernel_size = kernel_size + self.stride = stride def _build(self) -> None: self.block = nn.Sequential( @@ -85,13 +98,23 @@ class Depthwise(BaseModule): ) -@define(auto_attribs=True, eq=False) class SqueezeAndExcite(BaseModule): """Sequeeze and excite module.""" - in_channels: int = field() - channels: int = field() - se_ratio: float = field() + def __init__( + self, + bn_momentum: float, + bn_eps: float, + block: nn.Sequential, + in_channels: int, + channels: int, + se_ratio: float, + ) -> None: + super().__init__(bn_momentum, bn_eps, block) + + self.in_channels = in_channels + self.channels = channels + self.se_ratio = se_ratio def _build(self) -> None: num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio)) @@ -110,12 +133,20 @@ class SqueezeAndExcite(BaseModule): ) -@define(auto_attribs=True, eq=False) class Pointwise(BaseModule): """Pointwise module.""" - in_channels: int = field() - out_channels: int = field() + def __init__( + self, + bn_momentum: float, + bn_eps: float, + block: nn.Sequential, + in_channels: int, + out_channels: int, + ) -> None: + super().__init__(bn_momentum, bn_eps, block) + self.in_channels = in_channels + self.out_channels = out_channels def _build(self) -> None: self.block = nn.Sequential( @@ -133,28 +164,36 @@ class Pointwise(BaseModule): ) -@define(eq=False) class MBConvBlock(nn.Module): """Mobile Inverted Residual Bottleneck block.""" - def __attrs_pre_init__(self) -> None: + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, int], + stride: Tuple[int, int], + bn_momentum: float, + bn_eps: float, + se_ratio: float, + expand_ratio: int, + ) -> None: super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.bn_momentum = bn_momentum + self.bn_eps = bn_eps + self.se_ratio = se_ratio + self.expand_ratio = expand_ratio + self.pad = self._configure_padding() + self._inverted_bottleneck: Optional[InvertedBottleneck] + self._depthwise: nn.Sequential + self._squeeze_excite: nn.Sequential + self._pointwise: nn.Sequential + self._build() - in_channels: int = field() - out_channels: int = field() - kernel_size: Tuple[int, int] = field() - stride: Tuple[int, int] = field(converter=_convert_stride) - bn_momentum: float = field() - bn_eps: float = field() - se_ratio: float = field() - expand_ratio: int = field() - pad: Tuple[int, int, int, int] = field(init=False) - _inverted_bottleneck: Optional[InvertedBottleneck] = field(init=False) - _depthwise: nn.Sequential = field(init=False) - _squeeze_excite: nn.Sequential = field(init=False) - _pointwise: nn.Sequential = field(init=False) - - @pad.default def _configure_padding(self) -> Tuple[int, int, int, int]: """Set padding for convolutional layers.""" if self.stride == (2, 2): @@ -164,10 +203,6 @@ class MBConvBlock(nn.Module): ) * 2 return ((self.kernel_size - 1) // 2,) * 4 - def __attrs_post_init__(self) -> None: - """Post init configuration.""" - self._build() - def _build(self) -> None: has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0 inner_channels = self.in_channels * self.expand_ratio -- cgit v1.2.3-70-g09d2