summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conv_transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
-rw-r--r--text_recognizer/networks/conv_transformer.py6
1 files changed, 1 insertions, 5 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index ee939e7..0c838d8 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -14,9 +14,7 @@ class ConvTransformer(nn.Module):
def __init__(
self,
input_dims: Tuple[int, int, int],
- encoder_dim: int,
hidden_dim: int,
- dropout_rate: float,
num_classes: int,
pad_index: Tensor,
encoder: nn.Module,
@@ -26,9 +24,7 @@ class ConvTransformer(nn.Module):
) -> None:
super().__init__()
self.input_dims = input_dims
- self.encoder_dim = encoder_dim
self.hidden_dim = hidden_dim
- self.dropout_rate = dropout_rate
self.num_classes = num_classes
self.pad_index = pad_index
self.encoder = encoder
@@ -38,7 +34,7 @@ class ConvTransformer(nn.Module):
# positional encoding.
self.latent_encoder = nn.Sequential(
nn.Conv2d(
- in_channels=self.encoder_dim,
+ in_channels=self.encoder.out_channels,
out_channels=self.hidden_dim,
kernel_size=1,
),