diff options
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r-- | text_recognizer/models/base.py | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 34f40a2..8aadc39 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -11,6 +11,7 @@ from torch import nn from torch import Tensor import torchmetrics +from text_recognizer.data.mappings.base_mapping import AbstractMapping @attr.s(eq=False) class BaseLitModel(LightningModule): @@ -24,6 +25,9 @@ class BaseLitModel(LightningModule): loss_fn: Type[nn.Module] = attr.ib() optimizer_configs: DictConfig = attr.ib() lr_scheduler_configs: Optional[DictConfig] = attr.ib() + mapping: Type[AbstractMapping] = attr.ib() + + # Placeholders train_acc: torchmetrics.Accuracy = attr.ib( init=False, default=torchmetrics.Accuracy() ) |