summaryrefslogtreecommitdiff
path: root/text_recognizer/models/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r--text_recognizer/models/base.py4
1 files changed, 4 insertions, 0 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 34f40a2..8aadc39 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -11,6 +11,7 @@ from torch import nn
from torch import Tensor
import torchmetrics
+from text_recognizer.data.mappings.base_mapping import AbstractMapping
@attr.s(eq=False)
class BaseLitModel(LightningModule):
@@ -24,6 +25,9 @@ class BaseLitModel(LightningModule):
loss_fn: Type[nn.Module] = attr.ib()
optimizer_configs: DictConfig = attr.ib()
lr_scheduler_configs: Optional[DictConfig] = attr.ib()
+ mapping: Type[AbstractMapping] = attr.ib()
+
+ # Placeholders
train_acc: torchmetrics.Accuracy = attr.ib(
init=False, default=torchmetrics.Accuracy()
)