diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-05-30 23:34:30 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-05-30 23:34:30 +0200 |
commit | 84d4147a342648398d16ad2c0bdbeacfbb4b3caa (patch) | |
tree | a372dcd1c654b86ff46e14566fcf527988fbb208 /text_recognizer/networks/cnn.py | |
parent | 52e811685374b07f3a82bf13a4e568d182045e68 (diff) |
Add a basic cnn encoder
Diffstat (limited to 'text_recognizer/networks/cnn.py')
-rw-r--r-- | text_recognizer/networks/cnn.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/text_recognizer/networks/cnn.py b/text_recognizer/networks/cnn.py new file mode 100644 index 0000000..5e2a7f4 --- /dev/null +++ b/text_recognizer/networks/cnn.py @@ -0,0 +1,26 @@ +"""Simple convolutional network.""" +import torch +from torch import nn, Tensor + + +class CNN(nn.Module): + def __init__(self, channels: int, depth: int) -> None: + super().__init__() + self.layers = self._build(channels, depth) + + def _build(self, channels: int, depth: int) -> nn.Sequential: + layers = [] + for i in range(depth): + layers.append( + nn.Conv2d( + in_channels=1 if i == 0 else channels, + out_channels=channels, + kernel_size=3, + stride=2, + ) + ) + layers.append(nn.Mish(inplace=True)) + return nn.Sequential(*layers) + + def forward(self, x: Tensor) -> Tensor: + return self.layers(x) |