summaryrefslogtreecommitdiff
path: root/text_recognizer/models/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r--text_recognizer/models/base.py11
1 files changed, 5 insertions, 6 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 4e803eb..8dc7a36 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -1,5 +1,5 @@
"""Base PyTorch Lightning model."""
-from typing import Any, Dict, List, Union, Tuple, Type
+from typing import Any, Dict, List, Tuple, Type
import attr
import hydra
@@ -13,7 +13,7 @@ import torchmetrics
@attr.s
-class LitBaseModel(pl.LightningModule):
+class BaseLitModel(pl.LightningModule):
"""Abstract PyTorch Lightning class."""
network: Type[nn.Module] = attr.ib()
@@ -30,18 +30,17 @@ class LitBaseModel(pl.LightningModule):
val_acc = attr.ib(init=False)
test_acc = attr.ib(init=False)
- def __attrs_pre_init__(self):
+ def __attrs_pre_init__(self) -> None:
super().__init__()
- def __attrs_post_init__(self):
- self.loss_fn = self.configure_criterion()
+ def __attrs_post_init__(self) -> None:
+ self.loss_fn = self._configure_criterion()
# Accuracy metric
self.train_acc = torchmetrics.Accuracy()
self.val_acc = torchmetrics.Accuracy()
self.test_acc = torchmetrics.Accuracy()
- @staticmethod
def configure_criterion(self) -> Type[nn.Module]:
"""Returns a loss functions."""
log.info(f"Instantiating criterion <{self.criterion_config._target_}>")