summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/mlp.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-09-08 23:14:23 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-09-08 23:14:23 +0200
commite1b504bca41a9793ed7e88ef14f2e2cbd85724f2 (patch)
tree70b482f890c9ad2be104f0bff8f2172e8411a2be /src/text_recognizer/networks/mlp.py
parentfe23001b6588e6e6e9e2c5a99b72f3445cf5206f (diff)
IAM datasets implemented.
Diffstat (limited to 'src/text_recognizer/networks/mlp.py')
-rw-r--r--src/text_recognizer/networks/mlp.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py
index acebdaa..d66af28 100644
--- a/src/text_recognizer/networks/mlp.py
+++ b/src/text_recognizer/networks/mlp.py
@@ -14,7 +14,7 @@ class MLP(nn.Module):
def __init__(
self,
input_size: int = 784,
- output_size: int = 10,
+ num_classes: int = 10,
hidden_size: Union[int, List] = 128,
num_layers: int = 3,
dropout_rate: float = 0.2,
@@ -24,7 +24,7 @@ class MLP(nn.Module):
Args:
input_size (int): The input shape of the network. Defaults to 784.
- output_size (int): Number of classes in the dataset. Defaults to 10.
+ num_classes (int): Number of classes in the dataset. Defaults to 10.
hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128.
num_layers (int): The number of hidden layers. Defaults to 3.
dropout_rate (float): The dropout rate at each layer. Defaults to 0.2.
@@ -55,7 +55,7 @@ class MLP(nn.Module):
self.layers.append(nn.Dropout(p=dropout_rate))
self.layers.append(
- nn.Linear(in_features=hidden_size[-1], out_features=output_size)
+ nn.Linear(in_features=hidden_size[-1], out_features=num_classes)
)
self.layers = nn.Sequential(*self.layers)