summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/image_transformer.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-06 22:31:54 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-06 22:31:54 +0200
commit31d58f2108165802d26eb1c1bdb9e5f052b4dd26 (patch)
tree6f5c2dcb0eef814c71a34df98444be7e8f1d0b43 /text_recognizer/networks/image_transformer.py
parent5e11924ca6aaea7898caca94675f41f67706a406 (diff)
Fix network args
Diffstat (limited to 'text_recognizer/networks/image_transformer.py')
-rw-r--r--text_recognizer/networks/image_transformer.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py
index 85a84d2..edebca9 100644
--- a/text_recognizer/networks/image_transformer.py
+++ b/text_recognizer/networks/image_transformer.py
@@ -5,7 +5,7 @@ i.e. feature maps. A 2d positional encoding is applied to the feature maps for
spatial information. The resulting feature are then set to a transformer decoder
together with the target tokens.
-TODO: Local attention for transformer.j
+TODO: Local attention for lower layer in attention.
"""
import importlib
@@ -39,7 +39,7 @@ class ImageTransformer(nn.Module):
num_decoder_layers: int = 4,
hidden_dim: int = 256,
num_heads: int = 4,
- expansion_dim: int = 4,
+ expansion_dim: int = 1024,
dropout_rate: float = 0.1,
transformer_activation: str = "glu",
) -> None:
@@ -109,6 +109,7 @@ class ImageTransformer(nn.Module):
def _configure_mapping(self, mapping: str) -> Tuple[List[str], Dict[str, int]]:
"""Configures mapping."""
+ # TODO: Fix me!!!
if mapping == "emnist":
mapping, inverse_mapping, _ = emnist_mapping()
return mapping, inverse_mapping