diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-06 02:42:45 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-06 02:42:45 +0200 |
commit | 3ab82ad36bce6fa698a13a029a0694b75a5947b7 (patch) | |
tree | 136f71a62d60e3ccf01e1f95d64bb4d9f9c9befe /text_recognizer/networks/conv_transformer.py | |
parent | 1bccf71cf4eec335001b50a8fbc0c991d0e6d13a (diff) |
Fix VQVAE into en/decoder, bug in wandb artifact code uploading
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index f3ba49d..b1a101e 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -4,7 +4,6 @@ from typing import Tuple from torch import nn, Tensor -from text_recognizer.networks.encoders.efficientnet import EfficientNet from text_recognizer.networks.transformer.layers import Decoder from text_recognizer.networks.transformer.positional_encodings import ( PositionalEncoding, @@ -18,15 +17,17 @@ class ConvTransformer(nn.Module): def __init__( self, input_dims: Tuple[int, int, int], + encoder_dim: int, hidden_dim: int, dropout_rate: float, num_classes: int, pad_index: Tensor, - encoder: EfficientNet, + encoder: nn.Module, decoder: Decoder, ) -> None: super().__init__() self.input_dims = input_dims + self.encoder_dim = encoder_dim self.hidden_dim = hidden_dim self.dropout_rate = dropout_rate self.num_classes = num_classes @@ -38,7 +39,7 @@ class ConvTransformer(nn.Module): # positional encoding. self.latent_encoder = nn.Sequential( nn.Conv2d( - in_channels=self.encoder.out_channels, + in_channels=self.encoder_dim, out_channels=self.hidden_dim, kernel_size=1, ), |