diff options
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r-- | text_recognizer/models/transformer.py | 24 |
1 files changed, 12 insertions, 12 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 7272f46..c5120fe 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,7 +1,7 @@ """PyTorch Lightning model for base Transformers.""" from typing import Set, Tuple -import attr +from attrs import define, field import torch from torch import Tensor @@ -9,22 +9,22 @@ from text_recognizer.models.base import BaseLitModel from text_recognizer.models.metrics import CharacterErrorRate -@attr.s(auto_attribs=True, eq=False) +@define(auto_attribs=True, eq=False) class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - max_output_len: int = attr.ib(default=451) - start_token: str = attr.ib(default="<s>") - end_token: str = attr.ib(default="<e>") - pad_token: str = attr.ib(default="<p>") + max_output_len: int = field(default=451) + start_token: str = field(default="<s>") + end_token: str = field(default="<e>") + pad_token: str = field(default="<p>") - start_index: int = attr.ib(init=False) - end_index: int = attr.ib(init=False) - pad_index: int = attr.ib(init=False) + start_index: int = field(init=False) + end_index: int = field(init=False) + pad_index: int = field(init=False) - ignore_indices: Set[Tensor] = attr.ib(init=False) - val_cer: CharacterErrorRate = attr.ib(init=False) - test_cer: CharacterErrorRate = attr.ib(init=False) + ignore_indices: Set[Tensor] = field(init=False) + val_cer: CharacterErrorRate = field(init=False) + test_cer: CharacterErrorRate = field(init=False) def __attrs_post_init__(self) -> None: """Post init configuration.""" |