diff options
Diffstat (limited to 'text_recognizer/networks')
| -rw-r--r-- | text_recognizer/networks/cnn.py | 26 | 
1 files changed, 26 insertions, 0 deletions
diff --git a/text_recognizer/networks/cnn.py b/text_recognizer/networks/cnn.py new file mode 100644 index 0000000..5e2a7f4 --- /dev/null +++ b/text_recognizer/networks/cnn.py @@ -0,0 +1,26 @@ +"""Simple convolutional network.""" +import torch +from torch import nn, Tensor + + +class CNN(nn.Module): +    def __init__(self, channels: int, depth: int) -> None: +        super().__init__() +        self.layers = self._build(channels, depth) + +    def _build(self, channels: int, depth: int) -> nn.Sequential: +        layers = [] +        for i in range(depth): +            layers.append( +                nn.Conv2d( +                    in_channels=1 if i == 0 else channels, +                    out_channels=channels, +                    kernel_size=3, +                    stride=2, +                ) +            ) +            layers.append(nn.Mish(inplace=True)) +        return nn.Sequential(*layers) + +    def forward(self, x: Tensor) -> Tensor: +        return self.layers(x)  |