summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-09-08 23:14:23 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-09-08 23:14:23 +0200
commite1b504bca41a9793ed7e88ef14f2e2cbd85724f2 (patch)
tree70b482f890c9ad2be104f0bff8f2172e8411a2be /src/text_recognizer/models
parentfe23001b6588e6e6e9e2c5a99b72f3445cf5206f (diff)
IAM datasets implemented.
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r--src/text_recognizer/models/__init__.py5
-rw-r--r--src/text_recognizer/models/base.py331
-rw-r--r--src/text_recognizer/models/character_model.py20
-rw-r--r--src/text_recognizer/models/line_ctc_model.py105
-rw-r--r--src/text_recognizer/models/metrics.py80
5 files changed, 417 insertions, 124 deletions
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index ff10a07..a3cfc15 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -1,6 +1,7 @@
"""Model modules."""
from .base import Model
from .character_model import CharacterModel
-from .metrics import accuracy
+from .line_ctc_model import LineCTCModel
+from .metrics import accuracy, cer, wer
-__all__ = ["Model", "CharacterModel", "accuracy"]
+__all__ = ["Model", "cer", "CharacterModel", "LineCTCModel", "accuracy", "wer"]
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index 3a84a11..153e19a 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -2,6 +2,7 @@
from abc import ABC, abstractmethod
from glob import glob
+import importlib
from pathlib import Path
import re
import shutil
@@ -10,9 +11,12 @@ from typing import Callable, Dict, Optional, Tuple, Type
from loguru import logger
import torch
from torch import nn
+from torch import Tensor
+from torch.optim.swa_utils import AveragedModel, SWALR
+from torch.utils.data import DataLoader, Dataset, random_split
from torchsummary import summary
-from text_recognizer.datasets import EmnistMapper, fetch_data_loaders
+from text_recognizer.datasets import EmnistMapper
WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights"
@@ -23,8 +27,9 @@ class Model(ABC):
def __init__(
self,
network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
network_args: Optional[Dict] = None,
- data_loader_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,
criterion: Optional[Callable] = None,
criterion_args: Optional[Dict] = None,
@@ -32,14 +37,16 @@ class Model(ABC):
optimizer_args: Optional[Dict] = None,
lr_scheduler: Optional[Callable] = None,
lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
device: Optional[str] = None,
) -> None:
"""Base class, to be inherited by model for specific type of data.
Args:
network_fn (Type[nn.Module]): The PyTorch network.
+ dataset (Type[Dataset]): A dataset class.
network_args (Optional[Dict]): Arguments for the network. Defaults to None.
- data_loader_args (Optional[Dict]): Arguments for the DataLoader.
+ dataset_args (Optional[Dict]): Arguments for the dataset.
metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
criterion (Optional[Callable]): The criterion to evaulate the preformance of the network.
Defaults to None.
@@ -49,107 +56,181 @@ class Model(ABC):
lr_scheduler (Optional[Callable]): A PyTorch learning rate scheduler. Defaults to None.
lr_scheduler_args (Optional[Dict]): Dict of arguments for learning rate scheduler. Defaults to
None.
+ swa_args (Optional[Dict]): Dict of arguments for stochastic weight averaging. Defaults to
+ None.
device (Optional[str]): Name of the device to train on. Defaults to None.
"""
+ # Has to be set in subclass.
+ self._mapper = None
- # Configure data loaders and dataset info.
- dataset_name, self._data_loaders, self._mapper = self._configure_data_loader(
- data_loader_args
- )
- self._input_shape = self._mapper.input_shape
+ # Placeholder.
+ self._input_shape = None
+
+ self.dataset = dataset
+ self.dataset_args = dataset_args
+
+ # Placeholders for datasets.
+ self.train_dataset = None
+ self.val_dataset = None
+ self.test_dataset = None
- self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}"
+ # Stochastic Weight Averaging placeholders.
+ self.swa_args = swa_args
+ self._swa_start = None
+ self._swa_scheduler = None
+ self._swa_network = None
- if metrics is not None:
- self._metrics = metrics
+ # Experiment directory.
+ self.model_dir = None
+
+ # Flag for configured model.
+ self.is_configured = False
+ self.data_prepared = False
+
+ # Flag for stopping training.
+ self.stop_training = False
+
+ self._name = (
+ f"{self.__class__.__name__}_{dataset.__name__}_{network_fn.__name__}"
+ )
+
+ self._metrics = metrics if metrics is not None else None
# Set the device.
- if device is None:
- self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- else:
- self._device = device
+ self._device = (
+ torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ if device is None
+ else device
+ )
# Configure network.
- self._network, self._network_args = self._configure_network(
- network_fn, network_args
- )
+ self._network = None
+ self._network_args = network_args
+ self._configure_network(network_fn)
- # To device.
- self._network.to(self._device)
+ # Place network on device (GPU).
+ self.to_device()
+
+ # Loss and Optimizer placeholders for before loading.
+ self._criterion = criterion
+ self.criterion_args = criterion_args
+
+ self._optimizer = optimizer
+ self.optimizer_args = optimizer_args
+
+ self._lr_scheduler = lr_scheduler
+ self.lr_scheduler_args = lr_scheduler_args
+
+ def configure_model(self) -> None:
+ """Configures criterion and optimizers."""
+ if not self.is_configured:
+ self._configure_criterion()
+ self._configure_optimizers()
+
+ # Prints a summary of the network in terminal.
+ self.summary()
+
+ # Set this flag to true to prevent the model from configuring again.
+ self.is_configured = True
+
+ def prepare_data(self) -> None:
+ """Prepare data for training."""
+ # TODO add downloading.
+ if not self.data_prepared:
+ # Load train dataset.
+ train_dataset = self.dataset(train=True, **self.dataset_args["args"])
+
+ # Set input shape.
+ self._input_shape = train_dataset.input_shape
+
+ # Split train dataset into a training and validation partition.
+ dataset_len = len(train_dataset)
+ train_len = int(
+ self.dataset_args["train_args"]["train_fraction"] * dataset_len
+ )
+ val_len = dataset_len - train_len
+ self.train_dataset, self.val_dataset = random_split(
+ train_dataset, lengths=[train_len, val_len]
+ )
+
+ # Load test dataset.
+ self.test_dataset = self.dataset(train=False, **self.dataset_args["args"])
+
+ # Set the flag to true to disable ability to load data agian.
+ self.data_prepared = True
- # Configure training objects.
- self._criterion = self._configure_criterion(criterion, criterion_args)
- self._optimizer, self._lr_scheduler = self._configure_optimizers(
- optimizer, optimizer_args, lr_scheduler, lr_scheduler_args
+ def train_dataloader(self) -> DataLoader:
+ """Returns data loader for training set."""
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.dataset_args["train_args"]["batch_size"],
+ num_workers=self.dataset_args["train_args"]["num_workers"],
+ shuffle=True,
+ pin_memory=True,
)
- # Experiment directory.
- self.model_dir = None
+ def val_dataloader(self) -> DataLoader:
+ """Returns data loader for validation set."""
+ return DataLoader(
+ self.val_dataset,
+ batch_size=self.dataset_args["train_args"]["batch_size"],
+ num_workers=self.dataset_args["train_args"]["num_workers"],
+ shuffle=True,
+ pin_memory=True,
+ )
- # Flag for stopping training.
- self.stop_training = False
+ def test_dataloader(self) -> DataLoader:
+ """Returns data loader for test set."""
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.dataset_args["train_args"]["batch_size"],
+ num_workers=self.dataset_args["train_args"]["num_workers"],
+ shuffle=False,
+ pin_memory=True,
+ )
- def _configure_data_loader(
- self, data_loader_args: Optional[Dict]
- ) -> Tuple[str, Dict, EmnistMapper]:
- """Loads data loader, dataset name, and dataset mapper."""
- if data_loader_args is not None:
- data_loaders = fetch_data_loaders(**data_loader_args)
- dataset = list(data_loaders.values())[0].dataset
- dataset_name = dataset.__name__
- mapper = dataset.mapper
- else:
- self._mapper = EmnistMapper()
- dataset_name = "*"
- data_loaders = None
- return dataset_name, data_loaders, mapper
-
- def _configure_network(
- self, network_fn: Type[nn.Module], network_args: Optional[Dict]
- ) -> Tuple[Type[nn.Module], Dict]:
+ def _configure_network(self, network_fn: Type[nn.Module]) -> None:
"""Loads the network."""
# If no network arguemnts are given, load pretrained weights if they exist.
- if network_args is None:
- network, network_args = self.load_weights(network_fn)
+ if self._network_args is None:
+ self.load_weights(network_fn)
else:
- network = network_fn(**network_args)
- return network, network_args
+ self._network = network_fn(**self._network_args)
- def _configure_criterion(
- self, criterion: Optional[Callable], criterion_args: Optional[Dict]
- ) -> Optional[Callable]:
+ def _configure_criterion(self) -> None:
"""Loads the criterion."""
- if criterion is not None:
- _criterion = criterion(**criterion_args)
- else:
- _criterion = None
- return _criterion
+ self._criterion = (
+ self._criterion(**self.criterion_args)
+ if self._criterion is not None
+ else None
+ )
- def _configure_optimizers(
- self,
- optimizer: Optional[Callable],
- optimizer_args: Optional[Dict],
- lr_scheduler: Optional[Callable],
- lr_scheduler_args: Optional[Dict],
- ) -> Tuple[Optional[Callable], Optional[Callable]]:
+ def _configure_optimizers(self,) -> None:
"""Loads the optimizers."""
- if optimizer is not None:
- _optimizer = optimizer(self._network.parameters(), **optimizer_args)
+ if self._optimizer is not None:
+ self._optimizer = self._optimizer(
+ self._network.parameters(), **self.optimizer_args
+ )
else:
- _optimizer = None
+ self._optimizer = None
- if _optimizer and lr_scheduler is not None:
- if "OneCycleLR" in str(lr_scheduler):
- lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders["train"])
- _lr_scheduler = lr_scheduler(_optimizer, **lr_scheduler_args)
+ if self._optimizer and self._lr_scheduler is not None:
+ if "OneCycleLR" in str(self._lr_scheduler):
+ self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader())
+ self._lr_scheduler = self._lr_scheduler(
+ self._optimizer, **self.lr_scheduler_args
+ )
else:
- _lr_scheduler = None
+ self._lr_scheduler = None
- return _optimizer, _lr_scheduler
+ if self.swa_args is not None:
+ self._swa_start = self.swa_args["start"]
+ self._swa_scheduler = SWALR(self._optimizer, swa_lr=self.swa_args["lr"])
+ self._swa_network = AveragedModel(self._network).to(self.device)
@property
- def __name__(self) -> str:
+ def name(self) -> str:
"""Returns the name of the model."""
return self._name
@@ -159,7 +240,7 @@ class Model(ABC):
return self._input_shape
@property
- def mapper(self) -> Dict:
+ def mapper(self) -> EmnistMapper:
"""Returns the mapper that maps between ints and chars."""
return self._mapper
@@ -202,13 +283,24 @@ class Model(ABC):
return self._lr_scheduler
@property
- def data_loaders(self) -> Optional[Dict]:
- """Dataloaders."""
- return self._data_loaders
+ def swa_scheduler(self) -> Optional[Callable]:
+ """Returns the stochastic weight averaging scheduler."""
+ return self._swa_scheduler
+
+ @property
+ def swa_start(self) -> Optional[Callable]:
+ """Returns the start epoch of stochastic weight averaging."""
+ return self._swa_start
@property
- def network(self) -> nn.Module:
+ def swa_network(self) -> Optional[Callable]:
+ """Returns the stochastic weight averaging network."""
+ return self._swa_network
+
+ @property
+ def network(self) -> Type[nn.Module]:
"""Neural network."""
+ # Returns the SWA network if available.
return self._network
@property
@@ -217,15 +309,27 @@ class Model(ABC):
WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True)
return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt")
- def summary(self) -> None:
+ def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor:
+ """Compute the loss."""
+ return self.criterion(output, targets)
+
+ def summary(
+ self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 5
+ ) -> None:
"""Prints a summary of the network architecture."""
- device = re.sub("[^A-Za-z]+", "", self.device)
- if self._input_shape is not None:
+
+ if input_shape is not None:
+ summary(self._network, input_shape, depth=depth, device=self.device)
+ elif self._input_shape is not None:
input_shape = (1,) + tuple(self._input_shape)
- summary(self._network, input_shape, device=device)
+ summary(self._network, input_shape, depth=depth, device=self.device)
else:
logger.warning("Could not print summary as input shape is not set.")
+ def to_device(self) -> None:
+ """Places the network on the device (GPU)."""
+ self._network.to(self._device)
+
def _get_state_dict(self) -> Dict:
"""Get the state dict of the model."""
state = {"model_state": self._network.state_dict()}
@@ -236,69 +340,67 @@ class Model(ABC):
if self._lr_scheduler is not None:
state["scheduler_state"] = self._lr_scheduler.state_dict()
+ if self._swa_network is not None:
+ state["swa_network"] = self._swa_network.state_dict()
+
return state
- def load_checkpoint(self, path: Path) -> int:
+ def load_from_checkpoint(self, checkpoint_path: Path) -> None:
"""Load a previously saved checkpoint.
Args:
- path (Path): Path to the experiment with the checkpoint.
-
- Returns:
- epoch (int): The last epoch when the checkpoint was created.
+ checkpoint_path (Path): Path to the experiment with the checkpoint.
"""
logger.debug("Loading checkpoint...")
- if not path.exists():
- logger.debug("File does not exist {str(path)}")
+ if not checkpoint_path.exists():
+ logger.debug("File does not exist {str(checkpoint_path)}")
- checkpoint = torch.load(str(path))
+ checkpoint = torch.load(str(checkpoint_path))
self._network.load_state_dict(checkpoint["model_state"])
if self._optimizer is not None:
self._optimizer.load_state_dict(checkpoint["optimizer_state"])
- # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs.
- # if self._lr_scheduler is not None:
- # self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"])
-
- epoch = checkpoint["epoch"]
+ if self._lr_scheduler is not None:
+ # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs
+ # with OneCycleLR.
+ if self._lr_scheduler.__class__.__name__ != "OneCycleLR":
+ self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"])
- return epoch
+ if self._swa_network is not None:
+ self._swa_network.load_state_dict(checkpoint["swa_network"])
- def save_checkpoint(self, is_best: bool, epoch: int, val_metric: str) -> None:
+ def save_checkpoint(
+ self, checkpoint_path: Path, is_best: bool, epoch: int, val_metric: str
+ ) -> None:
"""Saves a checkpoint of the model.
Args:
+ checkpoint_path (Path): Path to the experiment with the checkpoint.
is_best (bool): If it is the currently best model.
epoch (int): The epoch of the checkpoint.
val_metric (str): Validation metric.
- Raises:
- ValueError: If the self.model_dir is not set.
-
"""
state = self._get_state_dict()
state["is_best"] = is_best
state["epoch"] = epoch
state["network_args"] = self._network_args
- if self.model_dir is None:
- raise ValueError("Experiment directory is not set.")
-
- self.model_dir.mkdir(parents=True, exist_ok=True)
+ checkpoint_path.mkdir(parents=True, exist_ok=True)
logger.debug("Saving checkpoint...")
- filepath = str(self.model_dir / "last.pt")
+ filepath = str(checkpoint_path / "last.pt")
torch.save(state, filepath)
if is_best:
logger.debug(
f"Found a new best {val_metric}. Saving best checkpoint and weights."
)
- shutil.copyfile(filepath, str(self.model_dir / "best.pt"))
+ shutil.copyfile(filepath, str(checkpoint_path / "best.pt"))
- def load_weights(self, network_fn: Type[nn.Module]) -> Tuple[Type[nn.Module], Dict]:
+ def load_weights(self, network_fn: Type[nn.Module]) -> None:
"""Load the network weights."""
logger.debug("Loading network with pretrained weights.")
filename = glob(self.weights_filename)[0]
@@ -308,13 +410,16 @@ class Model(ABC):
)
# Loading state directory.
state_dict = torch.load(filename, map_location=torch.device(self._device))
- network_args = state_dict["network_args"]
+ self._network_args = state_dict["network_args"]
weights = state_dict["model_state"]
# Initializes the network with trained weights.
- network = network_fn(**self._network_args)
- network.load_state_dict(weights)
- return network, network_args
+ self._network = network_fn(**self._network_args)
+ self._network.load_state_dict(weights)
+
+ if "swa_network" in state_dict:
+ self._swa_network = AveragedModel(self._network).to(self.device)
+ self._swa_network.load_state_dict(state_dict["swa_network"])
def save_weights(self, path: Path) -> None:
"""Save the network weights."""
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index 0fd7afd..64ba693 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -4,8 +4,10 @@ from typing import Callable, Dict, Optional, Tuple, Type, Union
import numpy as np
import torch
from torch import nn
+from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
+from text_recognizer.datasets import EmnistMapper
from text_recognizer.models.base import Model
@@ -15,8 +17,9 @@ class CharacterModel(Model):
def __init__(
self,
network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
network_args: Optional[Dict] = None,
- data_loader_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,
criterion: Optional[Callable] = None,
criterion_args: Optional[Dict] = None,
@@ -24,14 +27,16 @@ class CharacterModel(Model):
optimizer_args: Optional[Dict] = None,
lr_scheduler: Optional[Callable] = None,
lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
device: Optional[str] = None,
) -> None:
"""Initializes the CharacterModel."""
super().__init__(
network_fn,
+ dataset,
network_args,
- data_loader_args,
+ dataset_args,
metrics,
criterion,
criterion_args,
@@ -39,8 +44,11 @@ class CharacterModel(Model):
optimizer_args,
lr_scheduler,
lr_scheduler_args,
+ swa_args,
device,
)
+ if self._mapper is None:
+ self._mapper = EmnistMapper()
self.tensor_transform = ToTensor()
self.softmax = nn.Softmax(dim=0)
@@ -67,9 +75,13 @@ class CharacterModel(Model):
# Put the image tensor on the device the model weights are on.
image = image.to(self.device)
- logits = self.network(image)
+ logits = (
+ self.swa_network(image)
+ if self.swa_network is not None
+ else self.network(image)
+ )
- prediction = self.softmax(logits.data.squeeze())
+ prediction = self.softmax(logits.squeeze(0))
index = int(torch.argmax(prediction, dim=0))
confidence_of_prediction = prediction[index]
diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/line_ctc_model.py
new file mode 100644
index 0000000..97308a7
--- /dev/null
+++ b/src/text_recognizer/models/line_ctc_model.py
@@ -0,0 +1,105 @@
+"""Defines the LineCTCModel class."""
+from typing import Callable, Dict, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+from text_recognizer.networks import greedy_decoder
+
+
+class LineCTCModel(Model):
+ """Model for predicting a sequence of characters from an image of a text line."""
+
+ def __init__(
+ self,
+ network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ if self._mapper is None:
+ self._mapper = EmnistMapper()
+ self.tensor_transform = ToTensor()
+
+ def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor:
+ """Computes the CTC loss.
+
+ Args:
+ output (Tensor): Model predictions.
+ targets (Tensor): Correct output sequence.
+
+ Returns:
+ Tensor: The CTC loss.
+
+ """
+ input_lengths = torch.full(
+ size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long,
+ )
+ target_lengths = torch.full(
+ size=(output.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
+ )
+ return self.criterion(output, targets, input_lengths, target_lengths)
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
+ """Predict on a single input."""
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+
+ # Rescale image between 0 and 1.
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+ log_probs = (
+ self.swa_network(image)
+ if self.swa_network is not None
+ else self.network(image)
+ )
+
+ raw_pred, _ = greedy_decoder(
+ predictions=log_probs,
+ character_mapper=self.mapper,
+ blank_label=79,
+ collapse_repeated=True,
+ )
+
+ log_probs, _ = log_probs.max(dim=2)
+
+ predicted_characters = "".join(raw_pred[0])
+ confidence_of_prediction = torch.exp(log_probs.sum()).item()
+
+ return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py
index ac8d68e..6a26216 100644
--- a/src/text_recognizer/models/metrics.py
+++ b/src/text_recognizer/models/metrics.py
@@ -1,19 +1,89 @@
"""Utility functions for models."""
-
+import Levenshtein as Lev
import torch
+from torch import Tensor
+
+from text_recognizer.networks import greedy_decoder
-def accuracy(outputs: torch.Tensor, labels: torch.Tensor) -> float:
+def accuracy(outputs: Tensor, labels: Tensor) -> float:
"""Computes the accuracy.
Args:
- outputs (torch.Tensor): The output from the network.
- labels (torch.Tensor): Ground truth labels.
+ outputs (Tensor): The output from the network.
+ labels (Tensor): Ground truth labels.
Returns:
float: The accuracy for the batch.
"""
_, predicted = torch.max(outputs.data, dim=1)
- acc = (predicted == labels).sum().item() / labels.shape[0]
+ acc = (predicted == labels).sum().float() / labels.shape[0]
+ acc = acc.item()
return acc
+
+
+def cer(outputs: Tensor, targets: Tensor) -> float:
+ """Computes the character error rate.
+
+ Args:
+ outputs (Tensor): The output from the network.
+ targets (Tensor): Ground truth labels.
+
+ Returns:
+ float: The cer for the batch.
+
+ """
+ target_lengths = torch.full(
+ size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
+ )
+ decoded_predictions, decoded_targets = greedy_decoder(
+ outputs, targets, target_lengths
+ )
+
+ lev_dist = 0
+
+ for prediction, target in zip(decoded_predictions, decoded_targets):
+ prediction = "".join(prediction)
+ target = "".join(target)
+ prediction, target = (
+ prediction.replace(" ", ""),
+ target.replace(" ", ""),
+ )
+ lev_dist += Lev.distance(prediction, target)
+ return lev_dist / len(decoded_predictions)
+
+
+def wer(outputs: Tensor, targets: Tensor) -> float:
+ """Computes the Word error rate.
+
+ Args:
+ outputs (Tensor): The output from the network.
+ targets (Tensor): Ground truth labels.
+
+ Returns:
+ float: The wer for the batch.
+
+ """
+ target_lengths = torch.full(
+ size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
+ )
+ decoded_predictions, decoded_targets = greedy_decoder(
+ outputs, targets, target_lengths
+ )
+
+ lev_dist = 0
+
+ for prediction, target in zip(decoded_predictions, decoded_targets):
+ prediction = "".join(prediction)
+ target = "".join(target)
+
+ b = set(prediction.split() + target.split())
+ word2char = dict(zip(b, range(len(b))))
+
+ w1 = [chr(word2char[w]) for w in prediction.split()]
+ w2 = [chr(word2char[w]) for w in target.split()]
+
+ lev_dist += Lev.distance("".join(w1), "".join(w2))
+
+ return lev_dist / len(decoded_predictions)