diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/networks/cnn.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/text_recognizer/networks/cnn.py b/text_recognizer/networks/cnn.py new file mode 100644 index 0000000..5e2a7f4 --- /dev/null +++ b/text_recognizer/networks/cnn.py @@ -0,0 +1,26 @@ +"""Simple convolutional network.""" +import torch +from torch import nn, Tensor + + +class CNN(nn.Module): + def __init__(self, channels: int, depth: int) -> None: + super().__init__() + self.layers = self._build(channels, depth) + + def _build(self, channels: int, depth: int) -> nn.Sequential: + layers = [] + for i in range(depth): + layers.append( + nn.Conv2d( + in_channels=1 if i == 0 else channels, + out_channels=channels, + kernel_size=3, + stride=2, + ) + ) + layers.append(nn.Mish(inplace=True)) + return nn.Sequential(*layers) + + def forward(self, x: Tensor) -> Tensor: + return self.layers(x) |