summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conformer/subsampler.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/conformer/subsampler.py')
-rw-r--r--text_recognizer/networks/conformer/subsampler.py8
1 files changed, 3 insertions, 5 deletions
diff --git a/text_recognizer/networks/conformer/subsampler.py b/text_recognizer/networks/conformer/subsampler.py
index 2bc0445..53928f1 100644
--- a/text_recognizer/networks/conformer/subsampler.py
+++ b/text_recognizer/networks/conformer/subsampler.py
@@ -34,13 +34,11 @@ class Subsampler(nn.Module):
)
)
subsampler.append(nn.Mish(inplace=True))
- projector = nn.Sequential(
- nn.Flatten(start_dim=2), nn.Linear(channels, channels), nn.Dropout(dropout)
- )
+ projector = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(dropout))
return nn.Sequential(*subsampler), projector
def forward(self, x: Tensor) -> Tensor:
x = self.subsampler(x)
x = self.pixel_pos_embedding(x)
- x = self.projector(x)
- return x.permute(0, 2, 1)
+ x = x.flatten(start_dim=2).permute(0, 2, 1)
+ return self.projector(x)