summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-11 22:09:51 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-11 22:09:51 +0200
commit7e0a0a39a54fd7d1a69b9f12bbd98a2b16285c9c (patch)
tree694ddde83d7af67457f92c1909bcdea39e09eb0c /text_recognizer
parent2c377b6f7e2d4ba8a7c424c748053cc8d560599a (diff)
Add mapping to base lit model
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/models/base.py4
1 files changed, 4 insertions, 0 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 34f40a2..8aadc39 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -11,6 +11,7 @@ from torch import nn
from torch import Tensor
import torchmetrics
+from text_recognizer.data.mappings.base_mapping import AbstractMapping
@attr.s(eq=False)
class BaseLitModel(LightningModule):
@@ -24,6 +25,9 @@ class BaseLitModel(LightningModule):
loss_fn: Type[nn.Module] = attr.ib()
optimizer_configs: DictConfig = attr.ib()
lr_scheduler_configs: Optional[DictConfig] = attr.ib()
+ mapping: Type[AbstractMapping] = attr.ib()
+
+ # Placeholders
train_acc: torchmetrics.Accuracy = attr.ib(
init=False, default=torchmetrics.Accuracy()
)