summaryrefslogtreecommitdiff
path: root/text_recognizer/models/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r--text_recognizer/models/base.py9
1 files changed, 4 insertions, 5 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index bb4e695..f8f4b40 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -9,7 +9,7 @@ from pytorch_lightning import LightningModule
from torch import nn, Tensor
from torchmetrics import Accuracy
-from text_recognizer.data.mappings import EmnistMapping
+from text_recognizer.data.tokenizer import Tokenizer
class LitBase(LightningModule):
@@ -21,8 +21,7 @@ class LitBase(LightningModule):
loss_fn: Type[nn.Module],
optimizer_config: DictConfig,
lr_scheduler_config: Optional[DictConfig],
- mapping: EmnistMapping,
- ignore_index: Optional[int] = None,
+ tokenizer: Tokenizer,
) -> None:
super().__init__()
@@ -30,8 +29,8 @@ class LitBase(LightningModule):
self.loss_fn = loss_fn
self.optimizer_config = optimizer_config
self.lr_scheduler_config = lr_scheduler_config
- self.mapping = mapping
-
+ self.tokenizer = tokenizer
+ ignore_index = int(self.tokenizer.get_value("<p>"))
# Placeholders
self.train_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index)
self.val_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index)