diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-22 22:38:57 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-22 22:38:57 +0100 |
commit | a8210efd341b3619ffeb57135a57c161b1d4f1cf (patch) | |
tree | 8b1a5c51694bcfc5c6f7c88218102f8c852ffc23 /text_recognizer | |
parent | 050e1bd284a173d2586ad4607e95d114691db563 (diff) |
Format conv transformer
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index da99bbf..f07b97d 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -1,11 +1,11 @@ """Vision transformer for character recognition.""" import math -from typing import List, Optional, Tuple, Type +from typing import Optional, Tuple, Type from loguru import logger as log from torch import nn, Tensor -from text_recognizer.networks.transformer.axial_attention.encoder import AxialEncoder +from text_recognizer.networks.transformer.axial_attention.encoder import AxialEncoder from text_recognizer.networks.transformer.embeddings.axial import ( AxialPositionalEmbedding, ) @@ -21,7 +21,7 @@ class ConvTransformer(nn.Module): hidden_dim: int, num_classes: int, pad_index: Tensor, - encoder: nn.Module, + encoder: Type[nn.Module], decoder: Decoder, axial_encoder: Optional[AxialEncoder], pixel_pos_embedding: AxialPositionalEmbedding, |