summaryrefslogtreecommitdiff
path: root/text_recognizer/models/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r--text_recognizer/models/base.py20
1 files changed, 9 insertions, 11 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 8ce5c37..57c5964 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -11,6 +11,8 @@ from torch import nn
from torch import Tensor
import torchmetrics
+from text_recognizer.data.base_mapping import AbstractMapping
+
@attr.s(eq=False)
class BaseLitModel(LightningModule):
@@ -20,12 +22,12 @@ class BaseLitModel(LightningModule):
super().__init__()
network: Type[nn.Module] = attr.ib()
- criterion_config: DictConfig = attr.ib(converter=DictConfig)
- optimizer_config: DictConfig = attr.ib(converter=DictConfig)
- lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig)
+ mapping: Type[AbstractMapping] = attr.ib()
+ loss_fn: Type[nn.Module] = attr.ib()
+ optimizer_config: DictConfig = attr.ib()
+ lr_scheduler_config: DictConfig = attr.ib()
interval: str = attr.ib()
monitor: str = attr.ib(default="val/loss")
- loss_fn: Type[nn.Module] = attr.ib(init=False)
train_acc: torchmetrics.Accuracy = attr.ib(
init=False, default=torchmetrics.Accuracy()
)
@@ -36,12 +38,6 @@ class BaseLitModel(LightningModule):
init=False, default=torchmetrics.Accuracy()
)
- @loss_fn.default
- def configure_criterion(self) -> Type[nn.Module]:
- """Returns a loss functions."""
- log.info(f"Instantiating criterion <{self.criterion_config._target_}>")
- return hydra.utils.instantiate(self.criterion_config)
-
def optimizer_zero_grad(
self,
epoch: int,
@@ -54,7 +50,9 @@ class BaseLitModel(LightningModule):
def _configure_optimizer(self) -> Type[torch.optim.Optimizer]:
"""Configures the optimizer."""
log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>")
- return hydra.utils.instantiate(self.optimizer_config, params=self.parameters())
+ return hydra.utils.instantiate(
+ self.optimizer_config, params=self.network.parameters()
+ )
def _configure_lr_scheduler(
self, optimizer: Type[torch.optim.Optimizer]