summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/models/base.py')
-rw-r--r--src/text_recognizer/models/base.py66
1 files changed, 50 insertions, 16 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index caf8065..cc44c92 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -6,7 +6,7 @@ import importlib
from pathlib import Path
import re
import shutil
-from typing import Callable, Dict, Optional, Tuple, Type
+from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from loguru import logger
import torch
@@ -15,6 +15,7 @@ from torch import Tensor
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.utils.data import DataLoader, Dataset, random_split
from torchsummary import summary
+from torchvision.transforms import Compose
from text_recognizer.datasets import EmnistMapper
@@ -128,16 +129,42 @@ class Model(ABC):
self._configure_criterion()
self._configure_optimizers()
- # Prints a summary of the network in terminal.
- self.summary()
-
# Set this flag to true to prevent the model from configuring again.
self.is_configured = True
+ def _configure_transforms(self) -> None:
+ # Load transforms.
+ transforms_module = importlib.import_module(
+ "text_recognizer.datasets.transforms"
+ )
+ if (
+ "transform" in self.dataset_args["args"]
+ and self.dataset_args["args"]["transform"] is not None
+ ):
+ transform_ = []
+ for t in self.dataset_args["args"]["transform"]:
+ args = t["args"] or {}
+ transform_.append(getattr(transforms_module, t["type"])(**args))
+ self.dataset_args["args"]["transform"] = Compose(transform_)
+
+ if (
+ "target_transform" in self.dataset_args["args"]
+ and self.dataset_args["args"]["target_transform"] is not None
+ ):
+ target_transform_ = [
+ torch.tensor,
+ ]
+ for t in self.dataset_args["args"]["target_transform"]:
+ args = t["args"] or {}
+ target_transform_.append(getattr(transforms_module, t["type"])(**args))
+ self.dataset_args["args"]["target_transform"] = Compose(target_transform_)
+
def prepare_data(self) -> None:
"""Prepare data for training."""
# TODO add downloading.
if not self.data_prepared:
+ self._configure_transforms()
+
# Load train dataset.
train_dataset = self.dataset(train=True, **self.dataset_args["args"])
train_dataset.load_or_generate_data()
@@ -327,20 +354,20 @@ class Model(ABC):
else:
return self.network(x)
- def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor:
- """Compute the loss."""
- return self.criterion(output, targets)
-
def summary(
- self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 3
+ self,
+ input_shape: Optional[Union[List, Tuple]] = None,
+ depth: int = 4,
+ device: Optional[str] = None,
) -> None:
"""Prints a summary of the network architecture."""
+ device = self.device if device is None else device
if input_shape is not None:
- summary(self.network, input_shape, depth=depth, device=self.device)
+ summary(self.network, input_shape, depth=depth, device=device)
elif self._input_shape is not None:
input_shape = (1,) + tuple(self._input_shape)
- summary(self.network, input_shape, depth=depth, device=self.device)
+ summary(self.network, input_shape, depth=depth, device=device)
else:
logger.warning("Could not print summary as input shape is not set.")
@@ -356,25 +383,29 @@ class Model(ABC):
state["optimizer_state"] = self._optimizer.state_dict()
if self._lr_scheduler is not None:
- state["scheduler_state"] = self._lr_scheduler.state_dict()
+ state["scheduler_state"] = self._lr_scheduler["lr_scheduler"].state_dict()
+ state["scheduler_interval"] = self._lr_scheduler["interval"]
if self._swa_network is not None:
state["swa_network"] = self._swa_network.state_dict()
return state
- def load_from_checkpoint(self, checkpoint_path: Path) -> None:
+ def load_from_checkpoint(self, checkpoint_path: Union[str, Path]) -> None:
"""Load a previously saved checkpoint.
Args:
checkpoint_path (Path): Path to the experiment with the checkpoint.
"""
+ checkpoint_path = Path(checkpoint_path)
+ self.prepare_data()
+ self.configure_model()
logger.debug("Loading checkpoint...")
if not checkpoint_path.exists():
logger.debug("File does not exist {str(checkpoint_path)}")
- checkpoint = torch.load(str(checkpoint_path))
+ checkpoint = torch.load(str(checkpoint_path), map_location=self.device)
self._network.load_state_dict(checkpoint["model_state"])
if self._optimizer is not None:
@@ -383,8 +414,11 @@ class Model(ABC):
if self._lr_scheduler is not None:
# Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs
# with OneCycleLR.
- if self._lr_scheduler.__class__.__name__ != "OneCycleLR":
- self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"])
+ if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR":
+ self._lr_scheduler["lr_scheduler"].load_state_dict(
+ checkpoint["scheduler_state"]
+ )
+ self._lr_scheduler["interval"] = checkpoint["scheduler_interval"]
if self._swa_network is not None:
self._swa_network.load_state_dict(checkpoint["swa_network"])