diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-20 22:18:35 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-20 22:18:35 +0200 |
commit | 1f459ba19422593de325983040e176f97cf4ffc0 (patch) | |
tree | 89fef442d5dbe0c83253e9566d1762f0704f64e2 /src/text_recognizer | |
parent | 95cbdf5bc1cc9639febda23c28d8f464c998b214 (diff) |
A lot of stuff working :D. ResNet implemented!
Diffstat (limited to 'src/text_recognizer')
-rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 43 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 9 | ||||
-rw-r--r-- | src/text_recognizer/models/base.py | 45 | ||||
-rw-r--r-- | src/text_recognizer/models/character_model.py | 8 | ||||
-rw-r--r-- | src/text_recognizer/networks/__init__.py | 3 | ||||
-rw-r--r-- | src/text_recognizer/networks/lenet.py | 17 | ||||
-rw-r--r-- | src/text_recognizer/networks/misc.py | 20 | ||||
-rw-r--r-- | src/text_recognizer/networks/mlp.py | 18 | ||||
-rw-r--r-- | src/text_recognizer/networks/residual_network.py | 314 | ||||
-rw-r--r-- | src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt | bin | 14485310 -> 14485362 bytes | |||
-rw-r--r-- | src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt | bin | 1704174 -> 11625484 bytes | |||
-rw-r--r-- | src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt | bin | 0 -> 28654593 bytes |
12 files changed, 400 insertions, 77 deletions
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index 96f84e5..49ebad3 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -8,6 +8,7 @@ from loguru import logger import numpy as np from PIL import Image import torch +from torch import Tensor from torch.utils.data import DataLoader, Dataset from torchvision.datasets import EMNIST from torchvision.transforms import Compose, Normalize, ToTensor @@ -183,12 +184,8 @@ class EmnistDataset(Dataset): self.input_shape = self._mapper.input_shape self.num_classes = self._mapper.num_classes - # Placeholders - self.data = None - self.targets = None - # Load dataset. - self.load_emnist_dataset() + self.data, self.targets = self.load_emnist_dataset() @property def mapper(self) -> EmnistMapper: @@ -199,9 +196,7 @@ class EmnistDataset(Dataset): """Returns the length of the dataset.""" return len(self.data) - def __getitem__( - self, index: Union[int, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor]: + def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: """Fetches samples from the dataset. Args: @@ -239,11 +234,13 @@ class EmnistDataset(Dataset): f"Mapping: {self.mapper.mapping}\n" ) - def _sample_to_balance(self) -> None: + def _sample_to_balance( + self, data: Tensor, targets: Tensor + ) -> Tuple[np.ndarray, np.ndarray]: """Because the dataset is not balanced, we take at most the mean number of instances per class.""" np.random.seed(self.seed) - x = self.data - y = self.targets + x = data + y = targets num_to_sample = int(np.bincount(y.flatten()).mean()) all_sampled_indices = [] for label in np.unique(y.flatten()): @@ -253,20 +250,22 @@ class EmnistDataset(Dataset): indices = np.concatenate(all_sampled_indices) x_sampled = x[indices] y_sampled = y[indices] - self.data = x_sampled - self.targets = y_sampled + data = x_sampled + targets = y_sampled + return data, targets - def _subsample(self) -> None: + def _subsample(self, data: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: """Subsamples the dataset to the specified fraction.""" - x = self.data - y = self.targets + x = data + y = targets num_samples = int(x.shape[0] * self.subsample_fraction) x_sampled = x[:num_samples] y_sampled = y[:num_samples] self.data = x_sampled self.targets = y_sampled + return data, targets - def load_emnist_dataset(self) -> None: + def load_emnist_dataset(self) -> Tuple[Tensor, Tensor]: """Fetch the EMNIST dataset.""" dataset = EMNIST( root=DATA_DIRNAME, @@ -277,11 +276,13 @@ class EmnistDataset(Dataset): target_transform=None, ) - self.data = dataset.data - self.targets = dataset.targets + data = dataset.data + targets = dataset.targets if self.sample_to_balance: - self._sample_to_balance() + data, targets = self._sample_to_balance(data, targets) if self.subsample_fraction is not None: - self._subsample() + data, targets = self._subsample(data, targets) + + return data, targets diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index d64a991..b0617f5 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -8,6 +8,7 @@ import h5py from loguru import logger import numpy as np import torch +from torch import Tensor from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose, Normalize, ToTensor @@ -87,16 +88,14 @@ class EmnistLinesDataset(Dataset): """Returns the length of the dataset.""" return len(self.data) - def __getitem__( - self, index: Union[int, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor]: + def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: """Fetches data, target pair of the dataset for a given and index or indices. Args: - index (Union[int, torch.Tensor]): Either a list or int of indices/index. + index (Union[int, Tensor]): Either a list or int of indices/index. Returns: - Tuple[torch.Tensor, torch.Tensor]: Data target pair. + Tuple[Tensor, Tensor]: Data target pair. """ if torch.is_tensor(index): diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index 6d40b49..74fd223 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -53,8 +53,8 @@ class Model(ABC): """ - # Fetch data loaders and dataset info. - dataset_name, self._data_loaders, self._mapper = self._load_data_loader( + # Configure data loaders and dataset info. + dataset_name, self._data_loaders, self._mapper = self._configure_data_loader( data_loader_args ) self._input_shape = self._mapper.input_shape @@ -70,16 +70,19 @@ class Model(ABC): else: self._device = device - # Load network. - self._network, self._network_args = self._load_network(network_fn, network_args) + # Configure network. + self._network, self._network_args = self._configure_network( + network_fn, network_args + ) # To device. self._network.to(self._device) - # Set training objects. - self._criterion = self._load_criterion(criterion, criterion_args) - self._optimizer = self._load_optimizer(optimizer, optimizer_args) - self._lr_scheduler = self._load_lr_scheduler(lr_scheduler, lr_scheduler_args) + # Configure training objects. + self._criterion = self._configure_criterion(criterion, criterion_args) + self._optimizer, self._lr_scheduler = self._configure_optimizers( + optimizer, optimizer_args, lr_scheduler, lr_scheduler_args + ) # Experiment directory. self.model_dir = None @@ -87,7 +90,7 @@ class Model(ABC): # Flag for stopping training. self.stop_training = False - def _load_data_loader( + def _configure_data_loader( self, data_loader_args: Optional[Dict] ) -> Tuple[str, Dict, EmnistMapper]: """Loads data loader, dataset name, and dataset mapper.""" @@ -102,7 +105,7 @@ class Model(ABC): data_loaders = None return dataset_name, data_loaders, mapper - def _load_network( + def _configure_network( self, network_fn: Type[nn.Module], network_args: Optional[Dict] ) -> Tuple[Type[nn.Module], Dict]: """Loads the network.""" @@ -113,7 +116,7 @@ class Model(ABC): network = network_fn(**network_args) return network, network_args - def _load_criterion( + def _configure_criterion( self, criterion: Optional[Callable], criterion_args: Optional[Dict] ) -> Optional[Callable]: """Loads the criterion.""" @@ -123,27 +126,27 @@ class Model(ABC): _criterion = None return _criterion - def _load_optimizer( - self, optimizer: Optional[Callable], optimizer_args: Optional[Dict] - ) -> Optional[Callable]: - """Loads the optimizer.""" + def _configure_optimizers( + self, + optimizer: Optional[Callable], + optimizer_args: Optional[Dict], + lr_scheduler: Optional[Callable], + lr_scheduler_args: Optional[Dict], + ) -> Tuple[Optional[Callable], Optional[Callable]]: + """Loads the optimizers.""" if optimizer is not None: _optimizer = optimizer(self._network.parameters(), **optimizer_args) else: _optimizer = None - return _optimizer - def _load_lr_scheduler( - self, lr_scheduler: Optional[Callable], lr_scheduler_args: Optional[Dict] - ) -> Optional[Callable]: - """Loads learning rate scheduler.""" if self._optimizer and lr_scheduler is not None: if "OneCycleLR" in str(lr_scheduler): lr_scheduler_args["steps_per_epoch"] = len(self._data_loaders["train"]) _lr_scheduler = lr_scheduler(self._optimizer, **lr_scheduler_args) else: _lr_scheduler = None - return _lr_scheduler + + return _optimizer, _lr_scheduler @property def __name__(self) -> str: diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index 0a0ab2d..0fd7afd 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -44,6 +44,7 @@ class CharacterModel(Model): self.tensor_transform = ToTensor() self.softmax = nn.Softmax(dim=0) + @torch.no_grad() def predict_on_image( self, image: Union[np.ndarray, torch.Tensor] ) -> Tuple[str, float]: @@ -64,10 +65,9 @@ class CharacterModel(Model): # If the image is an unscaled tensor. image = image.type("torch.FloatTensor") / 255 - with torch.no_grad(): - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - logits = self.network(image) + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + logits = self.network(image) prediction = self.softmax(logits.data.squeeze()) diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index e6b6946..a83ca35 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -1,5 +1,6 @@ """Network modules.""" from .lenet import LeNet from .mlp import MLP +from .residual_network import ResidualNetwork -__all__ = ["MLP", "LeNet"] +__all__ = ["MLP", "LeNet", "ResidualNetwork"] diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py index cbc58fc..91d3f2c 100644 --- a/src/text_recognizer/networks/lenet.py +++ b/src/text_recognizer/networks/lenet.py @@ -5,6 +5,8 @@ from einops.layers.torch import Rearrange import torch from torch import nn +from text_recognizer.networks.misc import activation_function + class LeNet(nn.Module): """LeNet network.""" @@ -16,8 +18,7 @@ class LeNet(nn.Module): hidden_size: Tuple[int, ...] = (9216, 128), dropout_rate: float = 0.2, output_size: int = 10, - activation_fn: Optional[Callable] = None, - activation_fn_args: Optional[Dict] = None, + activation_fn: Optional[str] = "relu", ) -> None: """The LeNet network. @@ -28,18 +29,12 @@ class LeNet(nn.Module): Defaults to (9216, 128). dropout_rate (float): The dropout rate. Defaults to 0.2. output_size (int): Number of classes. Defaults to 10. - activation_fn (Optional[Callable]): The non-linear activation function. Defaults to - nn.ReLU(inplace). - activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None. + activation_fn (Optional[str]): The name of non-linear activation function. Defaults to relu. """ super().__init__() - if activation_fn is not None: - activation_fn_args = activation_fn_args or {} - activation_fn = getattr(nn, activation_fn)(**activation_fn_args) - else: - activation_fn = nn.ReLU(inplace=True) + activation_fn = activation_function(activation_fn) self.layers = [ nn.Conv2d( @@ -66,7 +61,7 @@ class LeNet(nn.Module): self.layers = nn.Sequential(*self.layers) def forward(self, x: torch.Tensor) -> torch.Tensor: - """The feedforward.""" + """The feedforward pass.""" # If batch dimenstion is missing, it needs to be added. if len(x.shape) == 3: x = x.unsqueeze(0) diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py index 2fbab8f..6f61b5d 100644 --- a/src/text_recognizer/networks/misc.py +++ b/src/text_recognizer/networks/misc.py @@ -1,9 +1,9 @@ """Miscellaneous neural network functionality.""" -from typing import Tuple +from typing import Tuple, Type from einops import rearrange import torch -from torch.nn import Unfold +from torch import nn def sliding_window( @@ -20,10 +20,24 @@ def sliding_window( torch.Tensor: A tensor with the shape (batch, patches, height, width). """ - unfold = Unfold(kernel_size=patch_size, stride=stride) + unfold = nn.Unfold(kernel_size=patch_size, stride=stride) # Preform the slidning window, unsqueeze as the channel dimesion is lost. patches = unfold(images).unsqueeze(1) patches = rearrange( patches, "b c (h w) t -> b t c h w", h=patch_size[0], w=patch_size[1] ) return patches + + +def activation_function(activation: str) -> Type[nn.Module]: + """Returns the callable activation function.""" + activation_fns = nn.ModuleDict( + [ + ["gelu", nn.GELU()], + ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)], + ["none", nn.Identity()], + ["relu", nn.ReLU(inplace=True)], + ["selu", nn.SELU(inplace=True)], + ] + ) + return activation_fns[activation.lower()] diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py index ac2c825..acebdaa 100644 --- a/src/text_recognizer/networks/mlp.py +++ b/src/text_recognizer/networks/mlp.py @@ -5,6 +5,8 @@ from einops.layers.torch import Rearrange import torch from torch import nn +from text_recognizer.networks.misc import activation_function + class MLP(nn.Module): """Multi layered perceptron network.""" @@ -16,8 +18,7 @@ class MLP(nn.Module): hidden_size: Union[int, List] = 128, num_layers: int = 3, dropout_rate: float = 0.2, - activation_fn: Optional[Callable] = None, - activation_fn_args: Optional[Dict] = None, + activation_fn: str = "relu", ) -> None: """Initialization of the MLP network. @@ -27,18 +28,13 @@ class MLP(nn.Module): hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128. num_layers (int): The number of hidden layers. Defaults to 3. dropout_rate (float): The dropout rate at each layer. Defaults to 0.2. - activation_fn (Optional[Callable]): The activation function in the hidden layers. Defaults to - None. - activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None. + activation_fn (str): Name of the activation function in the hidden layers. Defaults to + relu. """ super().__init__() - if activation_fn is not None: - activation_fn_args = activation_fn_args or {} - activation_fn = getattr(nn, activation_fn)(**activation_fn_args) - else: - activation_fn = nn.ReLU(inplace=True) + activation_fn = activation_function(activation_fn) if isinstance(hidden_size, int): hidden_size = [hidden_size] * num_layers @@ -65,7 +61,7 @@ class MLP(nn.Module): self.layers = nn.Sequential(*self.layers) def forward(self, x: torch.Tensor) -> torch.Tensor: - """The feedforward.""" + """The feedforward pass.""" # If batch dimenstion is missing, it needs to be added. if len(x.shape) == 3: x = x.unsqueeze(0) diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py index 23394b0..47e351a 100644 --- a/src/text_recognizer/networks/residual_network.py +++ b/src/text_recognizer/networks/residual_network.py @@ -1 +1,315 @@ """Residual CNN.""" +from functools import partial +from typing import Callable, Dict, List, Optional, Type, Union + +from einops.layers.torch import Rearrange, Reduce +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.misc import activation_function + + +class Conv2dAuto(nn.Conv2d): + """Convolution with auto padding based on kernel size.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2) + + +def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential: + """3x3 convolution with batch norm.""" + conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,) + return nn.Sequential( + conv3x3(in_channels, out_channels, *args, **kwargs), + nn.BatchNorm2d(out_channels), + ) + + +class IdentityBlock(nn.Module): + """Residual with identity block.""" + + def __init__( + self, in_channels: int, out_channels: int, activation: str = "relu" + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.blocks = nn.Identity() + self.activation_fn = activation_function(activation) + self.shortcut = nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + residual = x + if self.apply_shortcut: + residual = self.shortcut(x) + x = self.blocks(x) + x += residual + x = self.activation_fn(x) + return x + + @property + def apply_shortcut(self) -> bool: + """Check if shortcut should be applied.""" + return self.in_channels != self.out_channels + + +class ResidualBlock(IdentityBlock): + """Residual with nonlinear shortcut.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + expansion: int = 1, + downsampling: int = 1, + *args, + **kwargs + ) -> None: + """Short summary. + + Args: + in_channels (int): Number of in channels. + out_channels (int): umber of out channels. + expansion (int): Expansion factor of the out channels. Defaults to 1. + downsampling (int): Downsampling factor used in stride. Defaults to 1. + *args (type): Extra arguments. + **kwargs (type): Extra key value arguments. + + """ + super().__init__(in_channels, out_channels, *args, **kwargs) + self.expansion = expansion + self.downsampling = downsampling + + self.shortcut = ( + nn.Sequential( + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.expanded_channels, + kernel_size=1, + stride=self.downsampling, + bias=False, + ), + nn.BatchNorm2d(self.expanded_channels), + ) + if self.apply_shortcut + else None + ) + + @property + def expanded_channels(self) -> int: + """Computes the expanded output channels.""" + return self.out_channels * self.expansion + + @property + def apply_shortcut(self) -> bool: + """Check if shortcut should be applied.""" + return self.in_channels != self.expanded_channels + + +class BasicBlock(ResidualBlock): + """Basic ResNet block.""" + + expansion = 1 + + def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None: + super().__init__(in_channels, out_channels, *args, **kwargs) + self.blocks = nn.Sequential( + conv_bn( + in_channels=self.in_channels, + out_channels=self.out_channels, + bias=False, + stride=self.downsampling, + ), + self.activation_fn, + conv_bn( + in_channels=self.out_channels, + out_channels=self.expanded_channels, + bias=False, + ), + ) + + +class BottleNeckBlock(ResidualBlock): + """Bottleneck block to increase depth while minimizing parameter size.""" + + expansion = 4 + + def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None: + super().__init__(in_channels, out_channels, *args, **kwargs) + self.blocks = nn.Sequential( + conv_bn( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=1, + ), + self.activation_fn, + conv_bn( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=self.downsampling, + ), + self.activation_fn, + conv_bn( + in_channels=self.out_channels, + out_channels=self.expanded_channels, + kernel_size=1, + ), + ) + + +class ResidualLayer(nn.Module): + """ResNet layer.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + block: BasicBlock = BasicBlock, + num_blocks: int = 1, + *args, + **kwargs + ) -> None: + super().__init__() + downsampling = 2 if in_channels != out_channels else 1 + self.blocks = nn.Sequential( + block( + in_channels, out_channels, *args, **kwargs, downsampling=downsampling + ), + *[ + block( + out_channels * block.expansion, + out_channels, + downsampling=1, + *args, + **kwargs + ) + for _ in range(num_blocks - 1) + ] + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + x = self.blocks(x) + return x + + +class Encoder(nn.Module): + """Encoder network.""" + + def __init__( + self, + in_channels: int = 1, + block_sizes: List[int] = (32, 64), + depths: List[int] = (2, 2), + activation: str = "relu", + block: Type[nn.Module] = BasicBlock, + *args, + **kwargs + ) -> None: + super().__init__() + + self.block_sizes = block_sizes + self.depths = depths + self.activation = activation + + self.gate = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=self.block_sizes[0], + kernel_size=3, + stride=2, + padding=3, + bias=False, + ), + nn.BatchNorm2d(self.block_sizes[0]), + activation_function(self.activation), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + + self.blocks = self._configure_blocks(block) + + def _configure_blocks( + self, block: Type[nn.Module], *args, **kwargs + ) -> nn.Sequential: + channels = [self.block_sizes[0]] + list( + zip(self.block_sizes, self.block_sizes[1:]) + ) + blocks = [ + ResidualLayer( + in_channels=channels[0], + out_channels=channels[0], + num_blocks=self.depths[0], + block=block, + activation=self.activation, + *args, + **kwargs + ) + ] + blocks += [ + ResidualLayer( + in_channels=in_channels * block.expansion, + out_channels=out_channels, + num_blocks=num_blocks, + block=block, + activation=self.activation, + *args, + **kwargs + ) + for (in_channels, out_channels), num_blocks in zip( + channels[1:], self.depths[1:] + ) + ] + + return nn.Sequential(*blocks) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + # If batch dimenstion is missing, it needs to be added. + if len(x.shape) == 3: + x = x.unsqueeze(0) + x = self.gate(x) + return self.blocks(x) + + +class Decoder(nn.Module): + """Classification head.""" + + def __init__(self, in_features: int, num_classes: int = 80) -> None: + super().__init__() + self.decoder = nn.Sequential( + Reduce("b c h w -> b c", "mean"), + nn.Linear(in_features=in_features, out_features=num_classes), + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + return self.decoder(x) + + +class ResidualNetwork(nn.Module): + """Full residual network.""" + + def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None: + super().__init__() + self.encoder = Encoder(in_channels, *args, **kwargs) + self.decoder = Decoder( + in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels, + num_classes=num_classes, + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + x = self.encoder(x) + x = self.decoder(x) + return x diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt Binary files differindex 81ef9be..676eb44 100644 --- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt +++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt Binary files differindex 49bd166..86cf103 100644 --- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt +++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt Binary files differnew file mode 100644 index 0000000..008beb2 --- /dev/null +++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt |