From 84d4147a342648398d16ad2c0bdbeacfbb4b3caa Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 30 May 2022 23:34:30 +0200 Subject: Add a basic cnn encoder --- text_recognizer/networks/cnn.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 text_recognizer/networks/cnn.py (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/cnn.py b/text_recognizer/networks/cnn.py new file mode 100644 index 0000000..5e2a7f4 --- /dev/null +++ b/text_recognizer/networks/cnn.py @@ -0,0 +1,26 @@ +"""Simple convolutional network.""" +import torch +from torch import nn, Tensor + + +class CNN(nn.Module): + def __init__(self, channels: int, depth: int) -> None: + super().__init__() + self.layers = self._build(channels, depth) + + def _build(self, channels: int, depth: int) -> nn.Sequential: + layers = [] + for i in range(depth): + layers.append( + nn.Conv2d( + in_channels=1 if i == 0 else channels, + out_channels=channels, + kernel_size=3, + stride=2, + ) + ) + layers.append(nn.Mish(inplace=True)) + return nn.Sequential(*layers) + + def forward(self, x: Tensor) -> Tensor: + return self.layers(x) -- cgit v1.2.3-70-g09d2