diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-11 22:09:51 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-11 22:09:51 +0200 |
commit | 7e0a0a39a54fd7d1a69b9f12bbd98a2b16285c9c (patch) | |
tree | 694ddde83d7af67457f92c1909bcdea39e09eb0c | |
parent | 2c377b6f7e2d4ba8a7c424c748053cc8d560599a (diff) |
Add mapping to base lit model
-rw-r--r-- | text_recognizer/models/base.py | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 34f40a2..8aadc39 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -11,6 +11,7 @@ from torch import nn from torch import Tensor import torchmetrics +from text_recognizer.data.mappings.base_mapping import AbstractMapping @attr.s(eq=False) class BaseLitModel(LightningModule): @@ -24,6 +25,9 @@ class BaseLitModel(LightningModule): loss_fn: Type[nn.Module] = attr.ib() optimizer_configs: DictConfig = attr.ib() lr_scheduler_configs: Optional[DictConfig] = attr.ib() + mapping: Type[AbstractMapping] = attr.ib() + + # Placeholders train_acc: torchmetrics.Accuracy = attr.ib( init=False, default=torchmetrics.Accuracy() ) |