summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/base.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-09-08 23:14:23 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-09-08 23:14:23 +0200
commite1b504bca41a9793ed7e88ef14f2e2cbd85724f2 (patch)
tree70b482f890c9ad2be104f0bff8f2172e8411a2be /src/text_recognizer/models/base.py
parentfe23001b6588e6e6e9e2c5a99b72f3445cf5206f (diff)
IAM datasets implemented.
Diffstat (limited to 'src/text_recognizer/models/base.py')
-rw-r--r--src/text_recognizer/models/base.py331
1 files changed, 218 insertions, 113 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index 3a84a11..153e19a 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -2,6 +2,7 @@
from abc import ABC, abstractmethod
from glob import glob
+import importlib
from pathlib import Path
import re
import shutil
@@ -10,9 +11,12 @@ from typing import Callable, Dict, Optional, Tuple, Type
from loguru import logger
import torch
from torch import nn
+from torch import Tensor
+from torch.optim.swa_utils import AveragedModel, SWALR
+from torch.utils.data import DataLoader, Dataset, random_split
from torchsummary import summary
-from text_recognizer.datasets import EmnistMapper, fetch_data_loaders
+from text_recognizer.datasets import EmnistMapper
WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights"
@@ -23,8 +27,9 @@ class Model(ABC):
def __init__(
self,
network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
network_args: Optional[Dict] = None,
- data_loader_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,
criterion: Optional[Callable] = None,
criterion_args: Optional[Dict] = None,
@@ -32,14 +37,16 @@ class Model(ABC):
optimizer_args: Optional[Dict] = None,
lr_scheduler: Optional[Callable] = None,
lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
device: Optional[str] = None,
) -> None:
"""Base class, to be inherited by model for specific type of data.
Args:
network_fn (Type[nn.Module]): The PyTorch network.
+ dataset (Type[Dataset]): A dataset class.
network_args (Optional[Dict]): Arguments for the network. Defaults to None.
- data_loader_args (Optional[Dict]): Arguments for the DataLoader.
+ dataset_args (Optional[Dict]): Arguments for the dataset.
metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
criterion (Optional[Callable]): The criterion to evaulate the preformance of the network.
Defaults to None.
@@ -49,107 +56,181 @@ class Model(ABC):
lr_scheduler (Optional[Callable]): A PyTorch learning rate scheduler. Defaults to None.
lr_scheduler_args (Optional[Dict]): Dict of arguments for learning rate scheduler. Defaults to
None.
+ swa_args (Optional[Dict]): Dict of arguments for stochastic weight averaging. Defaults to
+ None.
device (Optional[str]): Name of the device to train on. Defaults to None.
"""
+ # Has to be set in subclass.
+ self._mapper = None
- # Configure data loaders and dataset info.
- dataset_name, self._data_loaders, self._mapper = self._configure_data_loader(
- data_loader_args
- )
- self._input_shape = self._mapper.input_shape
+ # Placeholder.
+ self._input_shape = None
+
+ self.dataset = dataset
+ self.dataset_args = dataset_args
+
+ # Placeholders for datasets.
+ self.train_dataset = None
+ self.val_dataset = None
+ self.test_dataset = None
- self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}"
+ # Stochastic Weight Averaging placeholders.
+ self.swa_args = swa_args
+ self._swa_start = None
+ self._swa_scheduler = None
+ self._swa_network = None
- if metrics is not None:
- self._metrics = metrics
+ # Experiment directory.
+ self.model_dir = None
+
+ # Flag for configured model.
+ self.is_configured = False
+ self.data_prepared = False
+
+ # Flag for stopping training.
+ self.stop_training = False
+
+ self._name = (
+ f"{self.__class__.__name__}_{dataset.__name__}_{network_fn.__name__}"
+ )
+
+ self._metrics = metrics if metrics is not None else None
# Set the device.
- if device is None:
- self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- else:
- self._device = device
+ self._device = (
+ torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ if device is None
+ else device
+ )
# Configure network.
- self._network, self._network_args = self._configure_network(
- network_fn, network_args
- )
+ self._network = None
+ self._network_args = network_args
+ self._configure_network(network_fn)
- # To device.
- self._network.to(self._device)
+ # Place network on device (GPU).
+ self.to_device()
+
+ # Loss and Optimizer placeholders for before loading.
+ self._criterion = criterion
+ self.criterion_args = criterion_args
+
+ self._optimizer = optimizer
+ self.optimizer_args = optimizer_args
+
+ self._lr_scheduler = lr_scheduler
+ self.lr_scheduler_args = lr_scheduler_args
+
+ def configure_model(self) -> None:
+ """Configures criterion and optimizers."""
+ if not self.is_configured:
+ self._configure_criterion()
+ self._configure_optimizers()
+
+ # Prints a summary of the network in terminal.
+ self.summary()
+
+ # Set this flag to true to prevent the model from configuring again.
+ self.is_configured = True
+
+ def prepare_data(self) -> None:
+ """Prepare data for training."""
+ # TODO add downloading.
+ if not self.data_prepared:
+ # Load train dataset.
+ train_dataset = self.dataset(train=True, **self.dataset_args["args"])
+
+ # Set input shape.
+ self._input_shape = train_dataset.input_shape
+
+ # Split train dataset into a training and validation partition.
+ dataset_len = len(train_dataset)
+ train_len = int(
+ self.dataset_args["train_args"]["train_fraction"] * dataset_len
+ )
+ val_len = dataset_len - train_len
+ self.train_dataset, self.val_dataset = random_split(
+ train_dataset, lengths=[train_len, val_len]
+ )
+
+ # Load test dataset.
+ self.test_dataset = self.dataset(train=False, **self.dataset_args["args"])
+
+ # Set the flag to true to disable ability to load data agian.
+ self.data_prepared = True
- # Configure training objects.
- self._criterion = self._configure_criterion(criterion, criterion_args)
- self._optimizer, self._lr_scheduler = self._configure_optimizers(
- optimizer, optimizer_args, lr_scheduler, lr_scheduler_args
+ def train_dataloader(self) -> DataLoader:
+ """Returns data loader for training set."""
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.dataset_args["train_args"]["batch_size"],
+ num_workers=self.dataset_args["train_args"]["num_workers"],
+ shuffle=True,
+ pin_memory=True,
)
- # Experiment directory.
- self.model_dir = None
+ def val_dataloader(self) -> DataLoader:
+ """Returns data loader for validation set."""
+ return DataLoader(
+ self.val_dataset,
+ batch_size=self.dataset_args["train_args"]["batch_size"],
+ num_workers=self.dataset_args["train_args"]["num_workers"],
+ shuffle=True,
+ pin_memory=True,
+ )
- # Flag for stopping training.
- self.stop_training = False
+ def test_dataloader(self) -> DataLoader:
+ """Returns data loader for test set."""
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.dataset_args["train_args"]["batch_size"],
+ num_workers=self.dataset_args["train_args"]["num_workers"],
+ shuffle=False,
+ pin_memory=True,
+ )
- def _configure_data_loader(
- self, data_loader_args: Optional[Dict]
- ) -> Tuple[str, Dict, EmnistMapper]:
- """Loads data loader, dataset name, and dataset mapper."""
- if data_loader_args is not None:
- data_loaders = fetch_data_loaders(**data_loader_args)
- dataset = list(data_loaders.values())[0].dataset
- dataset_name = dataset.__name__
- mapper = dataset.mapper
- else:
- self._mapper = EmnistMapper()
- dataset_name = "*"
- data_loaders = None
- return dataset_name, data_loaders, mapper
-
- def _configure_network(
- self, network_fn: Type[nn.Module], network_args: Optional[Dict]
- ) -> Tuple[Type[nn.Module], Dict]:
+ def _configure_network(self, network_fn: Type[nn.Module]) -> None:
"""Loads the network."""
# If no network arguemnts are given, load pretrained weights if they exist.
- if network_args is None:
- network, network_args = self.load_weights(network_fn)
+ if self._network_args is None:
+ self.load_weights(network_fn)
else:
- network = network_fn(**network_args)
- return network, network_args
+ self._network = network_fn(**self._network_args)
- def _configure_criterion(
- self, criterion: Optional[Callable], criterion_args: Optional[Dict]
- ) -> Optional[Callable]:
+ def _configure_criterion(self) -> None:
"""Loads the criterion."""
- if criterion is not None:
- _criterion = criterion(**criterion_args)
- else:
- _criterion = None
- return _criterion
+ self._criterion = (
+ self._criterion(**self.criterion_args)
+ if self._criterion is not None
+ else None
+ )
- def _configure_optimizers(
- self,
- optimizer: Optional[Callable],
- optimizer_args: Optional[Dict],
- lr_scheduler: Optional[Callable],
- lr_scheduler_args: Optional[Dict],
- ) -> Tuple[Optional[Callable], Optional[Callable]]:
+ def _configure_optimizers(self,) -> None:
"""Loads the optimizers."""
- if optimizer is not None:
- _optimizer = optimizer(self._network.parameters(), **optimizer_args)
+ if self._optimizer is not None:
+ self._optimizer = self._optimizer(
+ self._network.parameters(), **self.optimizer_args
+ )
else:
- _optimizer = None
+ self._optimizer = None
- if _optimizer and lr_scheduler is not None:
- if "OneCycleLR" in str(lr_scheduler):
- lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders["train"])
- _lr_scheduler = lr_scheduler(_optimizer, **lr_scheduler_args)
+ if self._optimizer and self._lr_scheduler is not None:
+ if "OneCycleLR" in str(self._lr_scheduler):
+ self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader())
+ self._lr_scheduler = self._lr_scheduler(
+ self._optimizer, **self.lr_scheduler_args
+ )
else:
- _lr_scheduler = None
+ self._lr_scheduler = None
- return _optimizer, _lr_scheduler
+ if self.swa_args is not None:
+ self._swa_start = self.swa_args["start"]
+ self._swa_scheduler = SWALR(self._optimizer, swa_lr=self.swa_args["lr"])
+ self._swa_network = AveragedModel(self._network).to(self.device)
@property
- def __name__(self) -> str:
+ def name(self) -> str:
"""Returns the name of the model."""
return self._name
@@ -159,7 +240,7 @@ class Model(ABC):
return self._input_shape
@property
- def mapper(self) -> Dict:
+ def mapper(self) -> EmnistMapper:
"""Returns the mapper that maps between ints and chars."""
return self._mapper
@@ -202,13 +283,24 @@ class Model(ABC):
return self._lr_scheduler
@property
- def data_loaders(self) -> Optional[Dict]:
- """Dataloaders."""
- return self._data_loaders
+ def swa_scheduler(self) -> Optional[Callable]:
+ """Returns the stochastic weight averaging scheduler."""
+ return self._swa_scheduler
+
+ @property
+ def swa_start(self) -> Optional[Callable]:
+ """Returns the start epoch of stochastic weight averaging."""
+ return self._swa_start
@property
- def network(self) -> nn.Module:
+ def swa_network(self) -> Optional[Callable]:
+ """Returns the stochastic weight averaging network."""
+ return self._swa_network
+
+ @property
+ def network(self) -> Type[nn.Module]:
"""Neural network."""
+ # Returns the SWA network if available.
return self._network
@property
@@ -217,15 +309,27 @@ class Model(ABC):
WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True)
return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt")
- def summary(self) -> None:
+ def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor:
+ """Compute the loss."""
+ return self.criterion(output, targets)
+
+ def summary(
+ self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 5
+ ) -> None:
"""Prints a summary of the network architecture."""
- device = re.sub("[^A-Za-z]+", "", self.device)
- if self._input_shape is not None:
+
+ if input_shape is not None:
+ summary(self._network, input_shape, depth=depth, device=self.device)
+ elif self._input_shape is not None:
input_shape = (1,) + tuple(self._input_shape)
- summary(self._network, input_shape, device=device)
+ summary(self._network, input_shape, depth=depth, device=self.device)
else:
logger.warning("Could not print summary as input shape is not set.")
+ def to_device(self) -> None:
+ """Places the network on the device (GPU)."""
+ self._network.to(self._device)
+
def _get_state_dict(self) -> Dict:
"""Get the state dict of the model."""
state = {"model_state": self._network.state_dict()}
@@ -236,69 +340,67 @@ class Model(ABC):
if self._lr_scheduler is not None:
state["scheduler_state"] = self._lr_scheduler.state_dict()
+ if self._swa_network is not None:
+ state["swa_network"] = self._swa_network.state_dict()
+
return state
- def load_checkpoint(self, path: Path) -> int:
+ def load_from_checkpoint(self, checkpoint_path: Path) -> None:
"""Load a previously saved checkpoint.
Args:
- path (Path): Path to the experiment with the checkpoint.
-
- Returns:
- epoch (int): The last epoch when the checkpoint was created.
+ checkpoint_path (Path): Path to the experiment with the checkpoint.
"""
logger.debug("Loading checkpoint...")
- if not path.exists():
- logger.debug("File does not exist {str(path)}")
+ if not checkpoint_path.exists():
+ logger.debug("File does not exist {str(checkpoint_path)}")
- checkpoint = torch.load(str(path))
+ checkpoint = torch.load(str(checkpoint_path))
self._network.load_state_dict(checkpoint["model_state"])
if self._optimizer is not None:
self._optimizer.load_state_dict(checkpoint["optimizer_state"])
- # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs.
- # if self._lr_scheduler is not None:
- # self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"])
-
- epoch = checkpoint["epoch"]
+ if self._lr_scheduler is not None:
+ # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs
+ # with OneCycleLR.
+ if self._lr_scheduler.__class__.__name__ != "OneCycleLR":
+ self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"])
- return epoch
+ if self._swa_network is not None:
+ self._swa_network.load_state_dict(checkpoint["swa_network"])
- def save_checkpoint(self, is_best: bool, epoch: int, val_metric: str) -> None:
+ def save_checkpoint(
+ self, checkpoint_path: Path, is_best: bool, epoch: int, val_metric: str
+ ) -> None:
"""Saves a checkpoint of the model.
Args:
+ checkpoint_path (Path): Path to the experiment with the checkpoint.
is_best (bool): If it is the currently best model.
epoch (int): The epoch of the checkpoint.
val_metric (str): Validation metric.
- Raises:
- ValueError: If the self.model_dir is not set.
-
"""
state = self._get_state_dict()
state["is_best"] = is_best
state["epoch"] = epoch
state["network_args"] = self._network_args
- if self.model_dir is None:
- raise ValueError("Experiment directory is not set.")
-
- self.model_dir.mkdir(parents=True, exist_ok=True)
+ checkpoint_path.mkdir(parents=True, exist_ok=True)
logger.debug("Saving checkpoint...")
- filepath = str(self.model_dir / "last.pt")
+ filepath = str(checkpoint_path / "last.pt")
torch.save(state, filepath)
if is_best:
logger.debug(
f"Found a new best {val_metric}. Saving best checkpoint and weights."
)
- shutil.copyfile(filepath, str(self.model_dir / "best.pt"))
+ shutil.copyfile(filepath, str(checkpoint_path / "best.pt"))
- def load_weights(self, network_fn: Type[nn.Module]) -> Tuple[Type[nn.Module], Dict]:
+ def load_weights(self, network_fn: Type[nn.Module]) -> None:
"""Load the network weights."""
logger.debug("Loading network with pretrained weights.")
filename = glob(self.weights_filename)[0]
@@ -308,13 +410,16 @@ class Model(ABC):
)
# Loading state directory.
state_dict = torch.load(filename, map_location=torch.device(self._device))
- network_args = state_dict["network_args"]
+ self._network_args = state_dict["network_args"]
weights = state_dict["model_state"]
# Initializes the network with trained weights.
- network = network_fn(**self._network_args)
- network.load_state_dict(weights)
- return network, network_args
+ self._network = network_fn(**self._network_args)
+ self._network.load_state_dict(weights)
+
+ if "swa_network" in state_dict:
+ self._swa_network = AveragedModel(self._network).to(self.device)
+ self._swa_network.load_state_dict(state_dict["swa_network"])
def save_weights(self, path: Path) -> None:
"""Save the network weights."""