summaryrefslogtreecommitdiff
path: root/text_recognizer/models/transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r--text_recognizer/models/transformer.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 0e01bb5..91e088d 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,5 +1,5 @@
"""PyTorch Lightning model for base Transformers."""
-from typing import Sequence, Tuple, Type
+from typing import Tuple, Type, Set
import attr
import torch
@@ -10,20 +10,20 @@ from text_recognizer.models.metrics import CharacterErrorRate
from text_recognizer.models.base import BaseLitModel
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, eq=False)
class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- mapping: Type[AbstractMapping] = attr.ib()
- start_token: str = attr.ib()
- end_token: str = attr.ib()
- pad_token: str = attr.ib()
+ mapping: Type[AbstractMapping] = attr.ib(default=None)
+ start_token: str = attr.ib(default="<s>")
+ end_token: str = attr.ib(default="<e>")
+ pad_token: str = attr.ib(default="<p>")
start_index: Tensor = attr.ib(init=False)
end_index: Tensor = attr.ib(init=False)
pad_index: Tensor = attr.ib(init=False)
- ignore_indices: Sequence[str] = attr.ib(init=False)
+ ignore_indices: Set[Tensor] = attr.ib(init=False)
val_cer: CharacterErrorRate = attr.ib(init=False)
test_cer: CharacterErrorRate = attr.ib(init=False)
@@ -32,7 +32,7 @@ class TransformerLitModel(BaseLitModel):
self.start_index = self.mapping.get_index(self.start_token)
self.end_index = self.mapping.get_index(self.end_token)
self.pad_index = self.mapping.get_index(self.pad_token)
- self.ignore_indices = [self.start_index, self.end_index, self.pad_index]
+ self.ignore_indices = set([self.start_index, self.end_index, self.pad_index])
self.val_cer = CharacterErrorRate(self.ignore_indices)
self.test_cer = CharacterErrorRate(self.ignore_indices)