diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-07 00:24:28 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-07 00:24:28 +0200 |
commit | 38dc6ca3b787bcdb54d43ac5c076e08af25d44b2 (patch) | |
tree | df12ee98c797c44c61f02369cf8cb794d6f47b7c /text_recognizer/networks | |
parent | 7d759b6c0efcb58b5c7c6858d7dcbd2060992430 (diff) |
Add subsampler layer
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/conformer/__init__.py | 1 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/conformer.py | 10 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/subsampler.py | 46 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/__init__.py | 3 |
4 files changed, 59 insertions, 1 deletions
diff --git a/text_recognizer/networks/conformer/__init__.py b/text_recognizer/networks/conformer/__init__.py index 5f3c7b5..1886f85 100644 --- a/text_recognizer/networks/conformer/__init__.py +++ b/text_recognizer/networks/conformer/__init__.py @@ -3,3 +3,4 @@ from text_recognizer.networks.conformer.ff import Feedforward from text_recognizer.networks.conformer.glu import GLU from text_recognizer.networks.conformer.conformer import Conformer from text_recognizer.networks.conformer.conv import ConformerConv +from text_recognizer.networks.conformer.subsampler import Subsampler diff --git a/text_recognizer/networks/conformer/conformer.py b/text_recognizer/networks/conformer/conformer.py index d56955e..8d0e98e 100644 --- a/text_recognizer/networks/conformer/conformer.py +++ b/text_recognizer/networks/conformer/conformer.py @@ -1,5 +1,6 @@ """Conformer module.""" from copy import deepcopy +from typing import Type from torch import nn, Tensor @@ -7,11 +8,18 @@ from text_recognizer.networks.conformer.block import ConformerBlock class Conformer(nn.Module): - def __init__(self, block: ConformerBlock, depth: int) -> None: + def __init__( + self, + subsampler: Type[nn.Module], + block: ConformerBlock, + depth: int, + ) -> None: super().__init__() + self.subsampler = subsampler self.blocks = nn.ModuleList([deepcopy(block) for _ in range(depth)]) def forward(self, x: Tensor) -> Tensor: + x = self.subsampler(x) for fn in self.blocks: x = fn(x) return x diff --git a/text_recognizer/networks/conformer/subsampler.py b/text_recognizer/networks/conformer/subsampler.py new file mode 100644 index 0000000..2bc0445 --- /dev/null +++ b/text_recognizer/networks/conformer/subsampler.py @@ -0,0 +1,46 @@ +"""Simple convolutional network.""" +from typing import Tuple + +from torch import nn, Tensor + +from text_recognizer.networks.transformer import ( + AxialPositionalEmbedding, +) + + +class Subsampler(nn.Module): + def __init__( + self, + channels: int, + depth: int, + pixel_pos_embedding: AxialPositionalEmbedding, + dropout: float = 0.1, + ) -> None: + super().__init__() + self.pixel_pos_embedding = pixel_pos_embedding + self.subsampler, self.projector = self._build(channels, depth, dropout) + + def _build( + self, channels: int, depth: int, dropout: float + ) -> Tuple[nn.Sequential, nn.Sequential]: + subsampler = [] + for i in range(depth): + subsampler.append( + nn.Conv2d( + in_channels=1 if i == 0 else channels, + out_channels=channels, + kernel_size=3, + stride=2, + ) + ) + subsampler.append(nn.Mish(inplace=True)) + projector = nn.Sequential( + nn.Flatten(start_dim=2), nn.Linear(channels, channels), nn.Dropout(dropout) + ) + return nn.Sequential(*subsampler), projector + + def forward(self, x: Tensor) -> Tensor: + x = self.subsampler(x) + x = self.pixel_pos_embedding(x) + x = self.projector(x) + return x.permute(0, 2, 1) diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py index 041d257..d867800 100644 --- a/text_recognizer/networks/transformer/__init__.py +++ b/text_recognizer/networks/transformer/__init__.py @@ -1,3 +1,6 @@ """Transformer modules.""" from text_recognizer.networks.transformer.embeddings.rotary import RotaryEmbedding from text_recognizer.networks.transformer.attention import Attention +from text_recognizer.networks.transformer.embeddings.axial import ( + AxialPositionalEmbedding, +) |