summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/models/transformer.py5
-rw-r--r--text_recognizer/models/vq_transformer.py8
2 files changed, 7 insertions, 6 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index dbdc7f2..0acc30a 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,10 +1,11 @@
"""PyTorch Lightning model for base Transformers."""
-from typing import Tuple, Set
+from typing import Tuple, Type, Set
import attr
import torch
from torch import Tensor
+from text_recognizer.data.base_mapping import AbstractMapping
from text_recognizer.models.metrics import CharacterErrorRate
from text_recognizer.models.base import BaseLitModel
@@ -13,6 +14,8 @@ from text_recognizer.models.base import BaseLitModel
class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
+ mapping: Type[AbstractMapping] = attr.ib()
+
max_output_len: int = attr.ib(default=451)
start_token: str = attr.ib(default="<s>")
end_token: str = attr.ib(default="<e>")
diff --git a/text_recognizer/models/vq_transformer.py b/text_recognizer/models/vq_transformer.py
index 339ce09..b419e97 100644
--- a/text_recognizer/models/vq_transformer.py
+++ b/text_recognizer/models/vq_transformer.py
@@ -1,10 +1,11 @@
"""PyTorch Lightning model for base Transformers."""
-from typing import Tuple
+from typing import Tuple, Type
import attr
import torch
from torch import Tensor
+from text_recognizer.data.base_mapping import AbstractMapping
from text_recognizer.models.transformer import TransformerLitModel
@@ -12,6 +13,7 @@ from text_recognizer.models.transformer import TransformerLitModel
class VqTransformerLitModel(TransformerLitModel):
"""A PyTorch Lightning model for transformer networks."""
+ mapping: Type[AbstractMapping] = attr.ib()
alpha: float = attr.ib(default=1.0)
def forward(self, data: Tensor) -> Tensor:
@@ -30,8 +32,6 @@ class VqTransformerLitModel(TransformerLitModel):
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Validation step."""
data, targets = batch
-
- # Compute the loss.
logits, commitment_loss = self.network(data, targets[:, :-1])
loss = self.loss_fn(logits, targets[:, 1:]) + self.alpha * commitment_loss
self.log("val/loss", loss, prog_bar=True)
@@ -47,8 +47,6 @@ class VqTransformerLitModel(TransformerLitModel):
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Test step."""
data, targets = batch
-
- # Compute the text prediction.
pred = self(data)
self.test_cer(pred, targets)
self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)