summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-22 22:38:57 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-22 22:38:57 +0100
commita8210efd341b3619ffeb57135a57c161b1d4f1cf (patch)
tree8b1a5c51694bcfc5c6f7c88218102f8c852ffc23
parent050e1bd284a173d2586ad4607e95d114691db563 (diff)
Format conv transformer
-rw-r--r--text_recognizer/networks/conv_transformer.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index da99bbf..f07b97d 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -1,11 +1,11 @@
"""Vision transformer for character recognition."""
import math
-from typing import List, Optional, Tuple, Type
+from typing import Optional, Tuple, Type
from loguru import logger as log
from torch import nn, Tensor
-from text_recognizer.networks.transformer.axial_attention.encoder import AxialEncoder
+from text_recognizer.networks.transformer.axial_attention.encoder import AxialEncoder
from text_recognizer.networks.transformer.embeddings.axial import (
AxialPositionalEmbedding,
)
@@ -21,7 +21,7 @@ class ConvTransformer(nn.Module):
hidden_dim: int,
num_classes: int,
pad_index: Tensor,
- encoder: nn.Module,
+ encoder: Type[nn.Module],
decoder: Decoder,
axial_encoder: Optional[AxialEncoder],
pixel_pos_embedding: AxialPositionalEmbedding,