diff options
Diffstat (limited to 'src/text_recognizer/networks/mlp.py')
-rw-r--r-- | src/text_recognizer/networks/mlp.py | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py index d704d99..ac2c825 100644 --- a/src/text_recognizer/networks/mlp.py +++ b/src/text_recognizer/networks/mlp.py @@ -1,6 +1,7 @@ """Defines the MLP network.""" from typing import Callable, Dict, List, Optional, Union +from einops.layers.torch import Rearrange import torch from torch import nn @@ -34,7 +35,8 @@ class MLP(nn.Module): super().__init__() if activation_fn is not None: - activation_fn = getattr(nn, activation_fn)(activation_fn_args) + activation_fn_args = activation_fn_args or {} + activation_fn = getattr(nn, activation_fn)(**activation_fn_args) else: activation_fn = nn.ReLU(inplace=True) @@ -42,6 +44,7 @@ class MLP(nn.Module): hidden_size = [hidden_size] * num_layers self.layers = [ + Rearrange("b c h w -> b (c h w)"), nn.Linear(in_features=input_size, out_features=hidden_size[0]), activation_fn, ] @@ -63,7 +66,9 @@ class MLP(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: """The feedforward.""" - x = torch.flatten(x, start_dim=1) + # If batch dimenstion is missing, it needs to be added. + if len(x.shape) == 3: + x = x.unsqueeze(0) return self.layers(x) @property |