summaryrefslogtreecommitdiff
path: root/text_recognizer/models/transformer.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-02 21:13:48 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-02 21:13:48 +0200
commit75801019981492eedf9280cb352eea3d8e99b65f (patch)
tree6521cc4134459e42591b2375f70acd348741474e /text_recognizer/models/transformer.py
parente5eca28438cd17d436359f2c6eee0bb9e55d2a8b (diff)
Fix log import, fix mapping in datamodules, fix nn modules can be hashed
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r--text_recognizer/models/transformer.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 0e01bb5..91e088d 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,5 +1,5 @@
"""PyTorch Lightning model for base Transformers."""
-from typing import Sequence, Tuple, Type
+from typing import Tuple, Type, Set
import attr
import torch
@@ -10,20 +10,20 @@ from text_recognizer.models.metrics import CharacterErrorRate
from text_recognizer.models.base import BaseLitModel
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, eq=False)
class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- mapping: Type[AbstractMapping] = attr.ib()
- start_token: str = attr.ib()
- end_token: str = attr.ib()
- pad_token: str = attr.ib()
+ mapping: Type[AbstractMapping] = attr.ib(default=None)
+ start_token: str = attr.ib(default="<s>")
+ end_token: str = attr.ib(default="<e>")
+ pad_token: str = attr.ib(default="<p>")
start_index: Tensor = attr.ib(init=False)
end_index: Tensor = attr.ib(init=False)
pad_index: Tensor = attr.ib(init=False)
- ignore_indices: Sequence[str] = attr.ib(init=False)
+ ignore_indices: Set[Tensor] = attr.ib(init=False)
val_cer: CharacterErrorRate = attr.ib(init=False)
test_cer: CharacterErrorRate = attr.ib(init=False)
@@ -32,7 +32,7 @@ class TransformerLitModel(BaseLitModel):
self.start_index = self.mapping.get_index(self.start_token)
self.end_index = self.mapping.get_index(self.end_token)
self.pad_index = self.mapping.get_index(self.pad_token)
- self.ignore_indices = [self.start_index, self.end_index, self.pad_index]
+ self.ignore_indices = set([self.start_index, self.end_index, self.pad_index])
self.val_cer = CharacterErrorRate(self.ignore_indices)
self.test_cer = CharacterErrorRate(self.ignore_indices)