diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/networks/__init__.py | 3 | ||||
-rw-r--r-- | text_recognizer/networks/image_transformer.py | 5 |
2 files changed, 6 insertions, 2 deletions
diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py index e69de29..4dcaf2e 100644 --- a/text_recognizer/networks/__init__.py +++ b/text_recognizer/networks/__init__.py @@ -0,0 +1,3 @@ +"""Network modules""" +from .image_transformer import ImageTransformer + diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py index 85a84d2..edebca9 100644 --- a/text_recognizer/networks/image_transformer.py +++ b/text_recognizer/networks/image_transformer.py @@ -5,7 +5,7 @@ i.e. feature maps. A 2d positional encoding is applied to the feature maps for spatial information. The resulting feature are then set to a transformer decoder together with the target tokens. -TODO: Local attention for transformer.j +TODO: Local attention for lower layer in attention. """ import importlib @@ -39,7 +39,7 @@ class ImageTransformer(nn.Module): num_decoder_layers: int = 4, hidden_dim: int = 256, num_heads: int = 4, - expansion_dim: int = 4, + expansion_dim: int = 1024, dropout_rate: float = 0.1, transformer_activation: str = "glu", ) -> None: @@ -109,6 +109,7 @@ class ImageTransformer(nn.Module): def _configure_mapping(self, mapping: str) -> Tuple[List[str], Dict[str, int]]: """Configures mapping.""" + # TODO: Fix me!!! if mapping == "emnist": mapping, inverse_mapping, _ = emnist_mapping() return mapping, inverse_mapping |