From 31d58f2108165802d26eb1c1bdb9e5f052b4dd26 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 6 Apr 2021 22:31:54 +0200 Subject: Fix network args --- text_recognizer/networks/__init__.py | 3 +++ text_recognizer/networks/image_transformer.py | 5 +++-- 2 files changed, 6 insertions(+), 2 deletions(-) (limited to 'text_recognizer') diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py index e69de29..4dcaf2e 100644 --- a/text_recognizer/networks/__init__.py +++ b/text_recognizer/networks/__init__.py @@ -0,0 +1,3 @@ +"""Network modules""" +from .image_transformer import ImageTransformer + diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py index 85a84d2..edebca9 100644 --- a/text_recognizer/networks/image_transformer.py +++ b/text_recognizer/networks/image_transformer.py @@ -5,7 +5,7 @@ i.e. feature maps. A 2d positional encoding is applied to the feature maps for spatial information. The resulting feature are then set to a transformer decoder together with the target tokens. -TODO: Local attention for transformer.j +TODO: Local attention for lower layer in attention. """ import importlib @@ -39,7 +39,7 @@ class ImageTransformer(nn.Module): num_decoder_layers: int = 4, hidden_dim: int = 256, num_heads: int = 4, - expansion_dim: int = 4, + expansion_dim: int = 1024, dropout_rate: float = 0.1, transformer_activation: str = "glu", ) -> None: @@ -109,6 +109,7 @@ class ImageTransformer(nn.Module): def _configure_mapping(self, mapping: str) -> Tuple[List[str], Dict[str, int]]: """Configures mapping.""" + # TODO: Fix me!!! if mapping == "emnist": mapping, inverse_mapping, _ = emnist_mapping() return mapping, inverse_mapping -- cgit v1.2.3-70-g09d2