summaryrefslogtreecommitdiff
path: root/text_recognizer/models
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-05 23:05:25 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-05 23:05:25 +0200
commit4d1f2cef39688871d2caafce42a09316381a27ae (patch)
tree0f4385969e7df6d7d313cd5910bde9a7475ca027 /text_recognizer/models
parentf0481decdad9afb52494e9e95996deef843ef233 (diff)
Refactor with attr, working on cnn+transformer network
Diffstat (limited to 'text_recognizer/models')
-rw-r--r--text_recognizer/models/__init__.py2
-rw-r--r--text_recognizer/models/base.py11
-rw-r--r--text_recognizer/models/transformer.py30
-rw-r--r--text_recognizer/models/vqvae.py6
4 files changed, 27 insertions, 22 deletions
diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py
index 5ac2510..1982daf 100644
--- a/text_recognizer/models/__init__.py
+++ b/text_recognizer/models/__init__.py
@@ -1,3 +1 @@
"""PyTorch Lightning models modules."""
-from .transformer import LitTransformerModel
-from .vqvae import LitVQVAEModel
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 4e803eb..8dc7a36 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -1,5 +1,5 @@
"""Base PyTorch Lightning model."""
-from typing import Any, Dict, List, Union, Tuple, Type
+from typing import Any, Dict, List, Tuple, Type
import attr
import hydra
@@ -13,7 +13,7 @@ import torchmetrics
@attr.s
-class LitBaseModel(pl.LightningModule):
+class BaseLitModel(pl.LightningModule):
"""Abstract PyTorch Lightning class."""
network: Type[nn.Module] = attr.ib()
@@ -30,18 +30,17 @@ class LitBaseModel(pl.LightningModule):
val_acc = attr.ib(init=False)
test_acc = attr.ib(init=False)
- def __attrs_pre_init__(self):
+ def __attrs_pre_init__(self) -> None:
super().__init__()
- def __attrs_post_init__(self):
- self.loss_fn = self.configure_criterion()
+ def __attrs_post_init__(self) -> None:
+ self.loss_fn = self._configure_criterion()
# Accuracy metric
self.train_acc = torchmetrics.Accuracy()
self.val_acc = torchmetrics.Accuracy()
self.test_acc = torchmetrics.Accuracy()
- @staticmethod
def configure_criterion(self) -> Type[nn.Module]:
"""Returns a loss functions."""
log.info(f"Instantiating criterion <{self.criterion_config._target_}>")
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 6be0ac5..ea54d83 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,27 +1,35 @@
"""PyTorch Lightning model for base Transformers."""
from typing import Dict, List, Optional, Union, Tuple, Type
+import attr
from omegaconf import DictConfig
from torch import nn, Tensor
from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.models.metrics import CharacterErrorRate
from text_recognizer.models.base import LitBaseModel
-class LitTransformerModel(LitBaseModel):
+@attr.s
+class TransformerLitModel(LitBaseModel):
"""A PyTorch Lightning model for transformer networks."""
- def __init__(
- self,
- network: Type[nn.Module],
- optimizer: Union[DictConfig, Dict],
- lr_scheduler: Union[DictConfig, Dict],
- criterion: Union[DictConfig, Dict],
- monitor: str = "val_loss",
- mapping: Optional[List[str]] = None,
- ) -> None:
- super().__init__(network, optimizer, lr_scheduler, criterion, monitor)
+ network: Type[nn.Module] = attr.ib()
+ criterion_config: DictConfig = attr.ib(converter=DictConfig)
+ optimizer_config: DictConfig = attr.ib(converter=DictConfig)
+ lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig)
+ monitor: str = attr.ib()
+ mapping: Type[AbstractMapping] = attr.ib()
+
+ def __attrs_post_init__(self) -> None:
+ super().__init__(
+ network=self.network,
+ optimizer_config=self.optimizer_config,
+ lr_scheduler_config=self.lr_scheduler_config,
+ criterion_config=self.criterion_config,
+ monitor=self.monitor,
+ )
self.mapping, ignore_tokens = self.configure_mapping(mapping)
self.val_cer = CharacterErrorRate(ignore_tokens)
self.test_cer = CharacterErrorRate(ignore_tokens)
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
index 18e8691..7dc950f 100644
--- a/text_recognizer/models/vqvae.py
+++ b/text_recognizer/models/vqvae.py
@@ -18,7 +18,7 @@ class LitVQVAEModel(LitBaseModel):
optimizer: Union[DictConfig, Dict],
lr_scheduler: Union[DictConfig, Dict],
criterion: Union[DictConfig, Dict],
- monitor: str = "val_loss",
+ monitor: str = "val/loss",
*args: Any,
**kwargs: Dict,
) -> None:
@@ -50,7 +50,7 @@ class LitVQVAEModel(LitBaseModel):
reconstructions, vq_loss = self.network(data)
loss = self.loss_fn(reconstructions, data)
loss += vq_loss
- self.log("train_loss", loss)
+ self.log("train/loss", loss)
return loss
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
@@ -59,7 +59,7 @@ class LitVQVAEModel(LitBaseModel):
reconstructions, vq_loss = self.network(data)
loss = self.loss_fn(reconstructions, data)
loss += vq_loss
- self.log("val_loss", loss, prog_bar=True)
+ self.log("val/loss", loss, prog_bar=True)
title = "val_pred_examples"
self._log_prediction(data, reconstructions, title)