diff options
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r-- | text_recognizer/models/base.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index f917635..bb4e695 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -6,7 +6,7 @@ import torch from loguru import logger as log from omegaconf import DictConfig from pytorch_lightning import LightningModule -from torch import Tensor, nn +from torch import nn, Tensor from torchmetrics import Accuracy from text_recognizer.data.mappings import EmnistMapping @@ -22,6 +22,7 @@ class LitBase(LightningModule): optimizer_config: DictConfig, lr_scheduler_config: Optional[DictConfig], mapping: EmnistMapping, + ignore_index: Optional[int] = None, ) -> None: super().__init__() @@ -32,9 +33,9 @@ class LitBase(LightningModule): self.mapping = mapping # Placeholders - self.train_acc = Accuracy(mdmc_reduce="samplewise") - self.val_acc = Accuracy(mdmc_reduce="samplewise") - self.test_acc = Accuracy(mdmc_reduce="samplewise") + self.train_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) + self.val_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) + self.test_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) def optimizer_zero_grad( self, |