From db86cef2d308f58325278061c6aa177a535e7e03 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 1 Jun 2022 23:10:12 +0200 Subject: Replace attr with attrs --- text_recognizer/data/base_data_module.py | 32 +++++----- text_recognizer/data/base_dataset.py | 12 ++-- text_recognizer/data/emnist.py | 4 +- text_recognizer/data/emnist_lines.py | 18 +++--- text_recognizer/data/iam.py | 6 +- text_recognizer/data/iam_extended_paragraphs.py | 4 +- text_recognizer/data/iam_lines.py | 8 +-- text_recognizer/data/iam_paragraphs.py | 13 ++-- text_recognizer/data/iam_synthetic_paragraphs.py | 4 +- text_recognizer/models/base.py | 24 +++----- text_recognizer/models/metrics.py | 10 +-- text_recognizer/models/transformer.py | 24 ++++---- .../networks/efficientnet/efficientnet.py | 32 +++++----- text_recognizer/networks/efficientnet/mbconv.py | 71 +++++++++++----------- text_recognizer/networks/transformer/attention.py | 24 ++++---- 15 files changed, 144 insertions(+), 142 deletions(-) (limited to 'text_recognizer') diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index 15c286a..a0c8416 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Callable, Dict, Optional, Tuple, Type, TypeVar -import attr +from attrs import define, field from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader @@ -20,29 +20,29 @@ def load_and_print_info(data_module_class: type) -> None: print(dataset) -@attr.s(repr=False) +@define(repr=False) class BaseDataModule(LightningDataModule): """Base PyTorch Lightning DataModule.""" - def __attrs_pre_init__(self) -> None: + def __attrs_post_init__(self) -> None: """Pre init constructor.""" super().__init__() - mapping: Type[AbstractMapping] = attr.ib() - transform: Optional[Callable] = attr.ib(default=None) - test_transform: Optional[Callable] = attr.ib(default=None) - target_transform: Optional[Callable] = attr.ib(default=None) - train_fraction: float = attr.ib(default=0.8) - batch_size: int = attr.ib(default=16) - num_workers: int = attr.ib(default=0) - pin_memory: bool = attr.ib(default=True) + mapping: Type[AbstractMapping] = field() + transform: Optional[Callable] = field(default=None) + test_transform: Optional[Callable] = field(default=None) + target_transform: Optional[Callable] = field(default=None) + train_fraction: float = field(default=0.8) + batch_size: int = field(default=16) + num_workers: int = field(default=0) + pin_memory: bool = field(default=True) # Placeholders - data_train: BaseDataset = attr.ib(init=False, default=None) - data_val: BaseDataset = attr.ib(init=False, default=None) - data_test: BaseDataset = attr.ib(init=False, default=None) - dims: Tuple[int, ...] = attr.ib(init=False, default=None) - output_dims: Tuple[int, ...] = attr.ib(init=False, default=None) + data_train: BaseDataset = field(init=False, default=None) + data_val: BaseDataset = field(init=False, default=None) + data_test: BaseDataset = field(init=False, default=None) + dims: Tuple[int, ...] = field(init=False, default=None) + output_dims: Tuple[int, ...] = field(init=False, default=None) @classmethod def data_dirname(cls: T) -> Path: diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index b9567c7..c57cbcc 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -1,7 +1,7 @@ """Base PyTorch Dataset class.""" from typing import Callable, Dict, Optional, Sequence, Tuple, Union -import attr +from attrs import define, field import torch from torch import Tensor from torch.utils.data import Dataset @@ -9,7 +9,7 @@ from torch.utils.data import Dataset from text_recognizer.data.transforms.load_transform import load_transform_from_file -@attr.s +@define class BaseDataset(Dataset): r"""Base Dataset class that processes data and targets through optional transfroms. @@ -21,10 +21,10 @@ class BaseDataset(Dataset): target transforms. """ - data: Union[Sequence, Tensor] = attr.ib() - targets: Union[Sequence, Tensor] = attr.ib() - transform: Union[Optional[Callable], str] = attr.ib(default=None) - target_transform: Union[Optional[Callable], str] = attr.ib(default=None) + data: Union[Sequence, Tensor] = field() + targets: Union[Sequence, Tensor] = field() + transform: Union[Optional[Callable], str] = field(default=None) + target_transform: Union[Optional[Callable], str] = field(default=None) def __attrs_pre_init__(self) -> None: """Pre init constructor.""" diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index dc8d31a..94882bf 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -6,7 +6,7 @@ import shutil from typing import Dict, List, Optional, Sequence, Set, Tuple import zipfile -import attr +from attrs import define import h5py from loguru import logger as log import numpy as np @@ -35,7 +35,7 @@ ESSENTIALS_FILENAME = ( ) -@attr.s(auto_attribs=True) +@define(auto_attribs=True) class EMNIST(BaseDataModule): """Lightning DataModule class for loading EMNIST dataset. diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index c267286..43d55b9 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -3,7 +3,7 @@ from collections import defaultdict from pathlib import Path from typing import DefaultDict, List, Tuple -import attr +from attrs import define, field import h5py from loguru import logger as log import numpy as np @@ -33,17 +33,17 @@ IMAGE_X_PADDING = 28 MAX_OUTPUT_LENGTH = 89 # Same as IAMLines -@attr.s(auto_attribs=True, repr=False) +@define(auto_attribs=True, repr=False) class EMNISTLines(BaseDataModule): """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST.""" - max_length: int = attr.ib(default=128) - min_overlap: float = attr.ib(default=0.0) - max_overlap: float = attr.ib(default=0.33) - num_train: int = attr.ib(default=10_000) - num_val: int = attr.ib(default=2_000) - num_test: int = attr.ib(default=2_000) - emnist: EMNIST = attr.ib(init=False, default=None) + max_length: int = field(default=128) + min_overlap: float = field(default=0.0) + max_overlap: float = field(default=0.33) + num_train: int = field(default=10_000) + num_val: int = field(default=2_000) + num_test: int = field(default=2_000) + emnist: EMNIST = field(init=False, default=None) def __attrs_post_init__(self) -> None: """Post init constructor.""" diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py index 766f3e0..8166863 100644 --- a/text_recognizer/data/iam.py +++ b/text_recognizer/data/iam.py @@ -9,7 +9,7 @@ from typing import Any, Dict, List import xml.etree.ElementTree as ElementTree import zipfile -import attr +from attrs import define, field from boltons.cacheutils import cachedproperty from loguru import logger as log import toml @@ -27,7 +27,7 @@ DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be. LINE_REGION_PADDING = 16 # Add this many pixels around the exact coordinates. -@attr.s(auto_attribs=True) +@define(auto_attribs=True) class IAM(BaseDataModule): r"""The IAM Lines dataset. @@ -44,7 +44,7 @@ class IAM(BaseDataModule): contributed to one set only. """ - metadata: Dict = attr.ib(init=False, default=toml.load(METADATA_FILENAME)) + metadata: Dict = field(init=False, default=toml.load(METADATA_FILENAME)) def prepare_data(self) -> None: """Prepares the IAM dataset.""" diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 22d00f1..52c10c3 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -1,5 +1,5 @@ """IAM original and sythetic dataset class.""" -import attr +from attrs import define, field from torch.utils.data import ConcatDataset from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info @@ -8,7 +8,7 @@ from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs from text_recognizer.data.transforms.load_transform import load_transform_from_file -@attr.s(auto_attribs=True, repr=False) +@define(auto_attribs=True, repr=False) class IAMExtendedParagraphs(BaseDataModule): """A dataset with synthetic and real handwritten paragraph.""" diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index a79c202..34cf605 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -7,7 +7,7 @@ import json from pathlib import Path from typing import List, Sequence, Tuple -import attr +from attrs import define, field from loguru import logger as log import numpy as np from PIL import Image, ImageFile, ImageOps @@ -35,14 +35,14 @@ MAX_LABEL_LENGTH = 89 MAX_WORD_PIECE_LENGTH = 72 -@attr.s(auto_attribs=True, repr=False) +@define(auto_attribs=True, repr=False) class IAMLines(BaseDataModule): """IAM handwritten lines dataset.""" - dims: Tuple[int, int, int] = attr.ib( + dims: Tuple[int, int, int] = field( init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH) ) - output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1)) + output_dims: Tuple[int, int] = field(init=False, default=(MAX_LABEL_LENGTH, 1)) def prepare_data(self) -> None: """Creates the IAM lines dataset if not existing.""" diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index 033b93e..b605bbc 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -3,7 +3,7 @@ import json from pathlib import Path from typing import Dict, List, Optional, Sequence, Tuple -import attr +from attrs import define, field from loguru import logger as log import numpy as np from PIL import Image, ImageOps @@ -33,15 +33,15 @@ MAX_LABEL_LENGTH = 682 MAX_WORD_PIECE_LENGTH = 451 -@attr.s(auto_attribs=True, repr=False) +@define(auto_attribs=True, repr=False) class IAMParagraphs(BaseDataModule): """IAM handwriting database paragraphs.""" # Placeholders - dims: Tuple[int, int, int] = attr.ib( + dims: Tuple[int, int, int] = field( init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH) ) - output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1)) + output_dims: Tuple[int, int] = field(init=False, default=(MAX_LABEL_LENGTH, 1)) def prepare_data(self) -> None: """Create data for training/testing.""" @@ -86,7 +86,10 @@ class IAMParagraphs(BaseDataModule): length=self.output_dims[0], ) return BaseDataset( - data, targets, transform=transform, target_transform=target_transform, + data, + targets, + transform=transform, + target_transform=target_transform, ) log.info(f"Loading IAM paragraph regions and lines for {stage}...") diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index ea59098..7143951 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -2,7 +2,7 @@ import random from typing import Any, List, Sequence, Tuple -import attr +from attrs import define from loguru import logger as log import numpy as np from PIL import Image @@ -34,7 +34,7 @@ PROCESSED_DATA_DIRNAME = ( ) -@attr.s(auto_attribs=True, repr=False) +@define(auto_attribs=True, repr=False) class IAMSyntheticParagraphs(IAMParagraphs): """IAM Handwriting database of synthetic paragraphs.""" diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 821cb69..bf3bc08 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -1,7 +1,7 @@ """Base PyTorch Lightning model.""" from typing import Any, Dict, List, Optional, Tuple, Type -import attr +from attrs import define, field import hydra from loguru import logger as log from omegaconf import DictConfig @@ -14,7 +14,7 @@ import torchmetrics from text_recognizer.data.mappings.base import AbstractMapping -@attr.s(eq=False) +@define(eq=False) class BaseLitModel(LightningModule): """Abstract PyTorch Lightning class.""" @@ -22,22 +22,18 @@ class BaseLitModel(LightningModule): """Pre init constructor.""" super().__init__() - network: Type[nn.Module] = attr.ib() - loss_fn: Type[nn.Module] = attr.ib() - optimizer_configs: DictConfig = attr.ib() - lr_scheduler_configs: Optional[DictConfig] = attr.ib() - mapping: Type[AbstractMapping] = attr.ib() + network: Type[nn.Module] = field() + loss_fn: Type[nn.Module] = field() + optimizer_configs: DictConfig = field() + lr_scheduler_configs: Optional[DictConfig] = field() + mapping: Type[AbstractMapping] = field() # Placeholders - train_acc: torchmetrics.Accuracy = attr.ib( - init=False, default=torchmetrics.Accuracy() - ) - val_acc: torchmetrics.Accuracy = attr.ib( - init=False, default=torchmetrics.Accuracy() - ) - test_acc: torchmetrics.Accuracy = attr.ib( + train_acc: torchmetrics.Accuracy = field( init=False, default=torchmetrics.Accuracy() ) + val_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy()) + test_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy()) def optimizer_zero_grad( self, diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py index f83c9e4..e59a830 100644 --- a/text_recognizer/models/metrics.py +++ b/text_recognizer/models/metrics.py @@ -1,20 +1,20 @@ """Character Error Rate (CER).""" from typing import Set -import attr +from attrs import define, field import editdistance import torch from torch import Tensor from torchmetrics import Metric -@attr.s(eq=False) +@define(eq=False) class CharacterErrorRate(Metric): """Character error rate metric, computed using Levenshtein distance.""" - ignore_indices: Set[Tensor] = attr.ib(converter=set) - error: Tensor = attr.ib(init=False) - total: Tensor = attr.ib(init=False) + ignore_indices: Set[Tensor] = field(converter=set) + error: Tensor = field(init=False) + total: Tensor = field(init=False) def __attrs_post_init__(self) -> None: super().__init__() diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 7272f46..c5120fe 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,7 +1,7 @@ """PyTorch Lightning model for base Transformers.""" from typing import Set, Tuple -import attr +from attrs import define, field import torch from torch import Tensor @@ -9,22 +9,22 @@ from text_recognizer.models.base import BaseLitModel from text_recognizer.models.metrics import CharacterErrorRate -@attr.s(auto_attribs=True, eq=False) +@define(auto_attribs=True, eq=False) class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - max_output_len: int = attr.ib(default=451) - start_token: str = attr.ib(default="") - end_token: str = attr.ib(default="") - pad_token: str = attr.ib(default="

