diff options
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r-- | text_recognizer/models/transformer.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 0e01bb5..91e088d 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,5 +1,5 @@ """PyTorch Lightning model for base Transformers.""" -from typing import Sequence, Tuple, Type +from typing import Tuple, Type, Set import attr import torch @@ -10,20 +10,20 @@ from text_recognizer.models.metrics import CharacterErrorRate from text_recognizer.models.base import BaseLitModel -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, eq=False) class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - mapping: Type[AbstractMapping] = attr.ib() - start_token: str = attr.ib() - end_token: str = attr.ib() - pad_token: str = attr.ib() + mapping: Type[AbstractMapping] = attr.ib(default=None) + start_token: str = attr.ib(default="<s>") + end_token: str = attr.ib(default="<e>") + pad_token: str = attr.ib(default="<p>") start_index: Tensor = attr.ib(init=False) end_index: Tensor = attr.ib(init=False) pad_index: Tensor = attr.ib(init=False) - ignore_indices: Sequence[str] = attr.ib(init=False) + ignore_indices: Set[Tensor] = attr.ib(init=False) val_cer: CharacterErrorRate = attr.ib(init=False) test_cer: CharacterErrorRate = attr.ib(init=False) @@ -32,7 +32,7 @@ class TransformerLitModel(BaseLitModel): self.start_index = self.mapping.get_index(self.start_token) self.end_index = self.mapping.get_index(self.end_token) self.pad_index = self.mapping.get_index(self.pad_token) - self.ignore_indices = [self.start_index, self.end_index, self.pad_index] + self.ignore_indices = set([self.start_index, self.end_index, self.pad_index]) self.val_cer = CharacterErrorRate(self.ignore_indices) self.test_cer = CharacterErrorRate(self.ignore_indices) |