diff options
Diffstat (limited to 'text_recognizer/models')
-rw-r--r-- | text_recognizer/models/base.py | 36 | ||||
-rw-r--r-- | text_recognizer/models/metrics.py | 13 | ||||
-rw-r--r-- | text_recognizer/models/transformer.py | 29 |
3 files changed, 36 insertions, 42 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index bf3bc08..63fe5a7 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -1,7 +1,6 @@ """Base PyTorch Lightning model.""" from typing import Any, Dict, List, Optional, Tuple, Type -from attrs import define, field import hydra from loguru import logger as log from omegaconf import DictConfig @@ -9,31 +8,34 @@ from pytorch_lightning import LightningModule import torch from torch import nn from torch import Tensor -import torchmetrics +from torchmetrics import Accuracy from text_recognizer.data.mappings.base import AbstractMapping -@define(eq=False) class BaseLitModel(LightningModule): """Abstract PyTorch Lightning class.""" - def __attrs_pre_init__(self) -> None: - """Pre init constructor.""" + def __init__( + self, + network: Type[nn.Module], + loss_fn: Type[nn.Module], + optimizer_configs: DictConfig, + lr_scheduler_configs: Optional[DictConfig], + mapping: Type[AbstractMapping], + ) -> None: super().__init__() - network: Type[nn.Module] = field() - loss_fn: Type[nn.Module] = field() - optimizer_configs: DictConfig = field() - lr_scheduler_configs: Optional[DictConfig] = field() - mapping: Type[AbstractMapping] = field() - - # Placeholders - train_acc: torchmetrics.Accuracy = field( - init=False, default=torchmetrics.Accuracy() - ) - val_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy()) - test_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy()) + self.network = network + self.loss_fn = loss_fn + self.optimizer_configs = optimizer_configs + self.lr_scheduler_configs = lr_scheduler_configs + self.mapping = mapping + + # Placeholders + self.train_acc = Accuracy() + self.val_acc = Accuracy() + self.test_acc = Accuracy() def optimizer_zero_grad( self, diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py index e59a830..3cb16b5 100644 --- a/text_recognizer/models/metrics.py +++ b/text_recognizer/models/metrics.py @@ -1,25 +1,22 @@ """Character Error Rate (CER).""" -from typing import Set +from typing import Sequence -from attrs import define, field import editdistance import torch from torch import Tensor from torchmetrics import Metric -@define(eq=False) class CharacterErrorRate(Metric): """Character error rate metric, computed using Levenshtein distance.""" - ignore_indices: Set[Tensor] = field(converter=set) - error: Tensor = field(init=False) - total: Tensor = field(init=False) - - def __attrs_post_init__(self) -> None: + def __init__(self, ignore_indices: Sequence[Tensor]) -> None: super().__init__() + self.ignore_indices = set(ignore_indices) self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + self.error: Tensor + self.total: Tensor def update(self, preds: Tensor, targets: Tensor) -> None: """Update CER.""" diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index c5120fe..9537dd9 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,7 +1,6 @@ """PyTorch Lightning model for base Transformers.""" from typing import Set, Tuple -from attrs import define, field import torch from torch import Tensor @@ -9,25 +8,21 @@ from text_recognizer.models.base import BaseLitModel from text_recognizer.models.metrics import CharacterErrorRate -@define(auto_attribs=True, eq=False) class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - max_output_len: int = field(default=451) - start_token: str = field(default="<s>") - end_token: str = field(default="<e>") - pad_token: str = field(default="<p>") - - start_index: int = field(init=False) - end_index: int = field(init=False) - pad_index: int = field(init=False) - - ignore_indices: Set[Tensor] = field(init=False) - val_cer: CharacterErrorRate = field(init=False) - test_cer: CharacterErrorRate = field(init=False) - - def __attrs_post_init__(self) -> None: - """Post init configuration.""" + def __init__( + self, + max_output_len: int = 451, + start_token: str = "<s>", + end_token: str = "<e>", + pad_token: str = "<p>", + ) -> None: + super().__init__() + self.max_output_len = max_output_len + self.start_token = start_token + self.end_token = end_token + self.pad_token = pad_token self.start_index = int(self.mapping.get_index(self.start_token)) self.end_index = int(self.mapping.get_index(self.end_token)) self.pad_index = int(self.mapping.get_index(self.pad_token)) |