summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/mlp.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/mlp.py')
-rw-r--r--src/text_recognizer/networks/mlp.py18
1 files changed, 7 insertions, 11 deletions
diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py
index ac2c825..acebdaa 100644
--- a/src/text_recognizer/networks/mlp.py
+++ b/src/text_recognizer/networks/mlp.py
@@ -5,6 +5,8 @@ from einops.layers.torch import Rearrange
import torch
from torch import nn
+from text_recognizer.networks.misc import activation_function
+
class MLP(nn.Module):
"""Multi layered perceptron network."""
@@ -16,8 +18,7 @@ class MLP(nn.Module):
hidden_size: Union[int, List] = 128,
num_layers: int = 3,
dropout_rate: float = 0.2,
- activation_fn: Optional[Callable] = None,
- activation_fn_args: Optional[Dict] = None,
+ activation_fn: str = "relu",
) -> None:
"""Initialization of the MLP network.
@@ -27,18 +28,13 @@ class MLP(nn.Module):
hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128.
num_layers (int): The number of hidden layers. Defaults to 3.
dropout_rate (float): The dropout rate at each layer. Defaults to 0.2.
- activation_fn (Optional[Callable]): The activation function in the hidden layers. Defaults to
- None.
- activation_fn_args (Optional[Dict]): The arguments for the activation function. Defaults to None.
+ activation_fn (str): Name of the activation function in the hidden layers. Defaults to
+ relu.
"""
super().__init__()
- if activation_fn is not None:
- activation_fn_args = activation_fn_args or {}
- activation_fn = getattr(nn, activation_fn)(**activation_fn_args)
- else:
- activation_fn = nn.ReLU(inplace=True)
+ activation_fn = activation_function(activation_fn)
if isinstance(hidden_size, int):
hidden_size = [hidden_size] * num_layers
@@ -65,7 +61,7 @@ class MLP(nn.Module):
self.layers = nn.Sequential(*self.layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
- """The feedforward."""
+ """The feedforward pass."""
# If batch dimenstion is missing, it needs to be added.
if len(x.shape) == 3:
x = x.unsqueeze(0)