summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conv_transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
-rw-r--r--text_recognizer/networks/conv_transformer.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index da99bbf..f07b97d 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -1,11 +1,11 @@
"""Vision transformer for character recognition."""
import math
-from typing import List, Optional, Tuple, Type
+from typing import Optional, Tuple, Type
from loguru import logger as log
from torch import nn, Tensor
-from text_recognizer.networks.transformer.axial_attention.encoder import AxialEncoder
+from text_recognizer.networks.transformer.axial_attention.encoder import AxialEncoder
from text_recognizer.networks.transformer.embeddings.axial import (
AxialPositionalEmbedding,
)
@@ -21,7 +21,7 @@ class ConvTransformer(nn.Module):
hidden_dim: int,
num_classes: int,
pad_index: Tensor,
- encoder: nn.Module,
+ encoder: Type[nn.Module],
decoder: Decoder,
axial_encoder: Optional[AxialEncoder],
pixel_pos_embedding: AxialPositionalEmbedding,