diff options
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r-- | text_recognizer/models/base.py | 9 |
1 files changed, 4 insertions, 5 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index bb4e695..f8f4b40 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -9,7 +9,7 @@ from pytorch_lightning import LightningModule from torch import nn, Tensor from torchmetrics import Accuracy -from text_recognizer.data.mappings import EmnistMapping +from text_recognizer.data.tokenizer import Tokenizer class LitBase(LightningModule): @@ -21,8 +21,7 @@ class LitBase(LightningModule): loss_fn: Type[nn.Module], optimizer_config: DictConfig, lr_scheduler_config: Optional[DictConfig], - mapping: EmnistMapping, - ignore_index: Optional[int] = None, + tokenizer: Tokenizer, ) -> None: super().__init__() @@ -30,8 +29,8 @@ class LitBase(LightningModule): self.loss_fn = loss_fn self.optimizer_config = optimizer_config self.lr_scheduler_config = lr_scheduler_config - self.mapping = mapping - + self.tokenizer = tokenizer + ignore_index = int(self.tokenizer.get_value("<p>")) # Placeholders self.train_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) self.val_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) |