summaryrefslogtreecommitdiff
path: root/text_recognizer/models
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-29 23:59:52 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-29 23:59:52 +0200
commit34098ccbbbf6379c0bd29a987440b8479c743746 (patch)
treea8c68e3036503049fc7034c677ec855465f7a8e0 /text_recognizer/models
parentc032ffb05a7ed86f8fe5d596f94e8997c558cae8 (diff)
Configs, refactor with attrs, fix attr bug in iam
Diffstat (limited to 'text_recognizer/models')
-rw-r--r--text_recognizer/models/base.py31
-rw-r--r--text_recognizer/models/transformer.py26
2 files changed, 19 insertions, 38 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index f95df0f..3b83056 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -3,20 +3,25 @@ from typing import Any, Dict, List, Tuple, Type
import attr
import hydra
-import loguru.logger as log
+from loguru import logger as log
from omegaconf import DictConfig
-import pytorch_lightning as LightningModule
+from pytorch_lightning import LightningModule
import torch
from torch import nn
from torch import Tensor
import torchmetrics
+from text_recognizer.networks.base import BaseNetwork
+
@attr.s
class BaseLitModel(LightningModule):
"""Abstract PyTorch Lightning class."""
- network: Type[nn.Module] = attr.ib()
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ network: Type[BaseNetwork] = attr.ib()
criterion_config: DictConfig = attr.ib(converter=DictConfig)
optimizer_config: DictConfig = attr.ib(converter=DictConfig)
lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig)
@@ -24,23 +29,13 @@ class BaseLitModel(LightningModule):
interval: str = attr.ib()
monitor: str = attr.ib(default="val/loss")
- loss_fn = attr.ib(init=False)
-
- train_acc = attr.ib(init=False)
- val_acc = attr.ib(init=False)
- test_acc = attr.ib(init=False)
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
-
- def __attrs_post_init__(self) -> None:
- self.loss_fn = self._configure_criterion()
+ loss_fn: Type[nn.Module] = attr.ib(init=False)
- # Accuracy metric
- self.train_acc = torchmetrics.Accuracy()
- self.val_acc = torchmetrics.Accuracy()
- self.test_acc = torchmetrics.Accuracy()
+ train_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy())
+ val_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy())
+ test_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy())
+ @loss_fn.default
def configure_criterion(self) -> Type[nn.Module]:
"""Returns a loss functions."""
log.info(f"Instantiating criterion <{self.criterion_config._target_}>")
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 8c9fe8a..f5cb491 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,13 +1,11 @@
"""PyTorch Lightning model for base Transformers."""
-from typing import Dict, List, Optional, Union, Tuple, Type
+from typing import Dict, List, Optional, Sequence, Union, Tuple, Type
import attr
import hydra
from omegaconf import DictConfig
from torch import nn, Tensor
-from text_recognizer.data.emnist import emnist_mapping
-from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.models.metrics import CharacterErrorRate
from text_recognizer.models.base import BaseLitModel
@@ -16,30 +14,18 @@ from text_recognizer.models.base import BaseLitModel
class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- mapping_config: DictConfig = attr.ib(converter=DictConfig)
+ ignore_tokens: Sequence[str] = attr.ib(default=("<s>", "<e>", "<p>",))
+ val_cer: CharacterErrorRate = attr.ib(init=False)
+ test_cer: CharacterErrorRate = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
- self.mapping, ignore_tokens = self._configure_mapping()
- self.val_cer = CharacterErrorRate(ignore_tokens)
- self.test_cer = CharacterErrorRate(ignore_tokens)
+ self.val_cer = CharacterErrorRate(self.ignore_tokens)
+ self.test_cer = CharacterErrorRate(self.ignore_tokens)
def forward(self, data: Tensor) -> Tensor:
"""Forward pass with the transformer network."""
return self.network.predict(data)
- @staticmethod
- def _configure_mapping() -> Tuple[Type[AbstractMapping], List[int]]:
- """Configure mapping."""
- # TODO: Fix me!!!
- # Load config with hydra
- mapping, inverse_mapping, _ = emnist_mapping(["\n"])
- start_index = inverse_mapping["<s>"]
- end_index = inverse_mapping["<e>"]
- pad_index = inverse_mapping["<p>"]
- ignore_tokens = [start_index, end_index, pad_index]
- # TODO: add case for sentence pieces
- return mapping, ignore_tokens
-
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
"""Training step."""
data, targets = batch