summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/data/base_data_module.py32
-rw-r--r--text_recognizer/data/base_dataset.py12
-rw-r--r--text_recognizer/data/emnist.py4
-rw-r--r--text_recognizer/data/emnist_lines.py18
-rw-r--r--text_recognizer/data/iam.py6
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py4
-rw-r--r--text_recognizer/data/iam_lines.py8
-rw-r--r--text_recognizer/data/iam_paragraphs.py13
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py4
-rw-r--r--text_recognizer/models/base.py24
-rw-r--r--text_recognizer/models/metrics.py10
-rw-r--r--text_recognizer/models/transformer.py24
-rw-r--r--text_recognizer/networks/efficientnet/efficientnet.py32
-rw-r--r--text_recognizer/networks/efficientnet/mbconv.py71
-rw-r--r--text_recognizer/networks/transformer/attention.py24
15 files changed, 144 insertions, 142 deletions
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index 15c286a..a0c8416 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -2,7 +2,7 @@
from pathlib import Path
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
-import attr
+from attrs import define, field
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
@@ -20,29 +20,29 @@ def load_and_print_info(data_module_class: type) -> None:
print(dataset)
-@attr.s(repr=False)
+@define(repr=False)
class BaseDataModule(LightningDataModule):
"""Base PyTorch Lightning DataModule."""
- def __attrs_pre_init__(self) -> None:
+ def __attrs_post_init__(self) -> None:
"""Pre init constructor."""
super().__init__()
- mapping: Type[AbstractMapping] = attr.ib()
- transform: Optional[Callable] = attr.ib(default=None)
- test_transform: Optional[Callable] = attr.ib(default=None)
- target_transform: Optional[Callable] = attr.ib(default=None)
- train_fraction: float = attr.ib(default=0.8)
- batch_size: int = attr.ib(default=16)
- num_workers: int = attr.ib(default=0)
- pin_memory: bool = attr.ib(default=True)
+ mapping: Type[AbstractMapping] = field()
+ transform: Optional[Callable] = field(default=None)
+ test_transform: Optional[Callable] = field(default=None)
+ target_transform: Optional[Callable] = field(default=None)
+ train_fraction: float = field(default=0.8)
+ batch_size: int = field(default=16)
+ num_workers: int = field(default=0)
+ pin_memory: bool = field(default=True)
# Placeholders
- data_train: BaseDataset = attr.ib(init=False, default=None)
- data_val: BaseDataset = attr.ib(init=False, default=None)
- data_test: BaseDataset = attr.ib(init=False, default=None)
- dims: Tuple[int, ...] = attr.ib(init=False, default=None)
- output_dims: Tuple[int, ...] = attr.ib(init=False, default=None)
+ data_train: BaseDataset = field(init=False, default=None)
+ data_val: BaseDataset = field(init=False, default=None)
+ data_test: BaseDataset = field(init=False, default=None)
+ dims: Tuple[int, ...] = field(init=False, default=None)
+ output_dims: Tuple[int, ...] = field(init=False, default=None)
@classmethod
def data_dirname(cls: T) -> Path:
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index b9567c7..c57cbcc 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -1,7 +1,7 @@
"""Base PyTorch Dataset class."""
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
-import attr
+from attrs import define, field
import torch
from torch import Tensor
from torch.utils.data import Dataset
@@ -9,7 +9,7 @@ from torch.utils.data import Dataset
from text_recognizer.data.transforms.load_transform import load_transform_from_file
-@attr.s
+@define
class BaseDataset(Dataset):
r"""Base Dataset class that processes data and targets through optional transfroms.
@@ -21,10 +21,10 @@ class BaseDataset(Dataset):
target transforms.
"""
- data: Union[Sequence, Tensor] = attr.ib()
- targets: Union[Sequence, Tensor] = attr.ib()
- transform: Union[Optional[Callable], str] = attr.ib(default=None)
- target_transform: Union[Optional[Callable], str] = attr.ib(default=None)
+ data: Union[Sequence, Tensor] = field()
+ targets: Union[Sequence, Tensor] = field()
+ transform: Union[Optional[Callable], str] = field(default=None)
+ target_transform: Union[Optional[Callable], str] = field(default=None)
def __attrs_pre_init__(self) -> None:
"""Pre init constructor."""
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index dc8d31a..94882bf 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -6,7 +6,7 @@ import shutil
from typing import Dict, List, Optional, Sequence, Set, Tuple
import zipfile
-import attr
+from attrs import define
import h5py
from loguru import logger as log
import numpy as np
@@ -35,7 +35,7 @@ ESSENTIALS_FILENAME = (
)
-@attr.s(auto_attribs=True)
+@define(auto_attribs=True)
class EMNIST(BaseDataModule):
"""Lightning DataModule class for loading EMNIST dataset.
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index c267286..43d55b9 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -3,7 +3,7 @@ from collections import defaultdict
from pathlib import Path
from typing import DefaultDict, List, Tuple
-import attr
+from attrs import define, field
import h5py
from loguru import logger as log
import numpy as np
@@ -33,17 +33,17 @@ IMAGE_X_PADDING = 28
MAX_OUTPUT_LENGTH = 89 # Same as IAMLines
-@attr.s(auto_attribs=True, repr=False)
+@define(auto_attribs=True, repr=False)
class EMNISTLines(BaseDataModule):
"""EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST."""
- max_length: int = attr.ib(default=128)
- min_overlap: float = attr.ib(default=0.0)
- max_overlap: float = attr.ib(default=0.33)
- num_train: int = attr.ib(default=10_000)
- num_val: int = attr.ib(default=2_000)
- num_test: int = attr.ib(default=2_000)
- emnist: EMNIST = attr.ib(init=False, default=None)
+ max_length: int = field(default=128)
+ min_overlap: float = field(default=0.0)
+ max_overlap: float = field(default=0.33)
+ num_train: int = field(default=10_000)
+ num_val: int = field(default=2_000)
+ num_test: int = field(default=2_000)
+ emnist: EMNIST = field(init=False, default=None)
def __attrs_post_init__(self) -> None:
"""Post init constructor."""
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
index 766f3e0..8166863 100644
--- a/text_recognizer/data/iam.py
+++ b/text_recognizer/data/iam.py
@@ -9,7 +9,7 @@ from typing import Any, Dict, List
import xml.etree.ElementTree as ElementTree
import zipfile
-import attr
+from attrs import define, field
from boltons.cacheutils import cachedproperty
from loguru import logger as log
import toml
@@ -27,7 +27,7 @@ DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be.
LINE_REGION_PADDING = 16 # Add this many pixels around the exact coordinates.
-@attr.s(auto_attribs=True)
+@define(auto_attribs=True)
class IAM(BaseDataModule):
r"""The IAM Lines dataset.
@@ -44,7 +44,7 @@ class IAM(BaseDataModule):
contributed to one set only.
"""
- metadata: Dict = attr.ib(init=False, default=toml.load(METADATA_FILENAME))
+ metadata: Dict = field(init=False, default=toml.load(METADATA_FILENAME))
def prepare_data(self) -> None:
"""Prepares the IAM dataset."""
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 22d00f1..52c10c3 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -1,5 +1,5 @@
"""IAM original and sythetic dataset class."""
-import attr
+from attrs import define, field
from torch.utils.data import ConcatDataset
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
@@ -8,7 +8,7 @@ from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
from text_recognizer.data.transforms.load_transform import load_transform_from_file
-@attr.s(auto_attribs=True, repr=False)
+@define(auto_attribs=True, repr=False)
class IAMExtendedParagraphs(BaseDataModule):
"""A dataset with synthetic and real handwritten paragraph."""
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index a79c202..34cf605 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -7,7 +7,7 @@ import json
from pathlib import Path
from typing import List, Sequence, Tuple
-import attr
+from attrs import define, field
from loguru import logger as log
import numpy as np
from PIL import Image, ImageFile, ImageOps
@@ -35,14 +35,14 @@ MAX_LABEL_LENGTH = 89
MAX_WORD_PIECE_LENGTH = 72
-@attr.s(auto_attribs=True, repr=False)
+@define(auto_attribs=True, repr=False)
class IAMLines(BaseDataModule):
"""IAM handwritten lines dataset."""
- dims: Tuple[int, int, int] = attr.ib(
+ dims: Tuple[int, int, int] = field(
init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)
)
- output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1))
+ output_dims: Tuple[int, int] = field(init=False, default=(MAX_LABEL_LENGTH, 1))
def prepare_data(self) -> None:
"""Creates the IAM lines dataset if not existing."""
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 033b93e..b605bbc 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -3,7 +3,7 @@ import json
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
-import attr
+from attrs import define, field
from loguru import logger as log
import numpy as np
from PIL import Image, ImageOps
@@ -33,15 +33,15 @@ MAX_LABEL_LENGTH = 682
MAX_WORD_PIECE_LENGTH = 451
-@attr.s(auto_attribs=True, repr=False)
+@define(auto_attribs=True, repr=False)
class IAMParagraphs(BaseDataModule):
"""IAM handwriting database paragraphs."""
# Placeholders
- dims: Tuple[int, int, int] = attr.ib(
+ dims: Tuple[int, int, int] = field(
init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)
)
- output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1))
+ output_dims: Tuple[int, int] = field(init=False, default=(MAX_LABEL_LENGTH, 1))
def prepare_data(self) -> None:
"""Create data for training/testing."""
@@ -86,7 +86,10 @@ class IAMParagraphs(BaseDataModule):
length=self.output_dims[0],
)
return BaseDataset(
- data, targets, transform=transform, target_transform=target_transform,
+ data,
+ targets,
+ transform=transform,
+ target_transform=target_transform,
)
log.info(f"Loading IAM paragraph regions and lines for {stage}...")
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index ea59098..7143951 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -2,7 +2,7 @@
import random
from typing import Any, List, Sequence, Tuple
-import attr
+from attrs import define
from loguru import logger as log
import numpy as np
from PIL import Image
@@ -34,7 +34,7 @@ PROCESSED_DATA_DIRNAME = (
)
-@attr.s(auto_attribs=True, repr=False)
+@define(auto_attribs=True, repr=False)
class IAMSyntheticParagraphs(IAMParagraphs):
"""IAM Handwriting database of synthetic paragraphs."""
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 821cb69..bf3bc08 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -1,7 +1,7 @@
"""Base PyTorch Lightning model."""
from typing import Any, Dict, List, Optional, Tuple, Type
-import attr
+from attrs import define, field
import hydra
from loguru import logger as log
from omegaconf import DictConfig
@@ -14,7 +14,7 @@ import torchmetrics
from text_recognizer.data.mappings.base import AbstractMapping
-@attr.s(eq=False)
+@define(eq=False)
class BaseLitModel(LightningModule):
"""Abstract PyTorch Lightning class."""
@@ -22,22 +22,18 @@ class BaseLitModel(LightningModule):
"""Pre init constructor."""
super().__init__()
- network: Type[nn.Module] = attr.ib()
- loss_fn: Type[nn.Module] = attr.ib()
- optimizer_configs: DictConfig = attr.ib()
- lr_scheduler_configs: Optional[DictConfig] = attr.ib()
- mapping: Type[AbstractMapping] = attr.ib()
+ network: Type[nn.Module] = field()
+ loss_fn: Type[nn.Module] = field()
+ optimizer_configs: DictConfig = field()
+ lr_scheduler_configs: Optional[DictConfig] = field()
+ mapping: Type[AbstractMapping] = field()
# Placeholders
- train_acc: torchmetrics.Accuracy = attr.ib(
- init=False, default=torchmetrics.Accuracy()
- )
- val_acc: torchmetrics.Accuracy = attr.ib(
- init=False, default=torchmetrics.Accuracy()
- )
- test_acc: torchmetrics.Accuracy = attr.ib(
+ train_acc: torchmetrics.Accuracy = field(
init=False, default=torchmetrics.Accuracy()
)
+ val_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy())
+ test_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy())
def optimizer_zero_grad(
self,
diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py
index f83c9e4..e59a830 100644
--- a/text_recognizer/models/metrics.py
+++ b/text_recognizer/models/metrics.py
@@ -1,20 +1,20 @@
"""Character Error Rate (CER)."""
from typing import Set
-import attr
+from attrs import define, field
import editdistance
import torch
from torch import Tensor
from torchmetrics import Metric
-@attr.s(eq=False)
+@define(eq=False)
class CharacterErrorRate(Metric):
"""Character error rate metric, computed using Levenshtein distance."""
- ignore_indices: Set[Tensor] = attr.ib(converter=set)
- error: Tensor = attr.ib(init=False)
- total: Tensor = attr.ib(init=False)
+ ignore_indices: Set[Tensor] = field(converter=set)
+ error: Tensor = field(init=False)
+ total: Tensor = field(init=False)
def __attrs_post_init__(self) -> None:
super().__init__()
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 7272f46..c5120fe 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,7 +1,7 @@
"""PyTorch Lightning model for base Transformers."""
from typing import Set, Tuple
-import attr
+from attrs import define, field
import torch
from torch import Tensor
@@ -9,22 +9,22 @@ from text_recognizer.models.base import BaseLitModel
from text_recognizer.models.metrics import CharacterErrorRate
-@attr.s(auto_attribs=True, eq=False)
+@define(auto_attribs=True, eq=False)
class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- max_output_len: int = attr.ib(default=451)
- start_token: str = attr.ib(default="<s>")
- end_token: str = attr.ib(default="<e>")
- pad_token: str = attr.ib(default="<p>")
+ max_output_len: int = field(default=451)
+ start_token: str = field(default="<s>")
+ end_token: str = field(default="<e>")
+ pad_token: str = field(default="<p>")
- start_index: int = attr.ib(init=False)
- end_index: int = attr.ib(init=False)
- pad_index: int = attr.ib(init=False)
+ start_index: int = field(init=False)
+ end_index: int = field(init=False)
+ pad_index: int = field(init=False)
- ignore_indices: Set[Tensor] = attr.ib(init=False)
- val_cer: CharacterErrorRate = attr.ib(init=False)
- test_cer: CharacterErrorRate = attr.ib(init=False)
+ ignore_indices: Set[Tensor] = field(init=False)
+ val_cer: CharacterErrorRate = field(init=False)
+ test_cer: CharacterErrorRate = field(init=False)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py
index 4c9ed75..cf64bcf 100644
--- a/text_recognizer/networks/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/efficientnet/efficientnet.py
@@ -1,7 +1,7 @@
"""Efficientnet backbone."""
from typing import Tuple
-import attr
+from attrs import define, field
from torch import nn, Tensor
from text_recognizer.networks.efficientnet.mbconv import MBConvBlock
@@ -12,7 +12,7 @@ from text_recognizer.networks.efficientnet.utils import (
)
-@attr.s(eq=False)
+@define(eq=False)
class EfficientNet(nn.Module):
"""Efficientnet without classification head."""
@@ -33,28 +33,28 @@ class EfficientNet(nn.Module):
"l2": (4.3, 5.3, 0.5),
}
- arch: str = attr.ib()
- params: Tuple[float, float, float] = attr.ib(default=None, init=False)
- stochastic_dropout_rate: float = attr.ib(default=0.2)
- bn_momentum: float = attr.ib(default=0.99)
- bn_eps: float = attr.ib(default=1.0e-3)
- depth: int = attr.ib(default=7)
- out_channels: int = attr.ib(default=None, init=False)
- _conv_stem: nn.Sequential = attr.ib(default=None, init=False)
- _blocks: nn.ModuleList = attr.ib(default=None, init=False)
- _conv_head: nn.Sequential = attr.ib(default=None, init=False)
+ arch: str = field()
+ params: Tuple[float, float, float] = field(default=None, init=False)
+ stochastic_dropout_rate: float = field(default=0.2)
+ bn_momentum: float = field(default=0.99)
+ bn_eps: float = field(default=1.0e-3)
+ depth: int = field(default=7)
+ out_channels: int = field(default=None, init=False)
+ _conv_stem: nn.Sequential = field(default=None, init=False)
+ _blocks: nn.ModuleList = field(default=None, init=False)
+ _conv_head: nn.Sequential = field(default=None, init=False)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
self._build()
@depth.validator
- def _check_depth(self, attribute: attr._make.Attribute, value: str) -> None:
+ def _check_depth(self, attribute, value: str) -> None:
if not 5 <= value <= 7:
raise ValueError(f"Depth has to be between 5 and 7, was: {value}")
@arch.validator
- def _check_arch(self, attribute: attr._make.Attribute, value: str) -> None:
+ def _check_arch(self, attribute, value: str) -> None:
"""Validates the efficientnet architecure."""
if value not in self.archs:
raise ValueError(f"{value} not a valid architecure.")
@@ -88,7 +88,9 @@ class EfficientNet(nn.Module):
for _ in range(num_repeats):
self._blocks.append(
MBConvBlock(
- **args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps,
+ **args,
+ bn_momentum=self.bn_momentum,
+ bn_eps=self.bn_eps,
)
)
args.in_channels = args.out_channels
diff --git a/text_recognizer/networks/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py
index beb7d57..98e9353 100644
--- a/text_recognizer/networks/efficientnet/mbconv.py
+++ b/text_recognizer/networks/efficientnet/mbconv.py
@@ -1,7 +1,7 @@
"""Mobile inverted residual block."""
from typing import Optional, Tuple, Union
-import attr
+from attrs import define, field
import torch
from torch import nn, Tensor
import torch.nn.functional as F
@@ -14,13 +14,13 @@ def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]:
return (stride,) * 2 if isinstance(stride, int) else stride
-@attr.s(eq=False)
+@define(eq=False)
class BaseModule(nn.Module):
"""Base sub module class."""
- bn_momentum: float = attr.ib()
- bn_eps: float = attr.ib()
- block: nn.Sequential = attr.ib(init=False)
+ bn_momentum: float = field()
+ bn_eps: float = field()
+ block: nn.Sequential = field(init=False)
def __attrs_pre_init__(self) -> None:
super().__init__()
@@ -36,12 +36,12 @@ class BaseModule(nn.Module):
return self.block(x)
-@attr.s(auto_attribs=True, eq=False)
+@define(auto_attribs=True, eq=False)
class InvertedBottleneck(BaseModule):
"""Inverted bottleneck module."""
- in_channels: int = attr.ib()
- out_channels: int = attr.ib()
+ in_channels: int = field()
+ out_channels: int = field()
def _build(self) -> None:
self.block = nn.Sequential(
@@ -60,13 +60,13 @@ class InvertedBottleneck(BaseModule):
)
-@attr.s(auto_attribs=True, eq=False)
+@define(auto_attribs=True, eq=False)
class Depthwise(BaseModule):
"""Depthwise convolution module."""
- channels: int = attr.ib()
- kernel_size: int = attr.ib()
- stride: int = attr.ib()
+ channels: int = field()
+ kernel_size: int = field()
+ stride: int = field()
def _build(self) -> None:
self.block = nn.Sequential(
@@ -85,13 +85,13 @@ class Depthwise(BaseModule):
)
-@attr.s(auto_attribs=True, eq=False)
+@define(auto_attribs=True, eq=False)
class SqueezeAndExcite(BaseModule):
"""Sequeeze and excite module."""
- in_channels: int = attr.ib()
- channels: int = attr.ib()
- se_ratio: float = attr.ib()
+ in_channels: int = field()
+ channels: int = field()
+ se_ratio: float = field()
def _build(self) -> None:
num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio))
@@ -110,12 +110,12 @@ class SqueezeAndExcite(BaseModule):
)
-@attr.s(auto_attribs=True, eq=False)
+@define(auto_attribs=True, eq=False)
class Pointwise(BaseModule):
"""Pointwise module."""
- in_channels: int = attr.ib()
- out_channels: int = attr.ib()
+ in_channels: int = field()
+ out_channels: int = field()
def _build(self) -> None:
self.block = nn.Sequential(
@@ -133,32 +133,35 @@ class Pointwise(BaseModule):
)
-@attr.s(eq=False)
+@define(eq=False)
class MBConvBlock(nn.Module):
"""Mobile Inverted Residual Bottleneck block."""
def __attrs_pre_init__(self) -> None:
super().__init__()
- in_channels: int = attr.ib()
- out_channels: int = attr.ib()
- kernel_size: Tuple[int, int] = attr.ib()
- stride: Tuple[int, int] = attr.ib(converter=_convert_stride)
- bn_momentum: float = attr.ib()
- bn_eps: float = attr.ib()
- se_ratio: float = attr.ib()
- expand_ratio: int = attr.ib()
- pad: Tuple[int, int, int, int] = attr.ib(init=False)
- _inverted_bottleneck: Optional[InvertedBottleneck] = attr.ib(init=False)
- _depthwise: nn.Sequential = attr.ib(init=False)
- _squeeze_excite: nn.Sequential = attr.ib(init=False)
- _pointwise: nn.Sequential = attr.ib(init=False)
+ in_channels: int = field()
+ out_channels: int = field()
+ kernel_size: Tuple[int, int] = field()
+ stride: Tuple[int, int] = field(converter=_convert_stride)
+ bn_momentum: float = field()
+ bn_eps: float = field()
+ se_ratio: float = field()
+ expand_ratio: int = field()
+ pad: Tuple[int, int, int, int] = field(init=False)
+ _inverted_bottleneck: Optional[InvertedBottleneck] = field(init=False)
+ _depthwise: nn.Sequential = field(init=False)
+ _squeeze_excite: nn.Sequential = field(init=False)
+ _pointwise: nn.Sequential = field(init=False)
@pad.default
def _configure_padding(self) -> Tuple[int, int, int, int]:
"""Set padding for convolutional layers."""
if self.stride == (2, 2):
- return ((self.kernel_size - 1) // 2 - 1, (self.kernel_size - 1) // 2,) * 2
+ return (
+ (self.kernel_size - 1) // 2 - 1,
+ (self.kernel_size - 1) // 2,
+ ) * 2
return ((self.kernel_size - 1) // 2,) * 4
def __attrs_post_init__(self) -> None:
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 87792a9..aa15b88 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -1,7 +1,7 @@
"""Implementes the attention module for the transformer."""
from typing import Optional, Tuple
-import attr
+from attrs import define, field
from einops import rearrange
import torch
from torch import einsum
@@ -15,22 +15,22 @@ from text_recognizer.networks.transformer.embeddings.rotary import (
)
-@attr.s(eq=False)
+@define(eq=False)
class Attention(nn.Module):
"""Standard attention."""
def __attrs_pre_init__(self) -> None:
super().__init__()
- dim: int = attr.ib()
- num_heads: int = attr.ib()
- causal: bool = attr.ib(default=False)
- dim_head: int = attr.ib(default=64)
- dropout_rate: float = attr.ib(default=0.0)
- rotary_embedding: Optional[RotaryEmbedding] = attr.ib(default=None)
- scale: float = attr.ib(init=False)
- dropout: nn.Dropout = attr.ib(init=False)
- fc: nn.Linear = attr.ib(init=False)
+ dim: int = field()
+ num_heads: int = field()
+ causal: bool = field(default=False)
+ dim_head: int = field(default=64)
+ dropout_rate: float = field(default=0.0)
+ rotary_embedding: Optional[RotaryEmbedding] = field(default=None)
+ scale: float = field(init=False)
+ dropout: nn.Dropout = field(init=False)
+ fc: nn.Linear = field(init=False)
def __attrs_post_init__(self) -> None:
self.scale = self.dim ** -0.5
@@ -120,7 +120,6 @@ def apply_input_mask(
input_mask = q_mask * k_mask
energy = energy.masked_fill_(~input_mask, mask_value)
- del input_mask
return energy
@@ -133,5 +132,4 @@ def apply_causal_mask(
mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j")
mask = F.pad(mask, (j - i, 0), value=False)
energy.masked_fill_(mask, mask_value)
- del mask
return energy