summaryrefslogtreecommitdiff
path: root/text_recognizer/models/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r--text_recognizer/models/base.py31
1 files changed, 13 insertions, 18 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index f95df0f..3b83056 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -3,20 +3,25 @@ from typing import Any, Dict, List, Tuple, Type
import attr
import hydra
-import loguru.logger as log
+from loguru import logger as log
from omegaconf import DictConfig
-import pytorch_lightning as LightningModule
+from pytorch_lightning import LightningModule
import torch
from torch import nn
from torch import Tensor
import torchmetrics
+from text_recognizer.networks.base import BaseNetwork
+
@attr.s
class BaseLitModel(LightningModule):
"""Abstract PyTorch Lightning class."""
- network: Type[nn.Module] = attr.ib()
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ network: Type[BaseNetwork] = attr.ib()
criterion_config: DictConfig = attr.ib(converter=DictConfig)
optimizer_config: DictConfig = attr.ib(converter=DictConfig)
lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig)
@@ -24,23 +29,13 @@ class BaseLitModel(LightningModule):
interval: str = attr.ib()
monitor: str = attr.ib(default="val/loss")
- loss_fn = attr.ib(init=False)
-
- train_acc = attr.ib(init=False)
- val_acc = attr.ib(init=False)
- test_acc = attr.ib(init=False)
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
-
- def __attrs_post_init__(self) -> None:
- self.loss_fn = self._configure_criterion()
+ loss_fn: Type[nn.Module] = attr.ib(init=False)
- # Accuracy metric
- self.train_acc = torchmetrics.Accuracy()
- self.val_acc = torchmetrics.Accuracy()
- self.test_acc = torchmetrics.Accuracy()
+ train_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy())
+ val_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy())
+ test_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy())
+ @loss_fn.default
def configure_criterion(self) -> Type[nn.Module]:
"""Returns a loss functions."""
log.info(f"Instantiating criterion <{self.criterion_config._target_}>")