From e1b504bca41a9793ed7e88ef14f2e2cbd85724f2 Mon Sep 17 00:00:00 2001
From: aktersnurra <gustaf.rydholm@gmail.com>
Date: Tue, 8 Sep 2020 23:14:23 +0200
Subject: IAM datasets implemented.

---
 src/text_recognizer/models/__init__.py        |   5 +-
 src/text_recognizer/models/base.py            | 331 +++++++++++++++++---------
 src/text_recognizer/models/character_model.py |  20 +-
 src/text_recognizer/models/line_ctc_model.py  | 105 ++++++++
 src/text_recognizer/models/metrics.py         |  80 ++++++-
 5 files changed, 417 insertions(+), 124 deletions(-)
 create mode 100644 src/text_recognizer/models/line_ctc_model.py

(limited to 'src/text_recognizer/models')

diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index ff10a07..a3cfc15 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -1,6 +1,7 @@
 """Model modules."""
 from .base import Model
 from .character_model import CharacterModel
-from .metrics import accuracy
+from .line_ctc_model import LineCTCModel
+from .metrics import accuracy, cer, wer
 
-__all__ = ["Model", "CharacterModel", "accuracy"]
+__all__ = ["Model", "cer", "CharacterModel", "LineCTCModel", "accuracy", "wer"]
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index 3a84a11..153e19a 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -2,6 +2,7 @@
 
 from abc import ABC, abstractmethod
 from glob import glob
+import importlib
 from pathlib import Path
 import re
 import shutil
@@ -10,9 +11,12 @@ from typing import Callable, Dict, Optional, Tuple, Type
 from loguru import logger
 import torch
 from torch import nn
+from torch import Tensor
+from torch.optim.swa_utils import AveragedModel, SWALR
+from torch.utils.data import DataLoader, Dataset, random_split
 from torchsummary import summary
 
-from text_recognizer.datasets import EmnistMapper, fetch_data_loaders
+from text_recognizer.datasets import EmnistMapper
 
 WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights"
 
@@ -23,8 +27,9 @@ class Model(ABC):
     def __init__(
         self,
         network_fn: Type[nn.Module],
+        dataset: Type[Dataset],
         network_args: Optional[Dict] = None,
-        data_loader_args: Optional[Dict] = None,
+        dataset_args: Optional[Dict] = None,
         metrics: Optional[Dict] = None,
         criterion: Optional[Callable] = None,
         criterion_args: Optional[Dict] = None,
@@ -32,14 +37,16 @@ class Model(ABC):
         optimizer_args: Optional[Dict] = None,
         lr_scheduler: Optional[Callable] = None,
         lr_scheduler_args: Optional[Dict] = None,
+        swa_args: Optional[Dict] = None,
         device: Optional[str] = None,
     ) -> None:
         """Base class, to be inherited by model for specific type of data.
 
         Args:
             network_fn (Type[nn.Module]): The PyTorch network.
+            dataset (Type[Dataset]): A dataset class.
             network_args (Optional[Dict]): Arguments for the network. Defaults to None.
-            data_loader_args (Optional[Dict]):  Arguments for the DataLoader.
+            dataset_args (Optional[Dict]):  Arguments for the dataset.
             metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
             criterion (Optional[Callable]): The criterion to evaulate the preformance of the network.
                 Defaults to None.
@@ -49,107 +56,181 @@ class Model(ABC):
             lr_scheduler (Optional[Callable]): A PyTorch learning rate scheduler. Defaults to None.
             lr_scheduler_args (Optional[Dict]): Dict of arguments for learning rate scheduler. Defaults to
                 None.
+            swa_args (Optional[Dict]): Dict of arguments for stochastic weight averaging. Defaults to
+                None.
             device (Optional[str]): Name of the device to train on. Defaults to None.
 
         """
+        # Has to be set in subclass.
+        self._mapper = None
 
-        # Configure data loaders and dataset info.
-        dataset_name, self._data_loaders, self._mapper = self._configure_data_loader(
-            data_loader_args
-        )
-        self._input_shape = self._mapper.input_shape
+        # Placeholder.
+        self._input_shape = None
+
+        self.dataset = dataset
+        self.dataset_args = dataset_args
+
+        # Placeholders for datasets.
+        self.train_dataset = None
+        self.val_dataset = None
+        self.test_dataset = None
 
-        self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}"
+        # Stochastic Weight Averaging placeholders.
+        self.swa_args = swa_args
+        self._swa_start = None
+        self._swa_scheduler = None
+        self._swa_network = None
 
-        if metrics is not None:
-            self._metrics = metrics
+        # Experiment directory.
+        self.model_dir = None
+
+        # Flag for configured model.
+        self.is_configured = False
+        self.data_prepared = False
+
+        # Flag for stopping training.
+        self.stop_training = False
+
+        self._name = (
+            f"{self.__class__.__name__}_{dataset.__name__}_{network_fn.__name__}"
+        )
+
+        self._metrics = metrics if metrics is not None else None
 
         # Set the device.
-        if device is None:
-            self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-        else:
-            self._device = device
+        self._device = (
+            torch.device("cuda" if torch.cuda.is_available() else "cpu")
+            if device is None
+            else device
+        )
 
         # Configure network.
-        self._network, self._network_args = self._configure_network(
-            network_fn, network_args
-        )
+        self._network = None
+        self._network_args = network_args
+        self._configure_network(network_fn)
 
-        # To device.
-        self._network.to(self._device)
+        # Place network on device (GPU).
+        self.to_device()
+
+        # Loss and Optimizer placeholders for before loading.
+        self._criterion = criterion
+        self.criterion_args = criterion_args
+
+        self._optimizer = optimizer
+        self.optimizer_args = optimizer_args
+
+        self._lr_scheduler = lr_scheduler
+        self.lr_scheduler_args = lr_scheduler_args
+
+    def configure_model(self) -> None:
+        """Configures criterion and optimizers."""
+        if not self.is_configured:
+            self._configure_criterion()
+            self._configure_optimizers()
+
+            # Prints a summary of the network in terminal.
+            self.summary()
+
+            # Set this flag to true to prevent the model from configuring again.
+            self.is_configured = True
+
+    def prepare_data(self) -> None:
+        """Prepare data for training."""
+        # TODO add downloading.
+        if not self.data_prepared:
+            # Load train dataset.
+            train_dataset = self.dataset(train=True, **self.dataset_args["args"])
+
+            # Set input shape.
+            self._input_shape = train_dataset.input_shape
+
+            # Split train dataset into a training and validation partition.
+            dataset_len = len(train_dataset)
+            train_len = int(
+                self.dataset_args["train_args"]["train_fraction"] * dataset_len
+            )
+            val_len = dataset_len - train_len
+            self.train_dataset, self.val_dataset = random_split(
+                train_dataset, lengths=[train_len, val_len]
+            )
+
+            # Load test dataset.
+            self.test_dataset = self.dataset(train=False, **self.dataset_args["args"])
+
+            # Set the flag to true to disable ability to load data agian.
+            self.data_prepared = True
 
