diff options
Diffstat (limited to 'src/text_recognizer/networks/transducer/tds_conv.py')
-rw-r--r-- | src/text_recognizer/networks/transducer/tds_conv.py | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/src/text_recognizer/networks/transducer/tds_conv.py b/src/text_recognizer/networks/transducer/tds_conv.py index 018caf2..5fb8ba9 100644 --- a/src/text_recognizer/networks/transducer/tds_conv.py +++ b/src/text_recognizer/networks/transducer/tds_conv.py @@ -136,8 +136,10 @@ class TDS2d(nn.Module): self.tds = None self.fc = None - def _build_network(self) -> None: + self._build_network() + def _build_network(self) -> None: + in_channels = self.in_channels modules = [] stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups]) if self.input_dim % stride_h: @@ -151,7 +153,7 @@ class TDS2d(nn.Module): modules.extend( [ nn.Conv2d( - in_channels=self.in_channels, + in_channels=in_channels, out_channels=out_channels, kernel_size=self.kernel_size, padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2), @@ -173,12 +175,10 @@ class TDS2d(nn.Module): ) ) - self.in_channels = out_channels + in_channels = out_channels self.tds = nn.Sequential(*modules) - self.fc = nn.Linear( - self.in_channels * self.input_dim // stride_h, self.output_dim - ) + self.fc = nn.Linear(in_channels * self.input_dim // stride_h, self.output_dim) def forward(self, x: Tensor) -> Tensor: """Forward pass. @@ -193,6 +193,9 @@ class TDS2d(nn.Module): Tensor: Output tensor. """ + if len(x.shape) == 4: + x = x.squeeze(1) # Squeeze the channel dim away. + B, H, W = x.shape x = rearrange( x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels |