diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-03 18:18:48 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-03 18:18:48 +0200 |
commit | bd4bd443f339e95007bfdabf3e060db720f4d4b9 (patch) | |
tree | e55cb3744904f7c2a0348b100c7e92a65e538a16 /text_recognizer/models/transformer.py | |
parent | 75801019981492eedf9280cb352eea3d8e99b65f (diff) |
Training working, multiple bug fixes
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r-- | text_recognizer/models/transformer.py | 36 |
1 files changed, 18 insertions, 18 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 91e088d..5fb84a7 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -5,7 +5,6 @@ import attr import torch from torch import Tensor -from text_recognizer.data.mappings import AbstractMapping from text_recognizer.models.metrics import CharacterErrorRate from text_recognizer.models.base import BaseLitModel @@ -14,14 +13,14 @@ from text_recognizer.models.base import BaseLitModel class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - mapping: Type[AbstractMapping] = attr.ib(default=None) + max_output_len: int = attr.ib(default=451) start_token: str = attr.ib(default="<s>") end_token: str = attr.ib(default="<e>") pad_token: str = attr.ib(default="<p>") - start_index: Tensor = attr.ib(init=False) - end_index: Tensor = attr.ib(init=False) - pad_index: Tensor = attr.ib(init=False) + start_index: int = attr.ib(init=False) + end_index: int = attr.ib(init=False) + pad_index: int = attr.ib(init=False) ignore_indices: Set[Tensor] = attr.ib(init=False) val_cer: CharacterErrorRate = attr.ib(init=False) @@ -29,9 +28,9 @@ class TransformerLitModel(BaseLitModel): def __attrs_post_init__(self) -> None: """Post init configuration.""" - self.start_index = self.mapping.get_index(self.start_token) - self.end_index = self.mapping.get_index(self.end_token) - self.pad_index = self.mapping.get_index(self.pad_token) + self.start_index = int(self.mapping.get_index(self.start_token)) + self.end_index = int(self.mapping.get_index(self.end_token)) + self.pad_index = int(self.mapping.get_index(self.pad_token)) self.ignore_indices = set([self.start_index, self.end_index, self.pad_index]) self.val_cer = CharacterErrorRate(self.ignore_indices) self.test_cer = CharacterErrorRate(self.ignore_indices) @@ -93,23 +92,24 @@ class TransformerLitModel(BaseLitModel): output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) output[:, 0] = self.start_index - for i in range(1, self.max_output_len): - context = output[:, :i] # (bsz, i) - logits = self.network.decode(z, context) # (i, bsz, c) - tokens = torch.argmax(logits, dim=-1) # (i, bsz) - output[:, i : i + 1] = tokens[-1:] + for Sy in range(1, self.max_output_len): + context = output[:, :Sy] # (B, Sy) + logits = self.network.decode(z, context) # (B, Sy, C) + tokens = torch.argmax(logits, dim=-1) # (B, Sy) + output[:, Sy : Sy + 1] = tokens[:, -1:] # Early stopping of prediction loop if token is end or padding token. if ( - output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index + (output[:, Sy - 1] == self.end_index) + | (output[:, Sy - 1] == self.pad_index) ).all(): break # Set all tokens after end token to pad token. - for i in range(1, self.max_output_len): - idx = ( - output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index + for Sy in range(1, self.max_output_len): + idx = (output[:, Sy - 1] == self.end_index) | ( + output[:, Sy - 1] == self.pad_index ) - output[idx, i] = self.pad_index + output[idx, Sy] = self.pad_index return output |