") + max_output_len: int = field(default=451) + start_token: str = field(default="") + end_token: str = field(default="") + pad_token: str = field(default="

") - start_index: int = attr.ib(init=False) - end_index: int = attr.ib(init=False) - pad_index: int = attr.ib(init=False) + start_index: int = field(init=False) + end_index: int = field(init=False) + pad_index: int = field(init=False) - ignore_indices: Set[Tensor] = attr.ib(init=False) - val_cer: CharacterErrorRate = attr.ib(init=False) - test_cer: CharacterErrorRate = attr.ib(init=False) + ignore_indices: Set[Tensor] = field(init=False) + val_cer: CharacterErrorRate = field(init=False) + test_cer: CharacterErrorRate = field(init=False) def __attrs_post_init__(self) -> None: """Post init configuration.""" diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py index 4c9ed75..cf64bcf 100644 --- a/text_recognizer/networks/efficientnet/efficientnet.py +++ b/text_recognizer/networks/efficientnet/efficientnet.py @@ -1,7 +1,7 @@ """Efficientnet backbone.""" from typing import Tuple -import attr +from attrs import define, field from torch import nn, Tensor from text_recognizer.networks.efficientnet.mbconv import MBConvBlock @@ -12,7 +12,7 @@ from text_recognizer.networks.efficientnet.utils import ( ) -@attr.s(eq=False) +@define(eq=False) class EfficientNet(nn.Module): """Efficientnet without classification head.""" @@ -33,28 +33,28 @@ class EfficientNet(nn.Module): "l2": (4.3, 5.3, 0.5), } - arch: str = attr.ib() - params: Tuple[float, float, float] = attr.ib(default=None, init=False) - stochastic_dropout_rate: float = attr.ib(default=0.2) - bn_momentum: float = attr.ib(default=0.99) - bn_eps: float = attr.ib(default=1.0e-3) - depth: int = attr.ib(default=7) - out_channels: int = attr.ib(default=None, init=False) - _conv_stem: nn.Sequential = attr.ib(default=None, init=False) - _blocks: nn.ModuleList = attr.ib(default=None, init=False) - _conv_head: nn.Sequential = attr.ib(default=None, init=False) + arch: str = field() + params: Tuple[float, float, float] = field(default=None, init=False) + stochastic_dropout_rate: float = field(default=0.2) + bn_momentum: float = field(default=0.99) + bn_eps: float = field(default=1.0e-3) + depth: int = field(default=7) + out_channels: int = field(default=None, init=False) + _conv_stem: nn.Sequential = field(default=None, init=False) + _blocks: nn.ModuleList = field(default=None, init=False) + _conv_head: nn.Sequential = field(default=None, init=False) def __attrs_post_init__(self) -> None: """Post init configuration.""" self._build() @depth.validator - def _check_depth(self, attribute: attr._make.Attribute, value: str) -> None: + def _check_depth(self, attribute, value: str) -> None: if not 5 <= value <= 7: raise ValueError(f"Depth has to be between 5 and 7, was: {value}") @arch.validator - def _check_arch(self, attribute: attr._make.Attribute, value: str) -> None: + def _check_arch(self, attribute, value: str) -> None: """Validates the efficientnet architecure.""" if value not in self.archs: raise ValueError(f"{value} not a valid architecure.") @@ -88,7 +88,9 @@ class EfficientNet(nn.Module): for _ in range(num_repeats): self._blocks.append( MBConvBlock( - **args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps, + **args, + bn_momentum=self.bn_momentum, + bn_eps=self.bn_eps, ) ) args.in_channels = args.out_channels diff --git a/text_recognizer/networks/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py index beb7d57..98e9353 100644 --- a/text_recognizer/networks/efficientnet/mbconv.py +++ b/text_recognizer/networks/efficientnet/mbconv.py @@ -1,7 +1,7 @@ """Mobile inverted residual block.""" from typing import Optional, Tuple, Union -import attr +from attrs import define, field import torch from torch import nn, Tensor import torch.nn.functional as F @@ -14,13 +14,13 @@ def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]: return (stride,) * 2 if isinstance(stride, int) else stride -@attr.s(eq=False) +@define(eq=False) class BaseModule(nn.Module): """Base sub module class.""" - bn_momentum: float = attr.ib() - bn_eps: float = attr.ib() - block: nn.Sequential = attr.ib(init=False) + bn_momentum: float = field() + bn_eps: float = field() + block: nn.Sequential = field(init=False) def __attrs_pre_init__(self) -> None: super().__init__() @@ -36,12 +36,12 @@ class BaseModule(nn.Module): return self.block(x) -@attr.s(auto_attribs=True, eq=False) +@define(auto_attribs=True, eq=False) class InvertedBottleneck(BaseModule): """Inverted bottleneck module.""" - in_channels: int = attr.ib() - out_channels: int = attr.ib() + in_channels: int = field() + out_channels: int = field() def _build(self) -> None: self.block = nn.Sequential( @@ -60,13 +60,13 @@ class InvertedBottleneck(BaseModule): ) -@attr.s(auto_attribs=True, eq=False) +@define(auto_attribs=True, eq=False) class Depthwise(BaseModule): """Depthwise convolution module.""" - channels: int = attr.ib() - kernel_size: int = attr.ib() - stride: int = attr.ib() + channels: int = field() + kernel_size: int = field() + stride: int = field() def _build(self) -> None: self.block = nn.Sequential( @@ -85,13 +85,13 @@ class Depthwise(BaseModule): ) -@attr.s(auto_attribs=True, eq=False) +@define(auto_attribs=True, eq=False) class SqueezeAndExcite(BaseModule): """Sequeeze and excite module.""" - in_channels: int = attr.ib() - channels: int = attr.ib() - se_ratio: float = attr.ib() + in_channels: int = field() + channels: int = field() + se_ratio: float = field() def _build(self) -> None: num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio)) @@ -110,12 +110,12 @@ class SqueezeAndExcite(BaseModule): ) -@attr.s(auto_attribs=True, eq=False) +@define(auto_attribs=True, eq=False) class Pointwise(BaseModule): """Pointwise module.""" - in_channels: int = attr.ib() - out_channels: int = attr.ib() + in_channels: int = field() + out_channels: int = field() def _build(self) -> None: self.block = nn.Sequential( @@ -133,32 +133,35 @@ class Pointwise(BaseModule): ) -@attr.s(eq=False) +@define(eq=False) class MBConvBlock(nn.Module): """Mobile Inverted Residual Bottleneck block.""" def __attrs_pre_init__(self) -> None: super().__init__() - in_channels: int = attr.ib() - out_channels: int = attr.ib() - kernel_size: Tuple[int, int] = attr.ib() - stride: Tuple[int, int] = attr.ib(converter=_convert_stride) - bn_momentum: float = attr.ib() - bn_eps: float = attr.ib() - se_ratio: float = attr.ib() - expand_ratio: int = attr.ib() - pad: Tuple[int, int, int, int] = attr.ib(init=False) - _inverted_bottleneck: Optional[InvertedBottleneck] = attr.ib(init=False) - _depthwise: nn.Sequential = attr.ib(init=False) - _squeeze_excite: nn.Sequential = attr.ib(init=False) - _pointwise: nn.Sequential = attr.ib(init=False) + in_channels: int = field() + out_channels: int = field() + kernel_size: Tuple[int, int] = field() + stride: Tuple[int, int] = field(converter=_convert_stride) + bn_momentum: float = field() + bn_eps: float = field() + se_ratio: float = field() + expand_ratio: int = field() + pad: Tuple[int, int, int, int] = field(init=False) + _inverted_bottleneck: Optional[InvertedBottleneck] = field(init=False) + _depthwise: nn.Sequential = field(init=False) + _squeeze_excite: nn.Sequential = field(init=False) + _pointwise: nn.Sequential = field(init=False) @pad.default def _configure_padding(self) -> Tuple[int, int, int, int]: """Set padding for convolutional layers.""" if self.stride == (2, 2): - return ((self.kernel_size - 1) // 2 - 1, (self.kernel_size - 1) // 2,) * 2 + return ( + (self.kernel_size - 1) // 2 - 1, + (self.kernel_size - 1) // 2, + ) * 2 return ((self.kernel_size - 1) // 2,) * 4 def __attrs_post_init__(self) -> None: diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 87792a9..aa15b88 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -1,7 +1,7 @@ """Implementes the attention module for the transformer.""" from typing import Optional, Tuple -import attr +from attrs import define, field from einops import rearrange import torch from torch import einsum @@ -15,22 +15,22 @@ from text_recognizer.networks.transformer.embeddings.rotary import ( ) -@attr.s(eq=False) +@define(eq=False) class Attention(nn.Module): """Standard attention.""" def __attrs_pre_init__(self) -> None: super().__init__() - dim: int = attr.ib() - num_heads: int = attr.ib() - causal: bool = attr.ib(default=False) - dim_head: int = attr.ib(default=64) - dropout_rate: float = attr.ib(default=0.0) - rotary_embedding: Optional[RotaryEmbedding] = attr.ib(default=None) - scale: float = attr.ib(init=False) - dropout: nn.Dropout = attr.ib(init=False) - fc: nn.Linear = attr.ib(init=False) + dim: int = field() + num_heads: int = field() + causal: bool = field(default=False) + dim_head: int = field(default=64) + dropout_rate: float = field(default=0.0) + rotary_embedding: Optional[RotaryEmbedding] = field(default=None) + scale: float = field(init=False) + dropout: nn.Dropout = field(init=False) + fc: nn.Linear = field(init=False) def __attrs_post_init__(self) -> None: self.scale = self.dim ** -0.5 @@ -120,7 +120,6 @@ def apply_input_mask( input_mask = q_mask * k_mask energy = energy.masked_fill_(~input_mask, mask_value) - del input_mask return energy @@ -133,5 +132,4 @@ def apply_causal_mask( mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j") mask = F.pad(mask, (j - i, 0), value=False) energy.masked_fill_(mask, mask_value) - del mask return energy -- cgit v1.2.3-70-g09d2