summaryrefslogtreecommitdiff
path: root/text_recognizer/models
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:08:06 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:08:06 +0200
commit283a6fb2c33213dc05d34f1163422f2855506337 (patch)
tree43e1b751e1178c056e719bd42cb3798da612baba /text_recognizer/models
parentc564117e02b0cd13896e044ba61265149780a406 (diff)
Update base model
Diffstat (limited to 'text_recognizer/models')
-rw-r--r--text_recognizer/models/base.py23
1 files changed, 3 insertions, 20 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 39cf78f..34f40a2 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -11,8 +11,6 @@ from torch import nn
from torch import Tensor
import torchmetrics
-from text_recognizer.data.base_mapping import AbstractMapping
-
@attr.s(eq=False)
class BaseLitModel(LightningModule):
@@ -23,7 +21,6 @@ class BaseLitModel(LightningModule):
super().__init__()
network: Type[nn.Module] = attr.ib()
- mapping: Type[AbstractMapping] = attr.ib()
loss_fn: Type[nn.Module] = attr.ib()
optimizer_configs: DictConfig = attr.ib()
lr_scheduler_configs: Optional[DictConfig] = attr.ib()
@@ -104,26 +101,12 @@ class BaseLitModel(LightningModule):
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
"""Training step."""
- data, targets = batch
- logits = self(data)
- loss = self.loss_fn(logits, targets)
- self.log("train/loss", loss)
- self.train_acc(logits, targets)
- self.log("train/acc", self.train_acc, on_step=False, on_epoch=True)
- return loss
+ pass
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Validation step."""
- data, targets = batch
- logits = self(data)
- loss = self.loss_fn(logits, targets)
- self.log("val/loss", loss, prog_bar=True)
- self.val_acc(logits, targets)
- self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
+ pass
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Test step."""
- data, targets = batch
- logits = self(data)
- self.test_acc(logits, targets)
- self.log("test/acc", self.test_acc, on_step=False, on_epoch=True)
+ pass