summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/lenet.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-09-08 23:14:23 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-09-08 23:14:23 +0200
commite1b504bca41a9793ed7e88ef14f2e2cbd85724f2 (patch)
tree70b482f890c9ad2be104f0bff8f2172e8411a2be /src/text_recognizer/networks/lenet.py
parentfe23001b6588e6e6e9e2c5a99b72f3445cf5206f (diff)
IAM datasets implemented.
Diffstat (limited to 'src/text_recognizer/networks/lenet.py')
-rw-r--r--src/text_recognizer/networks/lenet.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py
index 91d3f2c..53c575e 100644
--- a/src/text_recognizer/networks/lenet.py
+++ b/src/text_recognizer/networks/lenet.py
@@ -1,4 +1,4 @@
-"""Defines the LeNet network."""
+"""Implementation of the LeNet network."""
from typing import Callable, Dict, Optional, Tuple
from einops.layers.torch import Rearrange
@@ -9,7 +9,7 @@ from text_recognizer.networks.misc import activation_function
class LeNet(nn.Module):
- """LeNet network."""
+ """LeNet network for character prediction."""
def __init__(
self,
@@ -17,10 +17,10 @@ class LeNet(nn.Module):
kernel_sizes: Tuple[int, ...] = (3, 3, 2),
hidden_size: Tuple[int, ...] = (9216, 128),
dropout_rate: float = 0.2,
- output_size: int = 10,
+ num_classes: int = 10,
activation_fn: Optional[str] = "relu",
) -> None:
- """The LeNet network.
+ """Initialization of the LeNet network.
Args:
channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64).
@@ -28,7 +28,7 @@ class LeNet(nn.Module):
hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers.
Defaults to (9216, 128).
dropout_rate (float): The dropout rate. Defaults to 0.2.
- output_size (int): Number of classes. Defaults to 10.
+ num_classes (int): Number of classes. Defaults to 10.
activation_fn (Optional[str]): The name of non-linear activation function. Defaults to relu.
"""
@@ -55,7 +55,7 @@ class LeNet(nn.Module):
nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]),
activation_fn,
nn.Dropout(p=dropout_rate),
- nn.Linear(in_features=hidden_size[1], out_features=output_size),
+ nn.Linear(in_features=hidden_size[1], out_features=num_classes),
]
self.layers = nn.Sequential(*self.layers)