summaryrefslogtreecommitdiff
path: root/text_recognizer/models
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-06 17:42:53 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-06 17:42:53 +0200
commiteb5b206f7e1b08435378d2a02395307be55ee6f1 (patch)
tree0cd30234afab698eb632b20a7da97e3bc7e98882 /text_recognizer/models
parent4d1f2cef39688871d2caafce42a09316381a27ae (diff)
Refactoring data with attrs and refactor conf for hydra
Diffstat (limited to 'text_recognizer/models')
-rw-r--r--text_recognizer/models/base.py5
-rw-r--r--text_recognizer/models/metrics.py15
-rw-r--r--text_recognizer/models/transformer.py26
-rw-r--r--text_recognizer/models/vqvae.py34
4 files changed, 24 insertions, 56 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 8dc7a36..f95df0f 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -5,7 +5,7 @@ import attr
import hydra
import loguru.logger as log
from omegaconf import DictConfig
-import pytorch_lightning as pl
+import pytorch_lightning as LightningModule
import torch
from torch import nn
from torch import Tensor
@@ -13,7 +13,7 @@ import torchmetrics
@attr.s
-class BaseLitModel(pl.LightningModule):
+class BaseLitModel(LightningModule):
"""Abstract PyTorch Lightning class."""
network: Type[nn.Module] = attr.ib()
@@ -80,7 +80,6 @@ class BaseLitModel(pl.LightningModule):
"""Configures optimizer and lr scheduler."""
optimizer = self._configure_optimizer()
scheduler = self._configure_lr_scheduler(optimizer)
-
return [optimizer], [scheduler]
def forward(self, data: Tensor) -> Tensor:
diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py
index 58d0537..4117ae2 100644
--- a/text_recognizer/models/metrics.py
+++ b/text_recognizer/models/metrics.py
@@ -1,18 +1,23 @@
"""Character Error Rate (CER)."""
-from typing import Sequence
+from typing import Set, Sequence
+import attr
import editdistance
import torch
from torch import Tensor
-import torchmetrics
+from torchmetrics import Metric
-class CharacterErrorRate(torchmetrics.Metric):
+@attr.s
+class CharacterErrorRate(Metric):
"""Character error rate metric, computed using Levenshtein distance."""
- def __init__(self, ignore_tokens: Sequence[int], *args) -> None:
+ ignore_tokens: Set = attr.ib(converter=set)
+ error: Tensor = attr.ib(init=False)
+ total: Tensor = attr.ib(init=False)
+
+ def __attrs_post_init__(self) -> None:
super().__init__()
- self.ignore_tokens = set(ignore_tokens)
self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index ea54d83..8c9fe8a 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -2,35 +2,24 @@
from typing import Dict, List, Optional, Union, Tuple, Type
import attr
+import hydra
from omegaconf import DictConfig
from torch import nn, Tensor
from text_recognizer.data.emnist import emnist_mapping
from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.models.metrics import CharacterErrorRate
-from text_recognizer.models.base import LitBaseModel
+from text_recognizer.models.base import BaseLitModel
-@attr.s
-class TransformerLitModel(LitBaseModel):
+@attr.s(auto_attribs=True)
+class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- network: Type[nn.Module] = attr.ib()
- criterion_config: DictConfig = attr.ib(converter=DictConfig)
- optimizer_config: DictConfig = attr.ib(converter=DictConfig)
- lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig)
- monitor: str = attr.ib()
- mapping: Type[AbstractMapping] = attr.ib()
+ mapping_config: DictConfig = attr.ib(converter=DictConfig)
def __attrs_post_init__(self) -> None:
- super().__init__(
- network=self.network,
- optimizer_config=self.optimizer_config,
- lr_scheduler_config=self.lr_scheduler_config,
- criterion_config=self.criterion_config,
- monitor=self.monitor,
- )
- self.mapping, ignore_tokens = self.configure_mapping(mapping)
+ self.mapping, ignore_tokens = self._configure_mapping()
self.val_cer = CharacterErrorRate(ignore_tokens)
self.test_cer = CharacterErrorRate(ignore_tokens)
@@ -39,9 +28,10 @@ class TransformerLitModel(LitBaseModel):
return self.network.predict(data)
@staticmethod
- def configure_mapping(mapping: Optional[List[str]]) -> Tuple[List[str], List[int]]:
+ def _configure_mapping() -> Tuple[Type[AbstractMapping], List[int]]:
"""Configure mapping."""
# TODO: Fix me!!!
+ # Load config with hydra
mapping, inverse_mapping, _ = emnist_mapping(["\n"])
start_index = inverse_mapping["<s>"]
end_index = inverse_mapping["<e>"]
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
index 7dc950f..0172163 100644
--- a/text_recognizer/models/vqvae.py
+++ b/text_recognizer/models/vqvae.py
@@ -1,49 +1,23 @@
"""PyTorch Lightning model for base Transformers."""
from typing import Any, Dict, Union, Tuple, Type
+import attr
from omegaconf import DictConfig
from torch import nn
from torch import Tensor
import wandb
-from text_recognizer.models.base import LitBaseModel
+from text_recognizer.models.base import BaseLitModel
-class LitVQVAEModel(LitBaseModel):
+@attr.s(auto_attribs=True)
+class VQVAELitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- def __init__(
- self,
- network: Type[nn.Module],
- optimizer: Union[DictConfig, Dict],
- lr_scheduler: Union[DictConfig, Dict],
- criterion: Union[DictConfig, Dict],
- monitor: str = "val/loss",
- *args: Any,
- **kwargs: Dict,
- ) -> None:
- super().__init__(network, optimizer, lr_scheduler, criterion, monitor)
-
def forward(self, data: Tensor) -> Tensor:
"""Forward pass with the transformer network."""
return self.network.predict(data)
- def _log_prediction(
- self, data: Tensor, reconstructions: Tensor, title: str
- ) -> None:
- """Logs prediction on image with wandb."""
- try:
- self.logger.experiment.log(
- {
- title: [
- wandb.Image(data[0]),
- wandb.Image(reconstructions[0]),
- ]
- }
- )
- except AttributeError:
- pass
-
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
"""Training step."""
data, _ = batch