summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conv_transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
-rw-r--r--text_recognizer/networks/conv_transformer.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index f3ba49d..b1a101e 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -4,7 +4,6 @@ from typing import Tuple
from torch import nn, Tensor
-from text_recognizer.networks.encoders.efficientnet import EfficientNet
from text_recognizer.networks.transformer.layers import Decoder
from text_recognizer.networks.transformer.positional_encodings import (
PositionalEncoding,
@@ -18,15 +17,17 @@ class ConvTransformer(nn.Module):
def __init__(
self,
input_dims: Tuple[int, int, int],
+ encoder_dim: int,
hidden_dim: int,
dropout_rate: float,
num_classes: int,
pad_index: Tensor,
- encoder: EfficientNet,
+ encoder: nn.Module,
decoder: Decoder,
) -> None:
super().__init__()
self.input_dims = input_dims
+ self.encoder_dim = encoder_dim
self.hidden_dim = hidden_dim
self.dropout_rate = dropout_rate
self.num_classes = num_classes
@@ -38,7 +39,7 @@ class ConvTransformer(nn.Module):
# positional encoding.
self.latent_encoder = nn.Sequential(
nn.Conv2d(
- in_channels=self.encoder.out_channels,
+ in_channels=self.encoder_dim,
out_channels=self.hidden_dim,
kernel_size=1,
),