From 9381895e6f0154b0e9acc9e540266367e8a35843 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 3 Nov 2021 22:13:51 +0100 Subject: Remove unused args from conv transformer --- text_recognizer/networks/conv_transformer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index ee939e7..0c838d8 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -14,9 +14,7 @@ class ConvTransformer(nn.Module): def __init__( self, input_dims: Tuple[int, int, int], - encoder_dim: int, hidden_dim: int, - dropout_rate: float, num_classes: int, pad_index: Tensor, encoder: nn.Module, @@ -26,9 +24,7 @@ class ConvTransformer(nn.Module): ) -> None: super().__init__() self.input_dims = input_dims - self.encoder_dim = encoder_dim self.hidden_dim = hidden_dim - self.dropout_rate = dropout_rate self.num_classes = num_classes self.pad_index = pad_index self.encoder = encoder @@ -38,7 +34,7 @@ class ConvTransformer(nn.Module): # positional encoding. self.latent_encoder = nn.Sequential( nn.Conv2d( - in_channels=self.encoder_dim, + in_channels=self.encoder.out_channels, out_channels=self.hidden_dim, kernel_size=1, ), -- cgit v1.2.3-70-g09d2