-        # Configure training objects.
-        self._criterion = self._configure_criterion(criterion, criterion_args)
-        self._optimizer, self._lr_scheduler = self._configure_optimizers(
-            optimizer, optimizer_args, lr_scheduler, lr_scheduler_args
+    def train_dataloader(self) -> DataLoader:
+        """Returns data loader for training set."""
+        return DataLoader(
+            self.train_dataset,
+            batch_size=self.dataset_args["train_args"]["batch_size"],
+            num_workers=self.dataset_args["train_args"]["num_workers"],
+            shuffle=True,
+            pin_memory=True,
         )
 
-        # Experiment directory.
-        self.model_dir = None
+    def val_dataloader(self) -> DataLoader:
+        """Returns data loader for validation set."""
+        return DataLoader(
+            self.val_dataset,
+            batch_size=self.dataset_args["train_args"]["batch_size"],
+            num_workers=self.dataset_args["train_args"]["num_workers"],
+            shuffle=True,
+            pin_memory=True,
+        )
 
-        # Flag for stopping training.
-        self.stop_training = False
+    def test_dataloader(self) -> DataLoader:
+        """Returns data loader for test set."""
+        return DataLoader(
+            self.test_dataset,
+            batch_size=self.dataset_args["train_args"]["batch_size"],
+            num_workers=self.dataset_args["train_args"]["num_workers"],
+            shuffle=False,
+            pin_memory=True,
+        )
 
-    def _configure_data_loader(
-        self, data_loader_args: Optional[Dict]
-    ) -> Tuple[str, Dict, EmnistMapper]:
-        """Loads data loader, dataset name, and dataset mapper."""
-        if data_loader_args is not None:
-            data_loaders = fetch_data_loaders(**data_loader_args)
-            dataset = list(data_loaders.values())[0].dataset
-            dataset_name = dataset.__name__
-            mapper = dataset.mapper
-        else:
-            self._mapper = EmnistMapper()
-            dataset_name = "*"
-            data_loaders = None
-        return dataset_name, data_loaders, mapper
-
-    def _configure_network(
-        self, network_fn: Type[nn.Module], network_args: Optional[Dict]
-    ) -> Tuple[Type[nn.Module], Dict]:
+    def _configure_network(self, network_fn: Type[nn.Module]) -> None:
         """Loads the network."""
         # If no network arguemnts are given, load pretrained weights if they exist.
-        if network_args is None:
-            network, network_args = self.load_weights(network_fn)
+        if self._network_args is None:
+            self.load_weights(network_fn)
         else:
-            network = network_fn(**network_args)
-        return network, network_args
+            self._network = network_fn(**self._network_args)
 
-    def _configure_criterion(
-        self, criterion: Optional[Callable], criterion_args: Optional[Dict]
-    ) -> Optional[Callable]:
+    def _configure_criterion(self) -> None:
         """Loads the criterion."""
-        if criterion is not None:
-            _criterion = criterion(**criterion_args)
-        else:
-            _criterion = None
-        return _criterion
+        self._criterion = (
+            self._criterion(**self.criterion_args)
+            if self._criterion is not None
+            else None
+        )
 
-    def _configure_optimizers(
-        self,
-        optimizer: Optional[Callable],
-        optimizer_args: Optional[Dict],
-        lr_scheduler: Optional[Callable],
-        lr_scheduler_args: Optional[Dict],
-    ) -> Tuple[Optional[Callable], Optional[Callable]]:
+    def _configure_optimizers(self,) -> None:
         """Loads the optimizers."""
-        if optimizer is not None:
-            _optimizer = optimizer(self._network.parameters(), **optimizer_args)
+        if self._optimizer is not None:
+            self._optimizer = self._optimizer(
+                self._network.parameters(), **self.optimizer_args
+            )
         else:
-            _optimizer = None
+            self._optimizer = None
 
-        if _optimizer and lr_scheduler is not None:
-            if "OneCycleLR" in str(lr_scheduler):
-                lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders["train"])
-            _lr_scheduler = lr_scheduler(_optimizer, **lr_scheduler_args)
+        if self._optimizer and self._lr_scheduler is not None:
+            if "OneCycleLR" in str(self._lr_scheduler):
+                self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader())
+            self._lr_scheduler = self._lr_scheduler(
+                self._optimizer, **self.lr_scheduler_args
+            )
         else:
-            _lr_scheduler = None
+            self._lr_scheduler = None
 
-        return _optimizer, _lr_scheduler
+        if self.swa_args is not None:
+            self._swa_start = self.swa_args["start"]
+            self._swa_scheduler = SWALR(self._optimizer, swa_lr=self.swa_args["lr"])
+            self._swa_network = AveragedModel(self._network).to(self.device)
 
     @property
-    def __name__(self) -> str:
+    def name(self) -> str:
         """Returns the name of the model."""
         return self._name
 
@@ -159,7 +240,7 @@ class Model(ABC):
         return self._input_shape
 
     @property
-    def mapper(self) -> Dict:
+    def mapper(self) -> EmnistMapper:
         """Returns the mapper that maps between ints and chars."""
         return self._mapper
 
@@ -202,13 +283,24 @@ class Model(ABC):
         return self._lr_scheduler
 
     @property
-    def data_loaders(self) -> Optional[Dict]:
-        """Dataloaders."""
-        return self._data_loaders
+    def swa_scheduler(self) -> Optional[Callable]:
+        """Returns the stochastic weight averaging scheduler."""
+        return self._swa_scheduler
+
+    @property
+    def swa_start(self) -> Optional[Callable]:
+        """Returns the start epoch of stochastic weight averaging."""
+        return self._swa_start
 
     @property
-    def network(self) -> nn.Module:
+    def swa_network(self) -> Optional[Callable]:
+        """Returns the stochastic weight averaging network."""
+        return self._swa_network
+
+    @property
+    def network(self) -> Type[nn.Module]:
         """Neural network."""
+        # Returns the SWA network if available.
         return self._network
 
     @property
@@ -217,15 +309,27 @@ class Model(ABC):
         WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True)
         return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt")
 
-    def summary(self) -> None:
+    def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor:
+        """Compute the loss."""
+        return self.criterion(output, targets)
+
+    def summary(
+        self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 5
+    ) -> None:
         """Prints a summary of the network architecture."""
