diff options
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r-- | text_recognizer/models/transformer.py | 29 |
1 files changed, 12 insertions, 17 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index c5120fe..9537dd9 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,7 +1,6 @@ """PyTorch Lightning model for base Transformers.""" from typing import Set, Tuple -from attrs import define, field import torch from torch import Tensor @@ -9,25 +8,21 @@ from text_recognizer.models.base import BaseLitModel from text_recognizer.models.metrics import CharacterErrorRate -@define(auto_attribs=True, eq=False) class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - max_output_len: int = field(default=451) - start_token: str = field(default="<s>") - end_token: str = field(default="<e>") - pad_token: str = field(default="<p>") - - start_index: int = field(init=False) - end_index: int = field(init=False) - pad_index: int = field(init=False) - - ignore_indices: Set[Tensor] = field(init=False) - val_cer: CharacterErrorRate = field(init=False) - test_cer: CharacterErrorRate = field(init=False) - - def __attrs_post_init__(self) -> None: - """Post init configuration.""" + def __init__( + self, + max_output_len: int = 451, + start_token: str = "<s>", + end_token: str = "<e>", + pad_token: str = "<p>", + ) -> None: + super().__init__() + self.max_output_len = max_output_len + self.start_token = start_token + self.end_token = end_token + self.pad_token = pad_token self.start_index = int(self.mapping.get_index(self.start_token)) self.end_index = int(self.mapping.get_index(self.end_token)) self.pad_index = int(self.mapping.get_index(self.pad_token)) |