diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-08 14:54:44 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-08 14:54:44 +0100 |
commit | dc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 (patch) | |
tree | 1b5fc0d06952e13727e85c4f973a26d277068453 /src/text_recognizer/networks/sparse_mlp.py | |
parent | e181195a699d7fa237f256d90ab4dedffc03d405 (diff) |
new updates
Diffstat (limited to 'src/text_recognizer/networks/sparse_mlp.py')
-rw-r--r-- | src/text_recognizer/networks/sparse_mlp.py | 78 |
1 files changed, 78 insertions, 0 deletions
diff --git a/src/text_recognizer/networks/sparse_mlp.py b/src/text_recognizer/networks/sparse_mlp.py new file mode 100644 index 0000000..53cf166 --- /dev/null +++ b/src/text_recognizer/networks/sparse_mlp.py @@ -0,0 +1,78 @@ +"""Defines the Sparse MLP network.""" +from typing import Callable, Dict, List, Optional, Union +import warnings + +from einops.layers.torch import Rearrange +from pytorch_block_sparse import BlockSparseLinear +import torch +from torch import nn + +from text_recognizer.networks.util import activation_function + +warnings.filterwarnings("ignore", category=DeprecationWarning) + + +class SparseMLP(nn.Module): + """Sparse multi layered perceptron network.""" + + def __init__( + self, + input_size: int = 784, + num_classes: int = 10, + hidden_size: Union[int, List] = 128, + num_layers: int = 3, + density: float = 0.1, + activation_fn: str = "relu", + ) -> None: + """Initialization of the MLP network. + + Args: + input_size (int): The input shape of the network. Defaults to 784. + num_classes (int): Number of classes in the dataset. Defaults to 10. + hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128. + num_layers (int): The number of hidden layers. Defaults to 3. + density (float): The density of activation at each layer. Default to 0.1. + activation_fn (str): Name of the activation function in the hidden layers. Defaults to + relu. + + """ + super().__init__() + + activation_fn = activation_function(activation_fn) + + if isinstance(hidden_size, int): + hidden_size = [hidden_size] * num_layers + + self.layers = [ + Rearrange("b c h w -> b (c h w)"), + nn.Linear(in_features=input_size, out_features=hidden_size[0]), + activation_fn, + ] + + for i in range(num_layers - 1): + self.layers += [ + BlockSparseLinear( + in_features=hidden_size[i], + out_features=hidden_size[i + 1], + density=density, + ), + activation_fn, + ] + + self.layers.append( + nn.Linear(in_features=hidden_size[-1], out_features=num_classes) + ) + + self.layers = nn.Sequential(*self.layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """The feedforward pass.""" + # If batch dimenstion is missing, it needs to be added. + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + return self.layers(x) + + @property + def __name__(self) -> str: + """Returns the name of the network.""" + return "mlp" |