summaryrefslogtreecommitdiff
path: root/src/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer')
-rw-r--r--src/text_recognizer/character_predictor.py8
-rw-r--r--src/text_recognizer/datasets/__init__.py2
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py9
-rw-r--r--src/text_recognizer/models/__init__.py4
-rw-r--r--src/text_recognizer/models/base.py66
-rw-r--r--src/text_recognizer/models/character_model.py30
-rw-r--r--src/text_recognizer/models/metrics.py2
-rw-r--r--src/text_recognizer/networks/__init__.py4
-rw-r--r--src/text_recognizer/networks/lenet.py55
-rw-r--r--src/text_recognizer/networks/mlp.py71
-rw-r--r--src/text_recognizer/tests/test_character_predictor.py19
-rw-r--r--src/text_recognizer/util.py2
-rw-r--r--src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.ptbin0 -> 14483400 bytes
-rw-r--r--src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.ptbin0 -> 1702233 bytes
14 files changed, 155 insertions, 117 deletions
diff --git a/src/text_recognizer/character_predictor.py b/src/text_recognizer/character_predictor.py
index 69ef896..a773f36 100644
--- a/src/text_recognizer/character_predictor.py
+++ b/src/text_recognizer/character_predictor.py
@@ -1,8 +1,8 @@
"""CharacterPredictor class."""
-
-from typing import Tuple, Union
+from typing import Dict, Tuple, Type, Union
import numpy as np
+from torch import nn
from text_recognizer.models import CharacterModel
from text_recognizer.util import read_image
@@ -11,9 +11,9 @@ from text_recognizer.util import read_image
class CharacterPredictor:
"""Recognizes the character in handwritten character images."""
- def __init__(self) -> None:
+ def __init__(self, network_fn: Type[nn.Module], network_args: Dict) -> None:
"""Intializes the CharacterModel and load the pretrained weights."""
- self.model = CharacterModel()
+ self.model = CharacterModel(network_fn=network_fn, network_args=network_args)
self.model.load_weights()
self.model.eval()
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py
index aec5bf9..795be90 100644
--- a/src/text_recognizer/datasets/__init__.py
+++ b/src/text_recognizer/datasets/__init__.py
@@ -1,2 +1,4 @@
"""Dataset modules."""
from .emnist_dataset import EmnistDataLoader
+
+__all__ = ["EmnistDataLoader"]
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index a17d7a9..b92b57d 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -2,7 +2,7 @@
import json
from pathlib import Path
-from typing import Callable, Dict, List, Optional
+from typing import Callable, Dict, List, Optional, Type
from loguru import logger
import numpy as np
@@ -102,21 +102,22 @@ class EmnistDataLoader:
self.shuffle = shuffle
self.num_workers = num_workers
self.cuda = cuda
+ self.seed = seed
self._data_loaders = self._fetch_emnist_data_loaders()
@property
def __name__(self) -> str:
"""Returns the name of the dataset."""
- return "EMNIST"
+ return "Emnist"
- def __call__(self, split: str) -> Optional[DataLoader]:
+ def __call__(self, split: str) -> DataLoader:
"""Returns the `split` DataLoader.
Args:
split (str): The dataset split, i.e. train or val.
Returns:
- type: A PyTorch DataLoader.
+ DataLoader: A PyTorch DataLoader.
Raises:
ValueError: If the split does not exist.
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index d265dcf..ff10a07 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -1,2 +1,6 @@
"""Model modules."""
+from .base import Model
from .character_model import CharacterModel
+from .metrics import accuracy
+
+__all__ = ["Model", "CharacterModel", "accuracy"]
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index 0cc531a..b78eacb 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -1,9 +1,11 @@
"""Abstract Model class for PyTorch neural networks."""
from abc import ABC, abstractmethod
+from glob import glob
from pathlib import Path
+import re
import shutil
-from typing import Callable, Dict, Optional, Tuple
+from typing import Callable, Dict, Optional, Tuple, Type
from loguru import logger
import torch
@@ -19,7 +21,7 @@ class Model(ABC):
def __init__(
self,
- network_fn: Callable,
+ network_fn: Type[nn.Module],
network_args: Dict,
data_loader: Optional[Callable] = None,
data_loader_args: Optional[Dict] = None,
@@ -35,7 +37,7 @@ class Model(ABC):
"""Base class, to be inherited by model for specific type of data.
Args:
- network_fn (Callable): The PyTorch network.
+ network_fn (Type[nn.Module]): The PyTorch network.
network_args (Dict): Arguments for the network.
data_loader (Optional[Callable]): A function that fetches train and val DataLoader.
data_loader_args (Optional[Dict]): Arguments for the DataLoader.
@@ -57,27 +59,29 @@ class Model(ABC):
self._data_loaders = data_loader(**data_loader_args)
dataset_name = self._data_loaders.__name__
else:
- dataset_name = ""
+ dataset_name = "*"
self._data_loaders = None
- self.name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}"
+ self._name = f"{self.__class__.__name__}_{dataset_name}_{network_fn.__name__}"
# Extract the input shape for the torchsummary.
- self._input_shape = network_args.pop("input_shape")
+ if isinstance(network_args["input_size"], int):
+ self._input_shape = (1,) + tuple([network_args["input_size"]])
+ else:
+ self._input_shape = (1,) + tuple(network_args["input_size"])
if metrics is not None:
self._metrics = metrics
# Set the device.
- if self.device is None:
- self._device = torch.device(
- "cuda:0" if torch.cuda.is_available() else "cpu"
- )
+ if device is None:
+ self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self._device = device
# Load network.
- self._network = network_fn(**network_args)
+ self.network_args = network_args
+ self._network = network_fn(**self.network_args)
# To device.
self._network.to(self._device)
@@ -95,13 +99,29 @@ class Model(ABC):
# Set learning rate scheduler.
self._lr_scheduler = None
if lr_scheduler is not None:
+ # OneCycleLR needs the number of steps in an epoch as an input argument.
+ if "OneCycleLR" in str(lr_scheduler):
+ lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders("train"))
self._lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args)
+ # Class mapping.
+ self._mapping = None
+
+ @property
+ def __name__(self) -> str:
+ """Returns the name of the model."""
+ return self._name
+
@property
def input_shape(self) -> Tuple[int, ...]:
"""The input shape."""
return self._input_shape
+ @property
+ def mapping(self) -> Dict:
+ """Returns the class mapping."""
+ return self._mapping
+
def eval(self) -> None:
"""Sets the network to evaluation mode."""
self._network.eval()
@@ -149,13 +169,14 @@ class Model(ABC):
def weights_filename(self) -> str:
"""Filepath to the network weights."""
WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True)
- return str(WEIGHT_DIRNAME / f"{self.name}_weights.pt")
+ return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt")
def summary(self) -> None:
"""Prints a summary of the network architecture."""
- summary(self._network, self._input_shape, device=self.device)
+ device = re.sub("[^A-Za-z]+", "", self.device)
+ summary(self._network, self._input_shape, device=device)
- def _get_state(self) -> Dict:
+ def _get_state_dict(self) -> Dict:
"""Get the state dict of the model."""
state = {"model_state": self._network.state_dict()}
if self._optimizer is not None:
@@ -172,6 +193,7 @@ class Model(ABC):
epoch (int): The last epoch when the checkpoint was created.
"""
+ logger.debug("Loading checkpoint...")
if not path.exists():
logger.debug("File does not exist {str(path)}")
@@ -200,6 +222,7 @@ class Model(ABC):
state = self._get_state_dict()
state["is_best"] = is_best
state["epoch"] = epoch
+ state["network_args"] = self.network_args
path.mkdir(parents=True, exist_ok=True)
@@ -216,15 +239,18 @@ class Model(ABC):
def load_weights(self) -> None:
"""Load the network weights."""
logger.debug("Loading network weights.")
- weights = torch.load(self.weights_filename)["model_state"]
+ filename = glob(self.weights_filename)[0]
+ weights = torch.load(filename, map_location=torch.device(self._device))[
+ "model_state"
+ ]
self._network.load_state_dict(weights)
- def save_weights(self) -> None:
+ def save_weights(self, path: Path) -> None:
"""Save the network weights."""
- logger.debug("Saving network weights.")
- torch.save({"model_state": self._network.state_dict()}, self.weights_filename)
+ logger.debug("Saving the best network weights.")
+ shutil.copyfile(str(path / "best.pt"), self.weights_filename)
@abstractmethod
- def mapping(self) -> Dict:
- """Mapping from network output to class."""
+ def load_mapping(self) -> None:
+ """Loads class mapping from network output to character."""
...
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index fd69bf2..527fc7d 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -1,5 +1,5 @@
"""Defines the CharacterModel class."""
-from typing import Callable, Dict, Optional, Tuple
+from typing import Callable, Dict, Optional, Tuple, Type
import numpy as np
import torch
@@ -8,7 +8,6 @@ from torchvision.transforms import ToTensor
from text_recognizer.datasets.emnist_dataset import load_emnist_mapping
from text_recognizer.models.base import Model
-from text_recognizer.networks.mlp import mlp
class CharacterModel(Model):
@@ -16,8 +15,9 @@ class CharacterModel(Model):
def __init__(
self,
- network_fn: Callable,
+ network_fn: Type[nn.Module],
network_args: Dict,
+ data_loader: Optional[Callable] = None,
data_loader_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,
criterion: Optional[Callable] = None,
@@ -33,6 +33,7 @@ class CharacterModel(Model):
super().__init__(
network_fn,
network_args,
+ data_loader,
data_loader_args,
metrics,
criterion,
@@ -43,13 +44,13 @@ class CharacterModel(Model):
lr_scheduler_args,
device,
)
- self.emnist_mapping = self.mapping()
- self.eval()
+ self.load_mapping()
+ self.tensor_transform = ToTensor()
+ self.softmax = nn.Softmax(dim=0)
- def mapping(self) -> Dict[int, str]:
+ def load_mapping(self) -> None:
"""Mapping between integers and classes."""
- mapping = load_emnist_mapping()
- return mapping
+ self._mapping = load_emnist_mapping()
def predict_on_image(self, image: np.ndarray) -> Tuple[str, float]:
"""Character prediction on an image.
@@ -61,15 +62,20 @@ class CharacterModel(Model):
Tuple[str, float]: The predicted character and the confidence in the prediction.
"""
+
if image.dtype == np.uint8:
image = (image / 255).astype(np.float32)
# Conver to Pytorch Tensor.
- image = ToTensor(image)
+ image = self.tensor_transform(image)
+
+ with torch.no_grad():
+ logits = self.network(image)
+
+ prediction = self.softmax(logits.data.squeeze())
- prediction = self.network(image)
- index = torch.argmax(prediction, dim=1)
+ index = int(torch.argmax(prediction, dim=0))
confidence_of_prediction = prediction[index]
- predicted_character = self.emnist_mapping[index]
+ predicted_character = self._mapping[index]
return predicted_character, confidence_of_prediction
diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py
index e2a30a9..ac8d68e 100644
--- a/src/text_recognizer/models/metrics.py
+++ b/src/text_recognizer/models/metrics.py
@@ -3,7 +3,7 @@
import torch
-def accuracy(outputs: torch.Tensor, labels: torch.Tensro) -> float:
+def accuracy(outputs: torch.Tensor, labels: torch.Tensor) -> float:
"""Computes the accuracy.
Args:
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index 4ea5bb3..e6b6946 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -1 +1,5 @@
"""Network modules."""
+from .lenet import LeNet
+from .mlp import MLP
+
+__all__ = ["MLP", "LeNet"]
diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py
index 71d247f..2839a0c 100644
--- a/src/text_recognizer/networks/lenet.py
+++ b/src/text_recognizer/networks/lenet.py
@@ -1,5 +1,5 @@
"""Defines the LeNet network."""
-from typing import Callable, Optional, Tuple
+from typing import Callable, Dict, Optional, Tuple
import torch
from torch import nn
@@ -18,28 +18,37 @@ class LeNet(nn.Module):
def __init__(
self,
- channels: Tuple[int, ...],
- kernel_sizes: Tuple[int, ...],
- hidden_size: Tuple[int, ...],
- dropout_rate: float,
- output_size: int,
+ input_size: Tuple[int, ...] = (1, 28, 28),
+ channels: Tuple[int, ...] = (1, 32, 64),
+ kernel_sizes: Tuple[int, ...] = (3, 3, 2),
+ hidden_size: Tuple[int, ...] = (9216, 128),
+ dropout_rate: float = 0.2,
+ output_size: int = 10,
activation_fn: Optional[Callable] = None,
+ activation_fn_args: Optional[Dict] = None,
) -> None:
"""The LeNet network.
Args:
- channels (Tuple[int, ...]): Channels in the convolutional layers.
- kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers.
+ input_size (Tuple[int, ...]): The input shape of the network. Defaults to (1, 28, 28).
+ channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64).
+ kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2).
hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers.
- dropout_rate (float): The dropout rate.
- output_size (int): Number of classes.
+ Defaults to (9216, 128).
+ dropout_rate (float): The dropout rate. Defaults to 0.2.
+ output_size (int): Number of classes. Defaults to 10.
activation_fn (Optional[Callable]): The non-linear activation function. Defaults to
nn.ReLU(inplace).
+ activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None.
"""
super().__init__()
- if activation_fn is None:
+ self._input_size = input_size
+
+ if activation_fn is not None:
+ activation_fn = getattr(nn, activation_fn)(activation_fn_args)
+ else:
activation_fn = nn.ReLU(inplace=True)
self.layers = [
@@ -68,26 +77,6 @@ class LeNet(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""The feedforward."""
+ if len(x.shape) == 3:
+ x = x.unsqueeze(0)
return self.layers(x)
-
-
-# def test():
-# x = torch.randn([1, 1, 28, 28])
-# channels = [1, 32, 64]
-# kernel_sizes = [3, 3, 2]
-# hidden_size = [9216, 128]
-# output_size = 10
-# dropout_rate = 0.2
-# activation_fn = nn.ReLU()
-# net = LeNet(
-# channels=channels,
-# kernel_sizes=kernel_sizes,
-# dropout_rate=dropout_rate,
-# hidden_size=hidden_size,
-# output_size=output_size,
-# activation_fn=activation_fn,
-# )
-# from torchsummary import summary
-#
-# summary(net, (1, 28, 28), device="cpu")
-# out = net(x)
diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py
index 2a41790..d704d99 100644
--- a/src/text_recognizer/networks/mlp.py
+++ b/src/text_recognizer/networks/mlp.py
@@ -1,5 +1,5 @@
"""Defines the MLP network."""
-from typing import Callable, Optional
+from typing import Callable, Dict, List, Optional, Union
import torch
from torch import nn
@@ -10,45 +10,54 @@ class MLP(nn.Module):
def __init__(
self,
- input_size: int,
- output_size: int,
- hidden_size: int,
- num_layers: int,
- dropout_rate: float,
+ input_size: int = 784,
+ output_size: int = 10,
+ hidden_size: Union[int, List] = 128,
+ num_layers: int = 3,
+ dropout_rate: float = 0.2,
activation_fn: Optional[Callable] = None,
+ activation_fn_args: Optional[Dict] = None,
) -> None:
"""Initialization of the MLP network.
Args:
- input_size (int): The input shape of the network.
- output_size (int): Number of classes in the dataset.
- hidden_size (int): The number of `neurons` in each hidden layer.
- num_layers (int): The number of hidden layers.
- dropout_rate (float): The dropout rate at each layer.
- activation_fn (Optional[Callable]): The activation function in the hidden layers, (default:
- nn.ReLU()).
+ input_size (int): The input shape of the network. Defaults to 784.
+ output_size (int): Number of classes in the dataset. Defaults to 10.
+ hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128.
+ num_layers (int): The number of hidden layers. Defaults to 3.
+ dropout_rate (float): The dropout rate at each layer. Defaults to 0.2.
+ activation_fn (Optional[Callable]): The activation function in the hidden layers. Defaults to
+ None.
+ activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None.
"""
super().__init__()
- if activation_fn is None:
+ if activation_fn is not None:
+ activation_fn = getattr(nn, activation_fn)(activation_fn_args)
+ else:
activation_fn = nn.ReLU(inplace=True)
+ if isinstance(hidden_size, int):
+ hidden_size = [hidden_size] * num_layers
+
self.layers = [
- nn.Linear(in_features=input_size, out_features=hidden_size),
+ nn.Linear(in_features=input_size, out_features=hidden_size[0]),
activation_fn,
]
- for _ in range(num_layers):
+ for i in range(num_layers - 1):
self.layers += [
- nn.Linear(in_features=hidden_size, out_features=hidden_size),
+ nn.Linear(in_features=hidden_size[i], out_features=hidden_size[i + 1]),
activation_fn,
]
if dropout_rate:
self.layers.append(nn.Dropout(p=dropout_rate))
- self.layers.append(nn.Linear(in_features=hidden_size, out_features=output_size))
+ self.layers.append(
+ nn.Linear(in_features=hidden_size[-1], out_features=output_size)
+ )
self.layers = nn.Sequential(*self.layers)
@@ -57,25 +66,7 @@ class MLP(nn.Module):
x = torch.flatten(x, start_dim=1)
return self.layers(x)
-
-# def test():
-# x = torch.randn([1, 28, 28])
-# input_size = torch.flatten(x).shape[0]
-# output_size = 10
-# hidden_size = 128
-# num_layers = 5
-# dropout_rate = 0.25
-# activation_fn = nn.GELU()
-# net = MLP(
-# input_size=input_size,
-# output_size=output_size,
-# hidden_size=hidden_size,
-# num_layers=num_layers,
-# dropout_rate=dropout_rate,
-# activation_fn=activation_fn,
-# )
-# from torchsummary import summary
-#
-# summary(net, (1, 28, 28), device="cpu")
-#
-# out = net(x)
+ @property
+ def __name__(self) -> str:
+ """Returns the name of the network."""
+ return "mlp"
diff --git a/src/text_recognizer/tests/test_character_predictor.py b/src/text_recognizer/tests/test_character_predictor.py
index 7c094ef..c603a3a 100644
--- a/src/text_recognizer/tests/test_character_predictor.py
+++ b/src/text_recognizer/tests/test_character_predictor.py
@@ -1,9 +1,14 @@
"""Test for CharacterPredictor class."""
+import importlib
import os
from pathlib import Path
import unittest
+import click
+from loguru import logger
+
from text_recognizer.character_predictor import CharacterPredictor
+from text_recognizer.networks import MLP
SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" / "emnist"
@@ -13,13 +18,23 @@ os.environ["CUDA_VISIBLE_DEVICES"] = ""
class TestCharacterPredictor(unittest.TestCase):
"""Tests for the CharacterPredictor class."""
+ # @click.command()
+ # @click.option(
+ # "--network", type=str, help="Network to load, e.g. MLP or LeNet.", default="MLP"
+ # )
def test_filename(self) -> None:
"""Test that CharacterPredictor correctly predicts on a single image, for serveral test images."""
- predictor = CharacterPredictor()
+ network_module = importlib.import_module("text_recognizer.networks")
+ network_fn_ = getattr(network_module, "MLP")
+ # network_args = {"input_size": [28, 28], "output_size": 62, "dropout_rate": 0}
+ network_args = {"input_size": 784, "output_size": 62, "dropout_rate": 0.2}
+ predictor = CharacterPredictor(
+ network_fn=network_fn_, network_args=network_args
+ )
for filename in SUPPORT_DIRNAME.glob("*.png"):
pred, conf = predictor.predict(str(filename))
- print(
+ logger.info(
f"Prediction: {pred} at confidence: {conf} for image with character {filename.stem}"
)
self.assertEqual(pred, filename.stem)
diff --git a/src/text_recognizer/util.py b/src/text_recognizer/util.py
index 52fa1e4..6c07c60 100644
--- a/src/text_recognizer/util.py
+++ b/src/text_recognizer/util.py
@@ -25,7 +25,7 @@ def read_image(image_uri: Union[Path, str], grayscale: bool = False) -> np.ndarr
) from None
imread_flag = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
- local_file = os.path.exsits(image_uri)
+ local_file = os.path.exists(image_uri)
try:
image = None
if local_file:
diff --git a/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt b/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt
new file mode 100644
index 0000000..43a3891
--- /dev/null
+++ b/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.pt b/src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.pt
new file mode 100644
index 0000000..0dde787
--- /dev/null
+++ b/src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.pt
Binary files differ