From 283a6fb2c33213dc05d34f1163422f2855506337 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 10 Oct 2021 18:08:06 +0200 Subject: Update base model --- text_recognizer/models/base.py | 23 +++-------------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 39cf78f..34f40a2 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -11,8 +11,6 @@ from torch import nn from torch import Tensor import torchmetrics -from text_recognizer.data.base_mapping import AbstractMapping - @attr.s(eq=False) class BaseLitModel(LightningModule): @@ -23,7 +21,6 @@ class BaseLitModel(LightningModule): super().__init__() network: Type[nn.Module] = attr.ib() - mapping: Type[AbstractMapping] = attr.ib() loss_fn: Type[nn.Module] = attr.ib() optimizer_configs: DictConfig = attr.ib() lr_scheduler_configs: Optional[DictConfig] = attr.ib() @@ -104,26 +101,12 @@ class BaseLitModel(LightningModule): def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: """Training step.""" - data, targets = batch - logits = self(data) - loss = self.loss_fn(logits, targets) - self.log("train/loss", loss) - self.train_acc(logits, targets) - self.log("train/acc", self.train_acc, on_step=False, on_epoch=True) - return loss + pass def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Validation step.""" - data, targets = batch - logits = self(data) - loss = self.loss_fn(logits, targets) - self.log("val/loss", loss, prog_bar=True) - self.val_acc(logits, targets) - self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) + pass def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" - data, targets = batch - logits = self(data) - self.test_acc(logits, targets) - self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) + pass -- cgit v1.2.3-70-g09d2