summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/lenet.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-06-23 22:39:54 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-06-23 22:39:54 +0200
commit7c4de6d88664d2ea1b084f316a11896dde3e1150 (patch)
treecbde7e64c8064e9b523dfb65cd6c487d061ec805 /src/text_recognizer/networks/lenet.py
parenta7a9ce59fc37317dd74c3a52caf7c4e68e570ee8 (diff)
latest
Diffstat (limited to 'src/text_recognizer/networks/lenet.py')
-rw-r--r--src/text_recognizer/networks/lenet.py93
1 files changed, 93 insertions, 0 deletions
diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py
new file mode 100644
index 0000000..71d247f
--- /dev/null
+++ b/src/text_recognizer/networks/lenet.py
@@ -0,0 +1,93 @@
+"""Defines the LeNet network."""
+from typing import Callable, Optional, Tuple
+
+import torch
+from torch import nn
+
+
+class Flatten(nn.Module):
+ """Flattens a tensor."""
+
+ def forward(self, x: int) -> torch.Tensor:
+ """Flattens a tensor for input to a nn.Linear layer."""
+ return torch.flatten(x, start_dim=1)
+
+
+class LeNet(nn.Module):
+ """LeNet network."""
+
+ def __init__(
+ self,
+ channels: Tuple[int, ...],
+ kernel_sizes: Tuple[int, ...],
+ hidden_size: Tuple[int, ...],
+ dropout_rate: float,
+ output_size: int,
+ activation_fn: Optional[Callable] = None,
+ ) -> None:
+ """The LeNet network.
+
+ Args:
+ channels (Tuple[int, ...]): Channels in the convolutional layers.
+ kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers.
+ hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers.
+ dropout_rate (float): The dropout rate.
+ output_size (int): Number of classes.
+ activation_fn (Optional[Callable]): The non-linear activation function. Defaults to
+ nn.ReLU(inplace).
+
+ """
+ super().__init__()
+
+ if activation_fn is None:
+ activation_fn = nn.ReLU(inplace=True)
+
+ self.layers = [
+ nn.Conv2d(
+ in_channels=channels[0],
+ out_channels=channels[1],
+ kernel_size=kernel_sizes[0],
+ ),
+ activation_fn,
+ nn.Conv2d(
+ in_channels=channels[1],
+ out_channels=channels[2],
+ kernel_size=kernel_sizes[1],
+ ),
+ activation_fn,
+ nn.MaxPool2d(kernel_sizes[2]),
+ nn.Dropout(p=dropout_rate),
+ Flatten(),
+ nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]),
+ activation_fn,
+ nn.Dropout(p=dropout_rate),
+ nn.Linear(in_features=hidden_size[1], out_features=output_size),
+ ]
+
+ self.layers = nn.Sequential(*self.layers)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """The feedforward."""
+ return self.layers(x)
+
+
+# def test():
+# x = torch.randn([1, 1, 28, 28])
+# channels = [1, 32, 64]
+# kernel_sizes = [3, 3, 2]
+# hidden_size = [9216, 128]
+# output_size = 10
+# dropout_rate = 0.2
+# activation_fn = nn.ReLU()
+# net = LeNet(
+# channels=channels,
+# kernel_sizes=kernel_sizes,
+# dropout_rate=dropout_rate,
+# hidden_size=hidden_size,
+# output_size=output_size,
+# activation_fn=activation_fn,
+# )
+# from torchsummary import summary
+#
+# summary(net, (1, 28, 28), device="cpu")
+# out = net(x)