summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/data/base_data_module.py46
-rw-r--r--text_recognizer/data/base_dataset.py19
-rw-r--r--text_recognizer/data/emnist.py6
-rw-r--r--text_recognizer/data/emnist_lines.py49
-rw-r--r--text_recognizer/data/iam.py7
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py27
-rw-r--r--text_recognizer/data/iam_lines.py10
-rw-r--r--text_recognizer/data/iam_paragraphs.py11
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py5
-rw-r--r--text_recognizer/models/base.py36
-rw-r--r--text_recognizer/models/metrics.py13
-rw-r--r--text_recognizer/models/transformer.py29
-rw-r--r--text_recognizer/networks/efficientnet/efficientnet.py47
-rw-r--r--text_recognizer/networks/efficientnet/mbconv.py121
14 files changed, 250 insertions, 176 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index a0c8416..6306cf8 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -2,7 +2,6 @@
from pathlib import Path
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
-from attrs import define, field
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
@@ -20,29 +19,36 @@ def load_and_print_info(data_module_class: type) -> None:
print(dataset)
-@define(repr=False)
class BaseDataModule(LightningDataModule):
"""Base PyTorch Lightning DataModule."""
- def __attrs_post_init__(self) -> None:
- """Pre init constructor."""
+ def __init__(
+ self,
+ mapping: Type[AbstractMapping],
+ transform: Optional[Callable] = None,
+ test_transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ train_fraction: float = 0.8,
+ batch_size: int = 16,
+ num_workers: int = 0,
+ pin_memory: bool = True,
+ ) -> None:
super().__init__()
-
- mapping: Type[AbstractMapping] = field()
- transform: Optional[Callable] = field(default=None)
- test_transform: Optional[Callable] = field(default=None)
- target_transform: Optional[Callable] = field(default=None)
- train_fraction: float = field(default=0.8)
- batch_size: int = field(default=16)
- num_workers: int = field(default=0)
- pin_memory: bool = field(default=True)
-
- # Placeholders
- data_train: BaseDataset = field(init=False, default=None)
- data_val: BaseDataset = field(init=False, default=None)
- data_test: BaseDataset = field(init=False, default=None)
- dims: Tuple[int, ...] = field(init=False, default=None)
- output_dims: Tuple[int, ...] = field(init=False, default=None)
+ self.mapping = mapping
+ self.transform = transform
+ self.test_transform = test_transform
+ self.target_transform = target_transform
+ self.train_fraction = train_fraction
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.pin_memory = pin_memory
+
+ # Placeholders
+ self.data_train: BaseDataset
+ self.data_val: BaseDataset
+ self.data_test: BaseDataset
+ self.dims: Tuple[int, ...]
+ self.output_dims: Tuple[int, ...]
@classmethod
def data_dirname(cls: T) -> Path:
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index c57cbcc..675683a 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -1,7 +1,6 @@
"""Base PyTorch Dataset class."""
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
-from attrs import define, field
import torch
from torch import Tensor
from torch.utils.data import Dataset
@@ -9,7 +8,6 @@ from torch.utils.data import Dataset
from text_recognizer.data.transforms.load_transform import load_transform_from_file
-@define
class BaseDataset(Dataset):
r"""Base Dataset class that processes data and targets through optional transfroms.
@@ -21,10 +19,19 @@ class BaseDataset(Dataset):
target transforms.
"""
- data: Union[Sequence, Tensor] = field()
- targets: Union[Sequence, Tensor] = field()
- transform: Union[Optional[Callable], str] = field(default=None)
- target_transform: Union[Optional[Callable], str] = field(default=None)
+ def __init__(
+ self,
+ data: Union[Sequence, Tensor],
+ targets: Union[Sequence, Tensor],
+ transform: Union[Optional[Callable], str],
+ target_transform: Union[Optional[Callable], str],
+ ) -> None:
+ super().__init__()
+
+ self.data = data
+ self.targets = targets
+ self.transform = transform
+ self.target_transform = target_transform
def __attrs_pre_init__(self) -> None:
"""Pre init constructor."""
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index 94882bf..1b1381a 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -6,7 +6,6 @@ import shutil
from typing import Dict, List, Optional, Sequence, Set, Tuple
import zipfile
-from attrs import define
import h5py
from loguru import logger as log
import numpy as np
@@ -35,7 +34,6 @@ ESSENTIALS_FILENAME = (
)
-@define(auto_attribs=True)
class EMNIST(BaseDataModule):
"""Lightning DataModule class for loading EMNIST dataset.
@@ -48,8 +46,8 @@ class EMNIST(BaseDataModule):
EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
"""
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
+ def __init__(self) -> None:
+ super().__init__()
self.dims = (1, *self.mapping.input_size)
def prepare_data(self) -> None:
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 43d55b9..062257d 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -1,9 +1,8 @@
"""Dataset of generated text from EMNIST characters."""
from collections import defaultdict
from pathlib import Path
-from typing import DefaultDict, List, Tuple
+from typing import Callable, DefaultDict, List, Optional, Tuple, Type
-from attrs import define, field
import h5py
from loguru import logger as log
import numpy as np
@@ -17,6 +16,7 @@ from text_recognizer.data.base_data_module import (
)
from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels
from text_recognizer.data.emnist import EMNIST
+from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.data.transforms.load_transform import load_transform_from_file
from text_recognizer.data.utils.sentence_generator import SentenceGenerator
@@ -33,22 +33,45 @@ IMAGE_X_PADDING = 28
MAX_OUTPUT_LENGTH = 89 # Same as IAMLines
-@define(auto_attribs=True, repr=False)
class EMNISTLines(BaseDataModule):
"""EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST."""
- max_length: int = field(default=128)
- min_overlap: float = field(default=0.0)
- max_overlap: float = field(default=0.33)
- num_train: int = field(default=10_000)
- num_val: int = field(default=2_000)
- num_test: int = field(default=2_000)
- emnist: EMNIST = field(init=False, default=None)
+ def __init__(
+ self,
+ mapping: Type[AbstractMapping],
+ transform: Optional[Callable] = None,
+ test_transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ train_fraction: float = 0.8,
+ batch_size: int = 16,
+ num_workers: int = 0,
+ pin_memory: bool = True,
+ max_length: int = 128,
+ min_overlap: float = 0.0,
+ max_overlap: float = 0.33,
+ num_train: int = 10_000,
+ num_val: int = 2_000,
+ num_test: int = 2_000,
+ ) -> None:
+ super().__init__(
+ mapping,
+ transform,
+ test_transform,
+ target_transform,
+ train_fraction,
+ batch_size,
+ num_workers,
+ pin_memory,
+ )
- def __attrs_post_init__(self) -> None:
- """Post init constructor."""
- self.emnist = EMNIST(mapping=self.mapping)
+ self.max_length = max_length
+ self.min_overlap = min_overlap
+ self.max_overlap = max_overlap
+ self.num_train = num_train
+ self.num_val = num_val
+ self.num_test = num_test
+ self.emnist = EMNIST(mapping=self.mapping)
max_width = (
int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap))
+ IMAGE_X_PADDING
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
index 8166863..a4d1d21 100644
--- a/text_recognizer/data/iam.py
+++ b/text_recognizer/data/iam.py
@@ -9,7 +9,6 @@ from typing import Any, Dict, List
import xml.etree.ElementTree as ElementTree
import zipfile
-from attrs import define, field
from boltons.cacheutils import cachedproperty
from loguru import logger as log
import toml
@@ -27,7 +26,6 @@ DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be.
LINE_REGION_PADDING = 16 # Add this many pixels around the exact coordinates.
-@define(auto_attribs=True)
class IAM(BaseDataModule):
r"""The IAM Lines dataset.
@@ -44,7 +42,10 @@ class IAM(BaseDataModule):
contributed to one set only.
"""
- metadata: Dict = field(init=False, default=toml.load(METADATA_FILENAME))
+ def __init__(self) -> None:
+ super().__init__()
+
+ self.metadata: Dict = toml.load(METADATA_FILENAME)
def prepare_data(self) -> None:
"""Prepares the IAM dataset."""
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 52c10c3..0c16181 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -1,19 +1,38 @@
"""IAM original and sythetic dataset class."""
-from attrs import define, field
+from typing import Callable, Optional, Type
from torch.utils.data import ConcatDataset
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.iam_paragraphs import IAMParagraphs
+from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
from text_recognizer.data.transforms.load_transform import load_transform_from_file
-@define(auto_attribs=True, repr=False)
class IAMExtendedParagraphs(BaseDataModule):
"""A dataset with synthetic and real handwritten paragraph."""
- def __attrs_post_init__(self) -> None:
- """Post init constructor."""
+ def __init__(
+ self,
+ mapping: Type[AbstractMapping],
+ transform: Optional[Callable] = None,
+ test_transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ train_fraction: float = 0.8,
+ batch_size: int = 16,
+ num_workers: int = 0,
+ pin_memory: bool = True,
+ ) -> None:
+ super().__init__(
+ mapping,
+ transform,
+ test_transform,
+ target_transform,
+ train_fraction,
+ batch_size,
+ num_workers,
+ pin_memory,
+ )
self.iam_paragraphs = IAMParagraphs(
mapping=self.mapping,
batch_size=self.batch_size,
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index 34cf605..5f38f14 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -7,7 +7,6 @@ import json
from pathlib import Path
from typing import List, Sequence, Tuple
-from attrs import define, field
from loguru import logger as log
import numpy as np
from PIL import Image, ImageFile, ImageOps
@@ -35,14 +34,13 @@ MAX_LABEL_LENGTH = 89
MAX_WORD_PIECE_LENGTH = 72
-@define(auto_attribs=True, repr=False)
class IAMLines(BaseDataModule):
"""IAM handwritten lines dataset."""
- dims: Tuple[int, int, int] = field(
- init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)
- )
- output_dims: Tuple[int, int] = field(init=False, default=(MAX_LABEL_LENGTH, 1))
+ def __init__(self) -> None:
+ super().__init__()
+ self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
+ self.output_dims = (MAX_LABEL_LENGTH, 1)
def prepare_data(self) -> None:
"""Creates the IAM lines dataset if not existing."""
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index b605bbc..dde505d 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -3,7 +3,6 @@ import json
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
-from attrs import define, field
from loguru import logger as log
import numpy as np
from PIL import Image, ImageOps
@@ -33,15 +32,13 @@ MAX_LABEL_LENGTH = 682
MAX_WORD_PIECE_LENGTH = 451
-@define(auto_attribs=True, repr=False)
class IAMParagraphs(BaseDataModule):
"""IAM handwriting database paragraphs."""
- # Placeholders
- dims: Tuple[int, int, int] = field(
- init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)
- )
- output_dims: Tuple[int, int] = field(init=False, default=(MAX_LABEL_LENGTH, 1))
+ def __init__(self) -> None:
+ super().__init__()
+ self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
+ self.output_dims = (MAX_LABEL_LENGTH, 1)
def prepare_data(self) -> None:
"""Create data for training/testing."""
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index 7143951..2e7762c 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -2,7 +2,6 @@
import random
from typing import Any, List, Sequence, Tuple
-from attrs import define
from loguru import logger as log
import numpy as np
from PIL import Image
@@ -34,10 +33,12 @@ PROCESSED_DATA_DIRNAME = (
)
-@define(auto_attribs=True, repr=False)
class IAMSyntheticParagraphs(IAMParagraphs):
"""IAM Handwriting database of synthetic paragraphs."""
+ def __init__(self) -> None:
+ super().__init__()
+
def prepare_data(self) -> None:
"""Prepare IAM lines to be used to generate paragraphs."""
if PROCESSED_DATA_DIRNAME.exists():
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index bf3bc08..63fe5a7 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -1,7 +1,6 @@
"""Base PyTorch Lightning model."""
from typing import Any, Dict, List, Optional, Tuple, Type
-from attrs import define, field
import hydra
from loguru import logger as log
from omegaconf import DictConfig
@@ -9,31 +8,34 @@ from pytorch_lightning import LightningModule
import torch
from torch import nn
from torch import Tensor
-import torchmetrics
+from torchmetrics import Accuracy
from text_recognizer.data.mappings.base import AbstractMapping
-@define(eq=False)
class BaseLitModel(LightningModule):
"""Abstract PyTorch Lightning class."""
- def __attrs_pre_init__(self) -> None:
- """Pre init constructor."""
+ def __init__(
+ self,
+ network: Type[nn.Module],
+ loss_fn: Type[nn.Module],
+ optimizer_configs: DictConfig,
+ lr_scheduler_configs: Optional[DictConfig],
+ mapping: Type[AbstractMapping],
+ ) -> None:
super().__init__()
- network: Type[nn.Module] = field()
- loss_fn: Type[nn.Module] = field()
- optimizer_configs: DictConfig = field()
- lr_scheduler_configs: Optional[DictConfig] = field()
- mapping: Type[AbstractMapping] = field()
-
- # Placeholders
- train_acc: torchmetrics.Accuracy = field(
- init=False, default=torchmetrics.Accuracy()
- )
- val_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy())
- test_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy())
+ self.network = network
+ self.loss_fn = loss_fn
+ self.optimizer_configs = optimizer_configs
+ self.lr_scheduler_configs = lr_scheduler_configs
+ self.mapping = mapping
+
+ # Placeholders
+ self.train_acc = Accuracy()
+ self.val_acc = Accuracy()
+ self.test_acc = Accuracy()
def optimizer_zero_grad(
self,
diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py
index e59a830..3cb16b5 100644
--- a/text_recognizer/models/metrics.py
+++ b/text_recognizer/models/metrics.py
@@ -1,25 +1,22 @@
"""Character Error Rate (CER)."""
-from typing import Set
+from typing import Sequence
-from attrs import define, field
import editdistance
import torch
from torch import Tensor
from torchmetrics import Metric
-@define(eq=False)
class CharacterErrorRate(Metric):
"""Character error rate metric, computed using Levenshtein distance."""
- ignore_indices: Set[Tensor] = field(converter=set)
- error: Tensor = field(init=False)
- total: Tensor = field(init=False)
-
- def __attrs_post_init__(self) -> None:
+ def __init__(self, ignore_indices: Sequence[Tensor]) -> None:
super().__init__()
+ self.ignore_indices = set(ignore_indices)
self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
+ self.error: Tensor
+ self.total: Tensor
def update(self, preds: Tensor, targets: Tensor) -> None:
"""Update CER."""
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index c5120fe..9537dd9 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,7 +1,6 @@
"""PyTorch Lightning model for base Transformers."""
from typing import Set, Tuple
-from attrs import define, field
import torch
from torch import Tensor
@@ -9,25 +8,21 @@ from text_recognizer.models.base import BaseLitModel
from text_recognizer.models.metrics import CharacterErrorRate
-@define(auto_attribs=True, eq=False)
class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- max_output_len: int = field(default=451)
- start_token: str = field(default="<s>")
- end_token: str = field(default="<e>")
- pad_token: str = field(default="<p>")
-
- start_index: int = field(init=False)
- end_index: int = field(init=False)
- pad_index: int = field(init=False)
-
- ignore_indices: Set[Tensor] = field(init=False)
- val_cer: CharacterErrorRate = field(init=False)
- test_cer: CharacterErrorRate = field(init=False)
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
+ def __init__(
+ self,
+ max_output_len: int = 451,
+ start_token: str = "<s>",
+ end_token: str = "<e>",
+ pad_token: str = "<p>",
+ ) -> None:
+ super().__init__()
+ self.max_output_len = max_output_len
+ self.start_token = start_token
+ self.end_token = end_token
+ self.pad_token = pad_token
self.start_index = int(self.mapping.get_index(self.start_token))
self.end_index = int(self.mapping.get_index(self.end_token))
self.pad_index = int(self.mapping.get_index(self.pad_token))
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py
index cf64bcf..2260ee2 100644
--- a/text_recognizer/networks/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/efficientnet/efficientnet.py
@@ -1,7 +1,6 @@
"""Efficientnet backbone."""
from typing import Tuple
-from attrs import define, field
from torch import nn, Tensor
from text_recognizer.networks.efficientnet.mbconv import MBConvBlock
@@ -12,13 +11,9 @@ from text_recognizer.networks.efficientnet.utils import (
)
-@define(eq=False)
class EfficientNet(nn.Module):
"""Efficientnet without classification head."""
- def __attrs_pre_init__(self) -> None:
- super().__init__()
-
archs = {
# width, depth, dropout
"b0": (1.0, 1.0, 0.2),
@@ -33,32 +28,32 @@ class EfficientNet(nn.Module):
"l2": (4.3, 5.3, 0.5),
}
- arch: str = field()
- params: Tuple[float, float, float] = field(default=None, init=False)
- stochastic_dropout_rate: float = field(default=0.2)
- bn_momentum: float = field(default=0.99)
- bn_eps: float = field(default=1.0e-3)
- depth: int = field(default=7)
- out_channels: int = field(default=None, init=False)
- _conv_stem: nn.Sequential = field(default=None, init=False)
- _blocks: nn.ModuleList = field(default=None, init=False)
- _conv_head: nn.Sequential = field(default=None, init=False)
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
+ def __init__(
+ self,
+ arch: str,
+ params: Tuple[float, float, float],
+ stochastic_dropout_rate: float = 0.2,
+ bn_momentum: float = 0.99,
+ bn_eps: float = 1.0e-3,
+ depth: int = 7,
+ ) -> None:
+ super().__init__()
+ self.params = self._get_arch_params(arch)
+ self.stochastic_dropout_rate = stochastic_dropout_rate
+ self.bn_momentum = bn_momentum
+ self.bn_eps = bn_eps
+ self.depth = depth
+ self.out_channels: int
+ self._conv_stem: nn.Sequential
+ self._blocks: nn.ModuleList
+ self._conv_head: nn.Sequential
self._build()
- @depth.validator
- def _check_depth(self, attribute, value: str) -> None:
- if not 5 <= value <= 7:
- raise ValueError(f"Depth has to be between 5 and 7, was: {value}")
-
- @arch.validator
- def _check_arch(self, attribute, value: str) -> None:
+ def _get_arch_params(self, value: str) -> Tuple[float, float, float]:
"""Validates the efficientnet architecure."""
if value not in self.archs:
raise ValueError(f"{value} not a valid architecure.")
- self.params = self.archs[value]
+ return self.archs[value]
def _build(self) -> None:
"""Builds the efficientnet backbone."""
diff --git a/text_recognizer/networks/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py
index 98e9353..64debd9 100644
--- a/text_recognizer/networks/efficientnet/mbconv.py
+++ b/text_recognizer/networks/efficientnet/mbconv.py
@@ -1,7 +1,6 @@
"""Mobile inverted residual block."""
from typing import Optional, Tuple, Union
-from attrs import define, field
import torch
from torch import nn, Tensor
import torch.nn.functional as F
@@ -14,18 +13,15 @@ def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]:
return (stride,) * 2 if isinstance(stride, int) else stride
-@define(eq=False)
class BaseModule(nn.Module):
"""Base sub module class."""
- bn_momentum: float = field()
- bn_eps: float = field()
- block: nn.Sequential = field(init=False)
-
- def __attrs_pre_init__(self) -> None:
+ def __init__(self, bn_momentum: float, bn_eps: float, block: nn.Sequential) -> None:
super().__init__()
- def __attrs_post_init__(self) -> None:
+ self.bn_momentum = bn_momentum
+ self.bn_eps = bn_eps
+ self.block = block
self._build()
def _build(self) -> None:
@@ -36,12 +32,20 @@ class BaseModule(nn.Module):
return self.block(x)
-@define(auto_attribs=True, eq=False)
class InvertedBottleneck(BaseModule):
"""Inverted bottleneck module."""
- in_channels: int = field()
- out_channels: int = field()
+ def __init__(
+ self,
+ bn_momentum: float,
+ bn_eps: float,
+ block: nn.Sequential,
+ in_channels: int,
+ out_channels: int,
+ ) -> None:
+ super().__init__(bn_momentum, bn_eps, block)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
def _build(self) -> None:
self.block = nn.Sequential(
@@ -60,13 +64,22 @@ class InvertedBottleneck(BaseModule):
)
-@define(auto_attribs=True, eq=False)
class Depthwise(BaseModule):
"""Depthwise convolution module."""
- channels: int = field()
- kernel_size: int = field()
- stride: int = field()
+ def __init__(
+ self,
+ bn_momentum: float,
+ bn_eps: float,
+ block: nn.Sequential,
+ channels: int,
+ kernel_size: int,
+ stride: int,
+ ) -> None:
+ super().__init__(bn_momentum, bn_eps, block)
+ self.channels = channels
+ self.kernel_size = kernel_size
+ self.stride = stride
def _build(self) -> None:
self.block = nn.Sequential(
@@ -85,13 +98,23 @@ class Depthwise(BaseModule):
)
-@define(auto_attribs=True, eq=False)
class SqueezeAndExcite(BaseModule):
"""Sequeeze and excite module."""
- in_channels: int = field()
- channels: int = field()
- se_ratio: float = field()
+ def __init__(
+ self,
+ bn_momentum: float,
+ bn_eps: float,
+ block: nn.Sequential,
+ in_channels: int,
+ channels: int,
+ se_ratio: float,
+ ) -> None:
+ super().__init__(bn_momentum, bn_eps, block)
+
+ self.in_channels = in_channels
+ self.channels = channels
+ self.se_ratio = se_ratio
def _build(self) -> None:
num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio))
@@ -110,12 +133,20 @@ class SqueezeAndExcite(BaseModule):
)
-@define(auto_attribs=True, eq=False)
class Pointwise(BaseModule):
"""Pointwise module."""
- in_channels: int = field()
- out_channels: int = field()
+ def __init__(
+ self,
+ bn_momentum: float,
+ bn_eps: float,
+ block: nn.Sequential,
+ in_channels: int,
+ out_channels: int,
+ ) -> None:
+ super().__init__(bn_momentum, bn_eps, block)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
def _build(self) -> None:
self.block = nn.Sequential(
@@ -133,28 +164,36 @@ class Pointwise(BaseModule):
)
-@define(eq=False)
class MBConvBlock(nn.Module):
"""Mobile Inverted Residual Bottleneck block."""
- def __attrs_pre_init__(self) -> None:
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Tuple[int, int],
+ stride: Tuple[int, int],
+ bn_momentum: float,
+ bn_eps: float,
+ se_ratio: float,
+ expand_ratio: int,
+ ) -> None:
super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.bn_momentum = bn_momentum
+ self.bn_eps = bn_eps
+ self.se_ratio = se_ratio
+ self.expand_ratio = expand_ratio
+ self.pad = self._configure_padding()
+ self._inverted_bottleneck: Optional[InvertedBottleneck]
+ self._depthwise: nn.Sequential
+ self._squeeze_excite: nn.Sequential
+ self._pointwise: nn.Sequential
+ self._build()
- in_channels: int = field()
- out_channels: int = field()
- kernel_size: Tuple[int, int] = field()
- stride: Tuple[int, int] = field(converter=_convert_stride)
- bn_momentum: float = field()
- bn_eps: float = field()
- se_ratio: float = field()
- expand_ratio: int = field()
- pad: Tuple[int, int, int, int] = field(init=False)
- _inverted_bottleneck: Optional[InvertedBottleneck] = field(init=False)
- _depthwise: nn.Sequential = field(init=False)
- _squeeze_excite: nn.Sequential = field(init=False)
- _pointwise: nn.Sequential = field(init=False)
-
- @pad.default
def _configure_padding(self) -> Tuple[int, int, int, int]:
"""Set padding for convolutional layers."""
if self.stride == (2, 2):
@@ -164,10 +203,6 @@ class MBConvBlock(nn.Module):
) * 2
return ((self.kernel_size - 1) // 2,) * 4
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
- self._build()
-
def _build(self) -> None:
has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0
inner_channels = self.in_channels * self.expand_ratio