-        device = re.sub("[^A-Za-z]+", "", self.device)
-        if self._input_shape is not None:
+
+        if input_shape is not None:
+            summary(self._network, input_shape, depth=depth, device=self.device)
+        elif self._input_shape is not None:
             input_shape = (1,) + tuple(self._input_shape)
-            summary(self._network, input_shape, device=device)
+            summary(self._network, input_shape, depth=depth, device=self.device)
         else:
             logger.warning("Could not print summary as input shape is not set.")
 
+    def to_device(self) -> None:
+        """Places the network on the device (GPU)."""
+        self._network.to(self._device)
+
     def _get_state_dict(self) -> Dict:
         """Get the state dict of the model."""
         state = {"model_state": self._network.state_dict()}
@@ -236,69 +340,67 @@ class Model(ABC):
         if self._lr_scheduler is not None:
             state["scheduler_state"] = self._lr_scheduler.state_dict()
 
+        if self._swa_network is not None:
+            state["swa_network"] = self._swa_network.state_dict()
+
         return state
 
-    def load_checkpoint(self, path: Path) -> int:
+    def load_from_checkpoint(self, checkpoint_path: Path) -> None:
         """Load a previously saved checkpoint.
 
         Args:
-            path (Path): Path to the experiment with the checkpoint.
-
-        Returns:
-            epoch (int): The last epoch when the checkpoint was created.
+            checkpoint_path (Path): Path to the experiment with the checkpoint.
 
         """
         logger.debug("Loading checkpoint...")
-        if not path.exists():
-            logger.debug("File does not exist {str(path)}")
+        if not checkpoint_path.exists():
+            logger.debug("File does not exist {str(checkpoint_path)}")
 
-        checkpoint = torch.load(str(path))
+        checkpoint = torch.load(str(checkpoint_path))
         self._network.load_state_dict(checkpoint["model_state"])
 
         if self._optimizer is not None:
             self._optimizer.load_state_dict(checkpoint["optimizer_state"])
 
-        # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs.
-        # if self._lr_scheduler is not None:
-        #     self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"])
-
-        epoch = checkpoint["epoch"]
+        if self._lr_scheduler is not None:
+            # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs
+            # with OneCycleLR.
+            if self._lr_scheduler.__class__.__name__ != "OneCycleLR":
+                self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"])
 
-        return epoch
+        if self._swa_network is not None:
+            self._swa_network.load_state_dict(checkpoint["swa_network"])
 
-    def save_checkpoint(self, is_best: bool, epoch: int, val_metric: str) -> None:
+    def save_checkpoint(
+        self, checkpoint_path: Path, is_best: bool, epoch: int, val_metric: str
+    ) -> None:
         """Saves a checkpoint of the model.
 
         Args:
+            checkpoint_path (Path): Path to the experiment with the checkpoint.
             is_best (bool): If it is the currently best model.
             epoch (int): The epoch of the checkpoint.
             val_metric (str): Validation metric.
 
-        Raises:
-            ValueError: If the self.model_dir is not set.
-
         """
         state = self._get_state_dict()
         state["is_best"] = is_best
         state["epoch"] = epoch
         state["network_args"] = self._network_args
 
-        if self.model_dir is None:
-            raise ValueError("Experiment directory is not set.")
-
-        self.model_dir.mkdir(parents=True, exist_ok=True)
+        checkpoint_path.mkdir(parents=True, exist_ok=True)
 
         logger.debug("Saving checkpoint...")
-        filepath = str(self.model_dir / "last.pt")
+        filepath = str(checkpoint_path / "last.pt")
         torch.save(state, filepath)
 
         if is_best:
             logger.debug(
                 f"Found a new best {val_metric}. Saving best checkpoint and weights."
             )
-            shutil.copyfile(filepath, str(self.model_dir / "best.pt"))
+            shutil.copyfile(filepath, str(checkpoint_path / "best.pt"))
 
-    def load_weights(self, network_fn: Type[nn.Module]) -> Tuple[Type[nn.Module], Dict]:
+    def load_weights(self, network_fn: Type[nn.Module]) -> None:
         """Load the network weights."""
         logger.debug("Loading network with pretrained weights.")
         filename = glob(self.weights_filename)[0]
@@ -308,13 +410,16 @@ class Model(ABC):
             )
         # Loading state directory.
         state_dict = torch.load(filename, map_location=torch.device(self._device))
-        network_args = state_dict["network_args"]
+        self._network_args = state_dict["network_args"]
         weights = state_dict["model_state"]
 
         # Initializes the network with trained weights.
-        network = network_fn(**self._network_args)
-        network.load_state_dict(weights)
-        return network, network_args
+        self._network = network_fn(**self._network_args)
+        self._network.load_state_dict(weights)
+
+        if "swa_network" in state_dict:
+            self._swa_network = AveragedModel(self._network).to(self.device)
+            self._swa_network.load_state_dict(state_dict["swa_network"])
 
     def save_weights(self, path: Path) -> None:
         """Save the network weights."""
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index 0fd7afd..64ba693 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -4,8 +4,10 @@ from typing import Callable, Dict, Optional, Tuple, Type, Union
 import numpy as np
 import torch
 from torch import nn
+from torch.utils.data import Dataset
 from torchvision.transforms import ToTensor
 
+from text_recognizer.datasets import EmnistMapper
 from text_recognizer.models.base import Model
 
 
@@ -15,8 +17,9 @@ class CharacterModel(Model):
     def __init__(
         self,
         network_fn: Type[nn.Module],
+        dataset: Type[Dataset],
         network_args: Optional[Dict] = None,
-        data_loader_args: Optional[Dict] = None,
+        dataset_args: Optional[Dict] = None,
         metrics: Optional[Dict] = None,
         criterion: Optional[Callable] = None,
         criterion_args: Optional[Dict] = None,
@@ -24,14 +27,16 @@ class CharacterModel(Model):
         optimizer_args: Optional[Dict] = None,
         lr_scheduler: Optional[Callable] = None,
         lr_scheduler_args: Optional[Dict] = None,
+        swa_args: Optional[Dict] = None,
         device: Optional[str] = None,
     ) -> None:
         """Initializes the CharacterModel."""
 
         super().__init__(
             network_fn,
+            dataset,
             network_args,
-            data_loader_args,
+            dataset_args,
             metrics,
             criterion,
             criterion_args,
@@ -39,8 +44,11 @@ class CharacterModel(Model):
             optimizer_args,
             lr_scheduler,
             lr_scheduler_args,
+            swa_args,
             device,
         )
+        if self._mapper is None:
+            self._mapper = EmnistMapper()
         self.tensor_transform = ToTensor()
         self.softmax = nn.Softmax(dim=0)
 
@@ -67,9 +75,13 @@ class CharacterModel(Model):
 
         # Put the image tensor on the device the model weights are on.
         image = image.to(self.device)
-        logits = self.network(image)
+        logits = (
+            self.swa_network(image)
+            if self.swa_network is not None
+            else self.network(image)
+        )
 
-        prediction = self.softmax(logits.data.squeeze())
+        prediction = self.softmax(logits.squeeze(0))
 
         index = int(torch.argmax(prediction, dim=0))
         confidence_of_prediction = prediction[index]
diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/line_ctc_model.py
new file mode 100644
index 0000000..97308a7
--- /dev/null
+++ b/src/text_recognizer/models/line_ctc_model.py
@@ -0,0 +1,105 @@
+"""Defines the LineCTCModel class."""
+from typing import Callable, Dict, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+from text_recognizer.networks import greedy_decoder
+
+
+class LineCTCModel(Model):
+    """Model for predicting a sequence of characters from an image of a text line."""
+
+    def __init__(
+        self,
+        network_fn: Type[nn.Module],
+        dataset: Type[Dataset],
+        network_args: Optional[Dict] = None,
+        dataset_args: Optional[Dict] = None,
+        metrics: Optional[Dict] = None,
+        criterion: Optional[Callable] = None,
+        criterion_args: Optional[Dict] = None,
+        optimizer: Optional[Callable] = None,
+        optimizer_args: Optional[Dict] = None,
+        lr_scheduler: Optional[Callable] = None,
+        lr_scheduler_args: Optional[Dict] = None,
+        swa_args: Optional[Dict] = None,
+        device: Optional[str] = None,
+    ) -> None:
+        super().__init__(
+            network_fn,
+            dataset,
+            network_args,
+            dataset_args,
+            metrics,
+            criterion,
+            criterion_args,
+            optimizer,
+            optimizer_args,
+            lr_scheduler,
+            lr_scheduler_args,
+            swa_args,
+            device,
+        )
+        if self._mapper is None:
+            self._mapper = EmnistMapper()
+        self.tensor_transform = ToTensor()
+
+    def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor:
+        """Computes the CTC loss.
+
+        Args:
+            output (Tensor): Model predictions.
+            targets (Tensor): Correct output sequence.
+
+        Returns:
+            Tensor: The CTC loss.
+
+        """
+        input_lengths = torch.full(
+            size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long,
+        )
+        target_lengths = torch.full(
+            size=(output.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
+        )
+        return self.criterion(output, targets, input_lengths, target_lengths)
+
+    @torch.no_grad()
+    def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
+        """Predict on a single input."""
+        if image.dtype == np.uint8:
+            # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+            image = self.tensor_transform(image)
+
+        # Rescale image between 0 and 1.
+        if image.dtype == torch.uint8:
+            # If the image is an unscaled tensor.
+            image = image.type("torch.FloatTensor") / 255
+
+        # Put the image tensor on the device the model weights are on.
+        image = image.to(self.device)
+        log_probs = (
+            self.swa_network(image)
+            if self.swa_network is not None
+            else self.network(image)
+        )
+
+        raw_pred, _ = greedy_decoder(
+            predictions=log_probs,
+            character_mapper=self.mapper,
+            blank_label=79,
+            collapse_repeated=True,
+        )
+
+        log_probs, _ = log_probs.max(dim=2)
+
+        predicted_characters = "".join(raw_pred[0])
+        confidence_of_prediction = torch.exp(log_probs.sum()).item()
+
+        return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py
index ac8d68e..6a26216 100644
--- a/src/text_recognizer/models/metrics.py
+++ b/src/text_recognizer/models/metrics.py
@@ -1,19 +1,89 @@
 """Utility functions for models."""
-
+import Levenshtein as Lev
 import torch
+from torch import Tensor
+
+from text_recognizer.networks import greedy_decoder
 
 
-def accuracy(outputs: torch.Tensor, labels: torch.Tensor) -> float:
+def accuracy(outputs: Tensor, labels: Tensor) -> float:
     """Computes the accuracy.
 
     Args:
-        outputs (torch.Tensor): The output from the network.
-        labels (torch.Tensor): Ground truth labels.
+        outputs (Tensor): The output from the network.
+        labels (Tensor): Ground truth labels.
 
     Returns:
         float: The accuracy for the batch.
 
     """
     _, predicted = torch.max(outputs.data, dim=1)
-    acc = (predicted == labels).sum().item() / labels.shape[0]
+    acc = (predicted == labels).sum().float() / labels.shape[0]
+    acc = acc.item()
     return acc
+
+
+def cer(outputs: Tensor, targets: Tensor) -> float:
+    """Computes the character error rate.
+
+    Args:
+        outputs (Tensor): The output from the network.
+        targets (Tensor): Ground truth labels.
+
+    Returns:
+        float: The cer for the batch.
+
+    """
+    target_lengths = torch.full(
+        size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
+    )
+    decoded_predictions, decoded_targets = greedy_decoder(
+        outputs, targets, target_lengths
+    )
+
+    lev_dist = 0
+
+    for prediction, target in zip(decoded_predictions, decoded_targets):
+        prediction = "".join(prediction)
+        target = "".join(target)
+        prediction, target = (
+            prediction.replace(" ", ""),
+            target.replace(" ", ""),
+        )
+        lev_dist += Lev.distance(prediction, target)
+    return lev_dist / len(decoded_predictions)
+
+
+def wer(outputs: Tensor, targets: Tensor) -> float:
+    """Computes the Word error rate.
+
+    Args:
+        outputs (Tensor): The output from the network.
+        targets (Tensor): Ground truth labels.
+
+    Returns:
+        float: The wer for the batch.
+
+    """
+    target_lengths = torch.full(
+        size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
+    )
+    decoded_predictions, decoded_targets = greedy_decoder(
+        outputs, targets, target_lengths
+    )
+
+    lev_dist = 0
+
+    for prediction, target in zip(decoded_predictions, decoded_targets):
+        prediction = "".join(prediction)
+        target = "".join(target)
+
+        b = set(prediction.split() + target.split())
+        word2char = dict(zip(b, range(len(b))))
+
+        w1 = [chr(word2char[w]) for w in prediction.split()]
+        w2 = [chr(word2char[w]) for w in target.split()]
+
+        lev_dist += Lev.distance("".join(w1), "".join(w2))
+
+    return lev_dist / len(decoded_predictions)
-- 
cgit v1.2.3-70-g09d2