summaryrefslogtreecommitdiff
path: root/text_recognizer/models
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models')
-rw-r--r--text_recognizer/models/base.py20
-rw-r--r--text_recognizer/models/transformer.py36
2 files changed, 27 insertions, 29 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 8ce5c37..57c5964 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -11,6 +11,8 @@ from torch import nn
from torch import Tensor
import torchmetrics
+from text_recognizer.data.base_mapping import AbstractMapping
+
@attr.s(eq=False)
class BaseLitModel(LightningModule):
@@ -20,12 +22,12 @@ class BaseLitModel(LightningModule):
super().__init__()
network: Type[nn.Module] = attr.ib()
- criterion_config: DictConfig = attr.ib(converter=DictConfig)
- optimizer_config: DictConfig = attr.ib(converter=DictConfig)
- lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig)
+ mapping: Type[AbstractMapping] = attr.ib()
+ loss_fn: Type[nn.Module] = attr.ib()
+ optimizer_config: DictConfig = attr.ib()
+ lr_scheduler_config: DictConfig = attr.ib()
interval: str = attr.ib()
monitor: str = attr.ib(default="val/loss")
- loss_fn: Type[nn.Module] = attr.ib(init=False)
train_acc: torchmetrics.Accuracy = attr.ib(
init=False, default=torchmetrics.Accuracy()
)
@@ -36,12 +38,6 @@ class BaseLitModel(LightningModule):
init=False, default=torchmetrics.Accuracy()
)
- @loss_fn.default
- def configure_criterion(self) -> Type[nn.Module]:
- """Returns a loss functions."""
- log.info(f"Instantiating criterion <{self.criterion_config._target_}>")
- return hydra.utils.instantiate(self.criterion_config)
-
def optimizer_zero_grad(
self,
epoch: int,
@@ -54,7 +50,9 @@ class BaseLitModel(LightningModule):
def _configure_optimizer(self) -> Type[torch.optim.Optimizer]:
"""Configures the optimizer."""
log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>")
- return hydra.utils.instantiate(self.optimizer_config, params=self.parameters())
+ return hydra.utils.instantiate(
+ self.optimizer_config, params=self.network.parameters()
+ )
def _configure_lr_scheduler(
self, optimizer: Type[torch.optim.Optimizer]
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 91e088d..5fb84a7 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -5,7 +5,6 @@ import attr
import torch
from torch import Tensor
-from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.models.metrics import CharacterErrorRate
from text_recognizer.models.base import BaseLitModel
@@ -14,14 +13,14 @@ from text_recognizer.models.base import BaseLitModel
class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- mapping: Type[AbstractMapping] = attr.ib(default=None)
+ max_output_len: int = attr.ib(default=451)
start_token: str = attr.ib(default="<s>")
end_token: str = attr.ib(default="<e>")
pad_token: str = attr.ib(default="<p>")
- start_index: Tensor = attr.ib(init=False)
- end_index: Tensor = attr.ib(init=False)
- pad_index: Tensor = attr.ib(init=False)
+ start_index: int = attr.ib(init=False)
+ end_index: int = attr.ib(init=False)
+ pad_index: int = attr.ib(init=False)
ignore_indices: Set[Tensor] = attr.ib(init=False)
val_cer: CharacterErrorRate = attr.ib(init=False)
@@ -29,9 +28,9 @@ class TransformerLitModel(BaseLitModel):
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
- self.start_index = self.mapping.get_index(self.start_token)
- self.end_index = self.mapping.get_index(self.end_token)
- self.pad_index = self.mapping.get_index(self.pad_token)
+ self.start_index = int(self.mapping.get_index(self.start_token))
+ self.end_index = int(self.mapping.get_index(self.end_token))
+ self.pad_index = int(self.mapping.get_index(self.pad_token))
self.ignore_indices = set([self.start_index, self.end_index, self.pad_index])
self.val_cer = CharacterErrorRate(self.ignore_indices)
self.test_cer = CharacterErrorRate(self.ignore_indices)
@@ -93,23 +92,24 @@ class TransformerLitModel(BaseLitModel):
output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
output[:, 0] = self.start_index
- for i in range(1, self.max_output_len):
- context = output[:, :i] # (bsz, i)
- logits = self.network.decode(z, context) # (i, bsz, c)
- tokens = torch.argmax(logits, dim=-1) # (i, bsz)
- output[:, i : i + 1] = tokens[-1:]
+ for Sy in range(1, self.max_output_len):
+ context = output[:, :Sy] # (B, Sy)
+ logits = self.network.decode(z, context) # (B, Sy, C)
+ tokens = torch.argmax(logits, dim=-1) # (B, Sy)
+ output[:, Sy : Sy + 1] = tokens[:, -1:]
# Early stopping of prediction loop if token is end or padding token.
if (
- output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index
+ (output[:, Sy - 1] == self.end_index)
+ | (output[:, Sy - 1] == self.pad_index)
).all():
break
# Set all tokens after end token to pad token.
- for i in range(1, self.max_output_len):
- idx = (
- output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index
+ for Sy in range(1, self.max_output_len):
+ idx = (output[:, Sy - 1] == self.end_index) | (
+ output[:, Sy - 1] == self.pad_index
)
- output[idx, i] = self.pad_index
+ output[idx, Sy] = self.pad_index
return output