diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-03 22:13:51 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-03 22:13:51 +0100 |
commit | 9381895e6f0154b0e9acc9e540266367e8a35843 (patch) | |
tree | 390a7627d6e13fb6139bd713ca9ae61134731afd /text_recognizer/networks | |
parent | 6cb2259592a1868fda069d7f7b3f1688c56f1912 (diff) |
Remove unused args from conv transformer
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 6 |
1 files changed, 1 insertions, 5 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index ee939e7..0c838d8 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -14,9 +14,7 @@ class ConvTransformer(nn.Module): def __init__( self, input_dims: Tuple[int, int, int], - encoder_dim: int, hidden_dim: int, - dropout_rate: float, num_classes: int, pad_index: Tensor, encoder: nn.Module, @@ -26,9 +24,7 @@ class ConvTransformer(nn.Module): ) -> None: super().__init__() self.input_dims = input_dims - self.encoder_dim = encoder_dim self.hidden_dim = hidden_dim - self.dropout_rate = dropout_rate self.num_classes = num_classes self.pad_index = pad_index self.encoder = encoder @@ -38,7 +34,7 @@ class ConvTransformer(nn.Module): # positional encoding. self.latent_encoder = nn.Sequential( nn.Conv2d( - in_channels=self.encoder_dim, + in_channels=self.encoder.out_channels, out_channels=self.hidden_dim, kernel_size=1, ), |