diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-10 18:08:06 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-10 18:08:06 +0200 |
commit | 283a6fb2c33213dc05d34f1163422f2855506337 (patch) | |
tree | 43e1b751e1178c056e719bd42cb3798da612baba /text_recognizer | |
parent | c564117e02b0cd13896e044ba61265149780a406 (diff) |
Update base model
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/models/base.py | 23 |
1 files changed, 3 insertions, 20 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 39cf78f..34f40a2 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -11,8 +11,6 @@ from torch import nn from torch import Tensor import torchmetrics -from text_recognizer.data.base_mapping import AbstractMapping - @attr.s(eq=False) class BaseLitModel(LightningModule): @@ -23,7 +21,6 @@ class BaseLitModel(LightningModule): super().__init__() network: Type[nn.Module] = attr.ib() - mapping: Type[AbstractMapping] = attr.ib() loss_fn: Type[nn.Module] = attr.ib() optimizer_configs: DictConfig = attr.ib() lr_scheduler_configs: Optional[DictConfig] = attr.ib() @@ -104,26 +101,12 @@ class BaseLitModel(LightningModule): def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: """Training step.""" - data, targets = batch - logits = self(data) - loss = self.loss_fn(logits, targets) - self.log("train/loss", loss) - self.train_acc(logits, targets) - self.log("train/acc", self.train_acc, on_step=False, on_epoch=True) - return loss + pass def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Validation step.""" - data, targets = batch - logits = self(data) - loss = self.loss_fn(logits, targets) - self.log("val/loss", loss, prog_bar=True) - self.val_acc(logits, targets) - self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) + pass def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" - data, targets = batch - logits = self(data) - self.test_acc(logits, targets) - self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) + pass |