diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-08 08:41:09 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-08 08:41:09 +0200 |
commit | 7b660c13ce3c0edeace1107838e62c559bc6f078 (patch) | |
tree | 117e8ca03815282907f7ba8da296ebc99de8ea7c /text_recognizer/networks/conformer/subsampler.py | |
parent | 8ae1b802bb7d7c63cf758e44269e97a4c0788b65 (diff) |
Fix conformer net
Diffstat (limited to 'text_recognizer/networks/conformer/subsampler.py')
-rw-r--r-- | text_recognizer/networks/conformer/subsampler.py | 8 |
1 files changed, 3 insertions, 5 deletions
diff --git a/text_recognizer/networks/conformer/subsampler.py b/text_recognizer/networks/conformer/subsampler.py index 2bc0445..53928f1 100644 --- a/text_recognizer/networks/conformer/subsampler.py +++ b/text_recognizer/networks/conformer/subsampler.py @@ -34,13 +34,11 @@ class Subsampler(nn.Module): ) ) subsampler.append(nn.Mish(inplace=True)) - projector = nn.Sequential( - nn.Flatten(start_dim=2), nn.Linear(channels, channels), nn.Dropout(dropout) - ) + projector = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(dropout)) return nn.Sequential(*subsampler), projector def forward(self, x: Tensor) -> Tensor: x = self.subsampler(x) x = self.pixel_pos_embedding(x) - x = self.projector(x) - return x.permute(0, 2, 1) + x = x.flatten(start_dim=2).permute(0, 2, 1) + return self.projector(x) |