From 34098ccbbbf6379c0bd29a987440b8479c743746 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 29 Jul 2021 23:59:52 +0200 Subject: Configs, refactor with attrs, fix attr bug in iam --- text_recognizer/models/base.py | 31 +++++++++++++------------------ text_recognizer/models/transformer.py | 26 ++++++-------------------- 2 files changed, 19 insertions(+), 38 deletions(-) (limited to 'text_recognizer/models') diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index f95df0f..3b83056 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -3,20 +3,25 @@ from typing import Any, Dict, List, Tuple, Type import attr import hydra -import loguru.logger as log +from loguru import logger as log from omegaconf import DictConfig -import pytorch_lightning as LightningModule +from pytorch_lightning import LightningModule import torch from torch import nn from torch import Tensor import torchmetrics +from text_recognizer.networks.base import BaseNetwork + @attr.s class BaseLitModel(LightningModule): """Abstract PyTorch Lightning class.""" - network: Type[nn.Module] = attr.ib() + def __attrs_pre_init__(self) -> None: + super().__init__() + + network: Type[BaseNetwork] = attr.ib() criterion_config: DictConfig = attr.ib(converter=DictConfig) optimizer_config: DictConfig = attr.ib(converter=DictConfig) lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig) @@ -24,23 +29,13 @@ class BaseLitModel(LightningModule): interval: str = attr.ib() monitor: str = attr.ib(default="val/loss") - loss_fn = attr.ib(init=False) - - train_acc = attr.ib(init=False) - val_acc = attr.ib(init=False) - test_acc = attr.ib(init=False) - - def __attrs_pre_init__(self) -> None: - super().__init__() - - def __attrs_post_init__(self) -> None: - self.loss_fn = self._configure_criterion() + loss_fn: Type[nn.Module] = attr.ib(init=False) - # Accuracy metric - self.train_acc = torchmetrics.Accuracy() - self.val_acc = torchmetrics.Accuracy() - self.test_acc = torchmetrics.Accuracy() + train_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy()) + val_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy()) + test_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy()) + @loss_fn.default def configure_criterion(self) -> Type[nn.Module]: """Returns a loss functions.""" log.info(f"Instantiating criterion <{self.criterion_config._target_}>") diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 8c9fe8a..f5cb491 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,13 +1,11 @@ """PyTorch Lightning model for base Transformers.""" -from typing import Dict, List, Optional, Union, Tuple, Type +from typing import Dict, List, Optional, Sequence, Union, Tuple, Type import attr import hydra from omegaconf import DictConfig from torch import nn, Tensor -from text_recognizer.data.emnist import emnist_mapping -from text_recognizer.data.mappings import AbstractMapping from text_recognizer.models.metrics import CharacterErrorRate from text_recognizer.models.base import BaseLitModel @@ -16,30 +14,18 @@ from text_recognizer.models.base import BaseLitModel class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - mapping_config: DictConfig = attr.ib(converter=DictConfig) + ignore_tokens: Sequence[str] = attr.ib(default=("", "", "

",)) + val_cer: CharacterErrorRate = attr.ib(init=False) + test_cer: CharacterErrorRate = attr.ib(init=False) def __attrs_post_init__(self) -> None: - self.mapping, ignore_tokens = self._configure_mapping() - self.val_cer = CharacterErrorRate(ignore_tokens) - self.test_cer = CharacterErrorRate(ignore_tokens) + self.val_cer = CharacterErrorRate(self.ignore_tokens) + self.test_cer = CharacterErrorRate(self.ignore_tokens) def forward(self, data: Tensor) -> Tensor: """Forward pass with the transformer network.""" return self.network.predict(data) - @staticmethod - def _configure_mapping() -> Tuple[Type[AbstractMapping], List[int]]: - """Configure mapping.""" - # TODO: Fix me!!! - # Load config with hydra - mapping, inverse_mapping, _ = emnist_mapping(["\n"]) - start_index = inverse_mapping[""] - end_index = inverse_mapping[""] - pad_index = inverse_mapping["

"] - ignore_tokens = [start_index, end_index, pad_index] - # TODO: add case for sentence pieces - return mapping, ignore_tokens - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: """Training step.""" data, targets = batch -- cgit v1.2.3-70-g09d2