summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/transducer/tds_conv.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2021-02-24 22:00:29 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2021-02-24 22:00:29 +0100
commit905eeeb4c3c0ba54b5414eb8f435e2e9870b7307 (patch)
tree91dab598a94911e6147b996237e786dd47f11f2f /src/text_recognizer/networks/transducer/tds_conv.py
parent4a54d7e690897dd6e6c719fb908fd371a44c2952 (diff)
updates
Diffstat (limited to 'src/text_recognizer/networks/transducer/tds_conv.py')
-rw-r--r--src/text_recognizer/networks/transducer/tds_conv.py15
1 files changed, 9 insertions, 6 deletions
diff --git a/src/text_recognizer/networks/transducer/tds_conv.py b/src/text_recognizer/networks/transducer/tds_conv.py
index 018caf2..5fb8ba9 100644
--- a/src/text_recognizer/networks/transducer/tds_conv.py
+++ b/src/text_recognizer/networks/transducer/tds_conv.py
@@ -136,8 +136,10 @@ class TDS2d(nn.Module):
self.tds = None
self.fc = None
- def _build_network(self) -> None:
+ self._build_network()
+ def _build_network(self) -> None:
+ in_channels = self.in_channels
modules = []
stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups])
if self.input_dim % stride_h:
@@ -151,7 +153,7 @@ class TDS2d(nn.Module):
modules.extend(
[
nn.Conv2d(
- in_channels=self.in_channels,
+ in_channels=in_channels,
out_channels=out_channels,
kernel_size=self.kernel_size,
padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2),
@@ -173,12 +175,10 @@ class TDS2d(nn.Module):
)
)
- self.in_channels = out_channels
+ in_channels = out_channels
self.tds = nn.Sequential(*modules)
- self.fc = nn.Linear(
- self.in_channels * self.input_dim // stride_h, self.output_dim
- )
+ self.fc = nn.Linear(in_channels * self.input_dim // stride_h, self.output_dim)
def forward(self, x: Tensor) -> Tensor:
"""Forward pass.
@@ -193,6 +193,9 @@ class TDS2d(nn.Module):
Tensor: Output tensor.
"""
+ if len(x.shape) == 4:
+ x = x.squeeze(1) # Squeeze the channel dim away.
+
B, H, W = x.shape
x = rearrange(
x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels