summaryrefslogtreecommitdiff
path: root/text_recognizer/models/transformer.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-29 23:59:52 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-29 23:59:52 +0200
commit34098ccbbbf6379c0bd29a987440b8479c743746 (patch)
treea8c68e3036503049fc7034c677ec855465f7a8e0 /text_recognizer/models/transformer.py
parentc032ffb05a7ed86f8fe5d596f94e8997c558cae8 (diff)
Configs, refactor with attrs, fix attr bug in iam
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r--text_recognizer/models/transformer.py26
1 files changed, 6 insertions, 20 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 8c9fe8a..f5cb491 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,13 +1,11 @@
"""PyTorch Lightning model for base Transformers."""
-from typing import Dict, List, Optional, Union, Tuple, Type
+from typing import Dict, List, Optional, Sequence, Union, Tuple, Type
import attr
import hydra
from omegaconf import DictConfig
from torch import nn, Tensor
-from text_recognizer.data.emnist import emnist_mapping
-from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.models.metrics import CharacterErrorRate
from text_recognizer.models.base import BaseLitModel
@@ -16,30 +14,18 @@ from text_recognizer.models.base import BaseLitModel
class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- mapping_config: DictConfig = attr.ib(converter=DictConfig)
+ ignore_tokens: Sequence[str] = attr.ib(default=("<s>", "<e>", "<p>",))
+ val_cer: CharacterErrorRate = attr.ib(init=False)
+ test_cer: CharacterErrorRate = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
- self.mapping, ignore_tokens = self._configure_mapping()
- self.val_cer = CharacterErrorRate(ignore_tokens)
- self.test_cer = CharacterErrorRate(ignore_tokens)
+ self.val_cer = CharacterErrorRate(self.ignore_tokens)
+ self.test_cer = CharacterErrorRate(self.ignore_tokens)
def forward(self, data: Tensor) -> Tensor:
"""Forward pass with the transformer network."""
return self.network.predict(data)
- @staticmethod
- def _configure_mapping() -> Tuple[Type[AbstractMapping], List[int]]:
- """Configure mapping."""
- # TODO: Fix me!!!
- # Load config with hydra
- mapping, inverse_mapping, _ = emnist_mapping(["\n"])
- start_index = inverse_mapping["<s>"]
- end_index = inverse_mapping["<e>"]
- pad_index = inverse_mapping["<p>"]
- ignore_tokens = [start_index, end_index, pad_index]
- # TODO: add case for sentence pieces
- return mapping, ignore_tokens
-
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
"""Training step."""
data, targets = batch