diff options
-rw-r--r-- | text_recognizer/models/transformer.py | 5 | ||||
-rw-r--r-- | text_recognizer/models/vq_transformer.py | 8 |
2 files changed, 7 insertions, 6 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index dbdc7f2..0acc30a 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,10 +1,11 @@ """PyTorch Lightning model for base Transformers.""" -from typing import Tuple, Set +from typing import Tuple, Type, Set import attr import torch from torch import Tensor +from text_recognizer.data.base_mapping import AbstractMapping from text_recognizer.models.metrics import CharacterErrorRate from text_recognizer.models.base import BaseLitModel @@ -13,6 +14,8 @@ from text_recognizer.models.base import BaseLitModel class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" + mapping: Type[AbstractMapping] = attr.ib() + max_output_len: int = attr.ib(default=451) start_token: str = attr.ib(default="<s>") end_token: str = attr.ib(default="<e>") diff --git a/text_recognizer/models/vq_transformer.py b/text_recognizer/models/vq_transformer.py index 339ce09..b419e97 100644 --- a/text_recognizer/models/vq_transformer.py +++ b/text_recognizer/models/vq_transformer.py @@ -1,10 +1,11 @@ """PyTorch Lightning model for base Transformers.""" -from typing import Tuple +from typing import Tuple, Type import attr import torch from torch import Tensor +from text_recognizer.data.base_mapping import AbstractMapping from text_recognizer.models.transformer import TransformerLitModel @@ -12,6 +13,7 @@ from text_recognizer.models.transformer import TransformerLitModel class VqTransformerLitModel(TransformerLitModel): """A PyTorch Lightning model for transformer networks.""" + mapping: Type[AbstractMapping] = attr.ib() alpha: float = attr.ib(default=1.0) def forward(self, data: Tensor) -> Tensor: @@ -30,8 +32,6 @@ class VqTransformerLitModel(TransformerLitModel): def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Validation step.""" data, targets = batch - - # Compute the loss. logits, commitment_loss = self.network(data, targets[:, :-1]) loss = self.loss_fn(logits, targets[:, 1:]) + self.alpha * commitment_loss self.log("val/loss", loss, prog_bar=True) @@ -47,8 +47,6 @@ class VqTransformerLitModel(TransformerLitModel): def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" data, targets = batch - - # Compute the text prediction. pred = self(data) self.test_cer(pred, targets) self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) |