summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-08 08:41:09 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-08 08:41:09 +0200
commit7b660c13ce3c0edeace1107838e62c559bc6f078 (patch)
tree117e8ca03815282907f7ba8da296ebc99de8ea7c
parent8ae1b802bb7d7c63cf758e44269e97a4c0788b65 (diff)
Fix conformer net
-rw-r--r--text_recognizer/networks/conformer/block.py4
-rw-r--r--text_recognizer/networks/conformer/conformer.py5
-rw-r--r--text_recognizer/networks/conformer/subsampler.py8
3 files changed, 9 insertions, 8 deletions
diff --git a/text_recognizer/networks/conformer/block.py b/text_recognizer/networks/conformer/block.py
index 4ea33c0..c53f339 100644
--- a/text_recognizer/networks/conformer/block.py
+++ b/text_recognizer/networks/conformer/block.py
@@ -26,9 +26,9 @@ class ConformerBlock(nn.Module):
self.conv = conv
self.post_norm = nn.LayerNorm(dim)
- def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
+ def forward(self, x: Tensor) -> Tensor:
x = self.ff_1(x) + x
- x = self.attn(x, input_mask=mask) + x
+ x = self.attn(x) + x
x = self.conv(x) + x
x = self.ff_2(x) + x
return self.post_norm(x)
diff --git a/text_recognizer/networks/conformer/conformer.py b/text_recognizer/networks/conformer/conformer.py
index 8d0e98e..e2dce27 100644
--- a/text_recognizer/networks/conformer/conformer.py
+++ b/text_recognizer/networks/conformer/conformer.py
@@ -10,6 +10,8 @@ from text_recognizer.networks.conformer.block import ConformerBlock
class Conformer(nn.Module):
def __init__(
self,
+ dim: int,
+ num_classes: int,
subsampler: Type[nn.Module],
block: ConformerBlock,
depth: int,
@@ -17,9 +19,10 @@ class Conformer(nn.Module):
super().__init__()
self.subsampler = subsampler
self.blocks = nn.ModuleList([deepcopy(block) for _ in range(depth)])
+ self.fc = nn.Linear(dim, num_classes, bias=False)
def forward(self, x: Tensor) -> Tensor:
x = self.subsampler(x)
for fn in self.blocks:
x = fn(x)
- return x
+ return self.fc(x)
diff --git a/text_recognizer/networks/conformer/subsampler.py b/text_recognizer/networks/conformer/subsampler.py
index 2bc0445..53928f1 100644
--- a/text_recognizer/networks/conformer/subsampler.py
+++ b/text_recognizer/networks/conformer/subsampler.py
@@ -34,13 +34,11 @@ class Subsampler(nn.Module):
)
)
subsampler.append(nn.Mish(inplace=True))
- projector = nn.Sequential(
- nn.Flatten(start_dim=2), nn.Linear(channels, channels), nn.Dropout(dropout)
- )
+ projector = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(dropout))
return nn.Sequential(*subsampler), projector
def forward(self, x: Tensor) -> Tensor:
x = self.subsampler(x)
x = self.pixel_pos_embedding(x)
- x = self.projector(x)
- return x.permute(0, 2, 1)
+ x = x.flatten(start_dim=2).permute(0, 2, 1)
+ return self.projector(x)