diff options
Diffstat (limited to 'text_recognizer/models')
-rw-r--r-- | text_recognizer/models/base.py | 20 | ||||
-rw-r--r-- | text_recognizer/models/transformer.py | 36 |
2 files changed, 27 insertions, 29 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 8ce5c37..57c5964 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -11,6 +11,8 @@ from torch import nn from torch import Tensor import torchmetrics +from text_recognizer.data.base_mapping import AbstractMapping + @attr.s(eq=False) class BaseLitModel(LightningModule): @@ -20,12 +22,12 @@ class BaseLitModel(LightningModule): super().__init__() network: Type[nn.Module] = attr.ib() - criterion_config: DictConfig = attr.ib(converter=DictConfig) - optimizer_config: DictConfig = attr.ib(converter=DictConfig) - lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig) + mapping: Type[AbstractMapping] = attr.ib() + loss_fn: Type[nn.Module] = attr.ib() + optimizer_config: DictConfig = attr.ib() + lr_scheduler_config: DictConfig = attr.ib() interval: str = attr.ib() monitor: str = attr.ib(default="val/loss") - loss_fn: Type[nn.Module] = attr.ib(init=False) train_acc: torchmetrics.Accuracy = attr.ib( init=False, default=torchmetrics.Accuracy() ) @@ -36,12 +38,6 @@ class BaseLitModel(LightningModule): init=False, default=torchmetrics.Accuracy() ) - @loss_fn.default - def configure_criterion(self) -> Type[nn.Module]: - """Returns a loss functions.""" - log.info(f"Instantiating criterion <{self.criterion_config._target_}>") - return hydra.utils.instantiate(self.criterion_config) - def optimizer_zero_grad( self, epoch: int, @@ -54,7 +50,9 @@ class BaseLitModel(LightningModule): def _configure_optimizer(self) -> Type[torch.optim.Optimizer]: """Configures the optimizer.""" log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>") - return hydra.utils.instantiate(self.optimizer_config, params=self.parameters()) + return hydra.utils.instantiate( + self.optimizer_config, params=self.network.parameters() + ) def _configure_lr_scheduler( self, optimizer: Type[torch.optim.Optimizer] diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 91e088d..5fb84a7 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -5,7 +5,6 @@ import attr import torch from torch import Tensor -from text_recognizer.data.mappings import AbstractMapping from text_recognizer.models.metrics import CharacterErrorRate from text_recognizer.models.base import BaseLitModel @@ -14,14 +13,14 @@ from text_recognizer.models.base import BaseLitModel class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - mapping: Type[AbstractMapping] = attr.ib(default=None) + max_output_len: int = attr.ib(default=451) start_token: str = attr.ib(default="<s>") end_token: str = attr.ib(default="<e>") pad_token: str = attr.ib(default="<p>") - start_index: Tensor = attr.ib(init=False) - end_index: Tensor = attr.ib(init=False) - pad_index: Tensor = attr.ib(init=False) + start_index: int = attr.ib(init=False) + end_index: int = attr.ib(init=False) + pad_index: int = attr.ib(init=False) ignore_indices: Set[Tensor] = attr.ib(init=False) val_cer: CharacterErrorRate = attr.ib(init=False) @@ -29,9 +28,9 @@ class TransformerLitModel(BaseLitModel): def __attrs_post_init__(self) -> None: """Post init configuration.""" - self.start_index = self.mapping.get_index(self.start_token) - self.end_index = self.mapping.get_index(self.end_token) - self.pad_index = self.mapping.get_index(self.pad_token) + self.start_index = int(self.mapping.get_index(self.start_token)) + self.end_index = int(self.mapping.get_index(self.end_token)) + self.pad_index = int(self.mapping.get_index(self.pad_token)) self.ignore_indices = set([self.start_index, self.end_index, self.pad_index]) self.val_cer = CharacterErrorRate(self.ignore_indices) self.test_cer = CharacterErrorRate(self.ignore_indices) @@ -93,23 +92,24 @@ class TransformerLitModel(BaseLitModel): output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) output[:, 0] = self.start_index - for i in range(1, self.max_output_len): - context = output[:, :i] # (bsz, i) - logits = self.network.decode(z, context) # (i, bsz, c) - tokens = torch.argmax(logits, dim=-1) # (i, bsz) - output[:, i : i + 1] = tokens[-1:] + for Sy in range(1, self.max_output_len): + context = output[:, :Sy] # (B, Sy) + logits = self.network.decode(z, context) # (B, Sy, C) + tokens = torch.argmax(logits, dim=-1) # (B, Sy) + output[:, Sy : Sy + 1] = tokens[:, -1:] # Early stopping of prediction loop if token is end or padding token. if ( - output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index + (output[:, Sy - 1] == self.end_index) + | (output[:, Sy - 1] == self.pad_index) ).all(): break # Set all tokens after end token to pad token. - for i in range(1, self.max_output_len): - idx = ( - output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index + for Sy in range(1, self.max_output_len): + idx = (output[:, Sy - 1] == self.end_index) | ( + output[:, Sy - 1] == self.pad_index ) - output[idx, i] = self.pad_index + output[idx, Sy] = self.pad_index return output |