summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/cnn.py
blob: 5e2a7f4284e21e16659b5cd29fdc704e91c09f80 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
"""Simple convolutional network."""
import torch
from torch import nn, Tensor


class CNN(nn.Module):
    def __init__(self, channels: int, depth: int) -> None:
        super().__init__()
        self.layers = self._build(channels, depth)

    def _build(self, channels: int, depth: int) -> nn.Sequential:
        layers = []
        for i in range(depth):
            layers.append(
                nn.Conv2d(
                    in_channels=1 if i == 0 else channels,
                    out_channels=channels,
                    kernel_size=3,
                    stride=2,
                )
            )
            layers.append(nn.Mish(inplace=True))
        return nn.Sequential(*layers)

    def forward(self, x: Tensor) -> Tensor:
        return self.layers(x)