summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/cnn.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-05-30 23:34:30 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-05-30 23:34:30 +0200
commit84d4147a342648398d16ad2c0bdbeacfbb4b3caa (patch)
treea372dcd1c654b86ff46e14566fcf527988fbb208 /text_recognizer/networks/cnn.py
parent52e811685374b07f3a82bf13a4e568d182045e68 (diff)
Add a basic cnn encoder
Diffstat (limited to 'text_recognizer/networks/cnn.py')
-rw-r--r--text_recognizer/networks/cnn.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/text_recognizer/networks/cnn.py b/text_recognizer/networks/cnn.py
new file mode 100644
index 0000000..5e2a7f4
--- /dev/null
+++ b/text_recognizer/networks/cnn.py
@@ -0,0 +1,26 @@
+"""Simple convolutional network."""
+import torch
+from torch import nn, Tensor
+
+
+class CNN(nn.Module):
+ def __init__(self, channels: int, depth: int) -> None:
+ super().__init__()
+ self.layers = self._build(channels, depth)
+
+ def _build(self, channels: int, depth: int) -> nn.Sequential:
+ layers = []
+ for i in range(depth):
+ layers.append(
+ nn.Conv2d(
+ in_channels=1 if i == 0 else channels,
+ out_channels=channels,
+ kernel_size=3,
+ stride=2,
+ )
+ )
+ layers.append(nn.Mish(inplace=True))
+ return nn.Sequential(*layers)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.layers(x)