summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-07-22 23:18:08 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-07-22 23:18:08 +0200
commitf473456c19558aaf8552df97a51d4e18cc69dfa8 (patch)
tree0d35ce2410ff623ba5fb433d616d95b67ecf7a98 /src/text_recognizer/models
parentad3bd52530f4800d4fb05dfef3354921f95513af (diff)
Working training loop and testing of trained CharacterModel.
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r--src/text_recognizer/models/__init__.py4
-rw-r--r--src/text_recognizer/models/base.py66
-rw-r--r--src/text_recognizer/models/character_model.py30
-rw-r--r--src/text_recognizer/models/metrics.py2
4 files changed, 69 insertions, 33 deletions
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index d265dcf..ff10a07 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -1,2 +1,6 @@
"""Model modules."""
+from .base import Model
from .character_model import CharacterModel
+from .metrics import accuracy
+
+__all__ = ["Model", "CharacterModel", "accuracy"]
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index 0cc531a..b78eacb 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -1,9 +1,11 @@
"""Abstract Model class for PyTorch neural networks."""
from abc import ABC, abstractmethod
+from glob import glob
from pathlib import Path
+import re
import shutil
-from typing import Callable, Dict, Optional, Tuple
+from typing import Callable, Dict, Optional, Tuple, Type
from loguru import logger
import torch
@@ -19,7 +21,7 @@ class Model(ABC):
def __init__(
self,
- network_fn: Callable,
+ network_fn: Type[nn.Module],
network_args: Dict,
data_loader: Optional[Callable] = None,
data_loader_args: Optional[Dict] = None,
@@ -35,7 +37,7 @@ class Model(ABC):
"""Base class, to be inherited by model for specific type of data.
Args:
- network_fn (Callable): The PyTorch network.
+ network_fn (Type[nn.Module]): The PyTorch network.
network_args (Dict): Arguments for the network.
data_loader (Optional[Callable]): A function that fetches train and val DataLoader.
data_loader_args (Optional[Dict]): Arguments for the DataLoader.
@@ -57,27 +59,29 @@ class Model(ABC):
self._data_loaders = data_loader(**data_loader_args)
dataset_name = self._data_loaders.__name__
else:
- dataset_name = ""
+ dataset_name = "*"
self._data_loaders = None
- self.name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}"
+ self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}"
# Extract the input shape for the torchsummary.
- self._input_shape = network_args.pop("input_shape")
+ if isinstance(network_args["input_size"], int):
+ self._input_shape = (1,) + tuple([network_args["input_size"]])
+ else:
+ self._input_shape = (1,) + tuple(network_args["input_size"])
if metrics is not None:
self._metrics = metrics
# Set the device.
- if self.device is None:
- self._device = torch.device(
- "cuda:0" if torch.cuda.is_available() else "cpu"
- )
+ if device is None:
+ self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self._device = device
# Load network.
- self._network = network_fn(**network_args)
+ self.network_args = network_args
+ self._network = network_fn(**self.network_args)
# To device.
self._network.to(self._device)
@@ -95,13 +99,29 @@ class Model(ABC):
# Set learning rate scheduler.
self._lr_scheduler = None
if lr_scheduler is not None:
+ # OneCycleLR needs the number of steps in an epoch as an input argument.
+ if "OneCycleLR" in str(lr_scheduler):
+ lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders("train"))
self._lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args)
+ # Class mapping.
+ self._mapping = None
+
+ @property
+ def __name__(self) -> str:
+ """Returns the name of the model."""
+ return self._name
+
@property
def input_shape(self) -> Tuple[int, ...]:
"""The input shape."""
return self._input_shape
+ @property
+ def mapping(self) -> Dict:
+ """Returns the class mapping."""
+ return self._mapping
+
def eval(self) -> None:
"""Sets the network to evaluation mode."""
self._network.eval()
@@ -149,13 +169,14 @@ class Model(ABC):
def weights_filename(self) -> str:
"""Filepath to the network weights."""
WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True)
- return str(WEIGHT_DIRNAME / f"{self.name}_weights.pt")
+ return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt")
def summary(self) -> None:
"""Prints a summary of the network architecture."""
- summary(self._network, self._input_shape, device=self.device)
+ device = re.sub("[^A-Za-z]+", "", self.device)
+ summary(self._network, self._input_shape, device=device)
- def _get_state(self) -> Dict:
+ def _get_state_dict(self) -> Dict:
"""Get the state dict of the model."""
state = {"model_state": self._network.state_dict()}
if self._optimizer is not None:
@@ -172,6 +193,7 @@ class Model(ABC):
epoch (int): The last epoch when the checkpoint was created.
"""
+ logger.debug("Loading checkpoint...")
if not path.exists():
logger.debug("File does not exist {str(path)}")
@@ -200,6 +222,7 @@ class Model(ABC):
state = self._get_state_dict()
state["is_best"] = is_best
state["epoch"] = epoch
+ state["network_args"] = self.network_args
path.mkdir(parents=True, exist_ok=True)
@@ -216,15 +239,18 @@ class Model(ABC):
def load_weights(self) -> None:
"""Load the network weights."""
logger.debug("Loading network weights.")
- weights = torch.load(self.weights_filename)["model_state"]
+ filename = glob(self.weights_filename)[0]
+ weights = torch.load(filename, map_location=torch.device(self._device))[
+ "model_state"
+ ]
self._network.load_state_dict(weights)
- def save_weights(self) -> None:
+ def save_weights(self, path: Path) -> None:
"""Save the network weights."""
- logger.debug("Saving network weights.")
- torch.save({"model_state": self._network.state_dict()}, self.weights_filename)
+ logger.debug("Saving the best network weights.")
+ shutil.copyfile(str(path / "best.pt"), self.weights_filename)
@abstractmethod
- def mapping(self) -> Dict:
- """Mapping from network output to class."""
+ def load_mapping(self) -> None:
+ """Loads class mapping from network output to character."""
...
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index fd69bf2..527fc7d 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -1,5 +1,5 @@
"""Defines the CharacterModel class."""
-from typing import Callable, Dict, Optional, Tuple
+from typing import Callable, Dict, Optional, Tuple, Type
import numpy as np
import torch
@@ -8,7 +8,6 @@ from torchvision.transforms import ToTensor
from text_recognizer.datasets.emnist_dataset import load_emnist_mapping
from text_recognizer.models.base import Model
-from text_recognizer.networks.mlp import mlp
class CharacterModel(Model):
@@ -16,8 +15,9 @@ class CharacterModel(Model):
def __init__(
self,
- network_fn: Callable,
+ network_fn: Type[nn.Module],
network_args: Dict,
+ data_loader: Optional[Callable] = None,
data_loader_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,
criterion: Optional[Callable] = None,
@@ -33,6 +33,7 @@ class CharacterModel(Model):
super().__init__(
network_fn,
network_args,
+ data_loader,
data_loader_args,
metrics,
criterion,
@@ -43,13 +44,13 @@ class CharacterModel(Model):
lr_scheduler_args,
device,
)
- self.emnist_mapping = self.mapping()
- self.eval()
+ self.load_mapping()
+ self.tensor_transform = ToTensor()
+ self.softmax = nn.Softmax(dim=0)
- def mapping(self) -> Dict[int, str]:
+ def load_mapping(self) -> None:
"""Mapping between integers and classes."""
- mapping = load_emnist_mapping()
- return mapping
+ self._mapping = load_emnist_mapping()
def predict_on_image(self, image: np.ndarray) -> Tuple[str, float]:
"""Character prediction on an image.
@@ -61,15 +62,20 @@ class CharacterModel(Model):
Tuple[str, float]: The predicted character and the confidence in the prediction.
"""
+
if image.dtype == np.uint8:
image = (image / 255).astype(np.float32)
# Conver to Pytorch Tensor.
- image = ToTensor(image)
+ image = self.tensor_transform(image)
+
+ with torch.no_grad():
+ logits = self.network(image)
+
+ prediction = self.softmax(logits.data.squeeze())
- prediction = self.network(image)
- index = torch.argmax(prediction, dim=1)
+ index = int(torch.argmax(prediction, dim=0))
confidence_of_prediction = prediction[index]
- predicted_character = self.emnist_mapping[index]
+ predicted_character = self._mapping[index]
return predicted_character, confidence_of_prediction
diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py
index e2a30a9..ac8d68e 100644
--- a/src/text_recognizer/models/metrics.py
+++ b/src/text_recognizer/models/metrics.py
@@ -3,7 +3,7 @@
import torch
-def accuracy(outputs: torch.Tensor, labels: torch.Tensro) -> float:
+def accuracy(outputs: torch.Tensor, labels: torch.Tensor) -> float:
"""Computes the accuracy.
Args: