summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/networks/__init__.py3
-rw-r--r--text_recognizer/networks/image_transformer.py5
2 files changed, 6 insertions, 2 deletions
diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py
index e69de29..4dcaf2e 100644
--- a/text_recognizer/networks/__init__.py
+++ b/text_recognizer/networks/__init__.py
@@ -0,0 +1,3 @@
+"""Network modules"""
+from .image_transformer import ImageTransformer
+
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py
index 85a84d2..edebca9 100644
--- a/text_recognizer/networks/image_transformer.py
+++ b/text_recognizer/networks/image_transformer.py
@@ -5,7 +5,7 @@ i.e. feature maps. A 2d positional encoding is applied to the feature maps for
spatial information. The resulting feature are then set to a transformer decoder
together with the target tokens.
-TODO: Local attention for transformer.j
+TODO: Local attention for lower layer in attention.
"""
import importlib
@@ -39,7 +39,7 @@ class ImageTransformer(nn.Module):
num_decoder_layers: int = 4,
hidden_dim: int = 256,
num_heads: int = 4,
- expansion_dim: int = 4,
+ expansion_dim: int = 1024,
dropout_rate: float = 0.1,
transformer_activation: str = "glu",
) -> None:
@@ -109,6 +109,7 @@ class ImageTransformer(nn.Module):
def _configure_mapping(self, mapping: str) -> Tuple[List[str], Dict[str, int]]:
"""Configures mapping."""
+ # TODO: Fix me!!!
if mapping == "emnist":
mapping, inverse_mapping, _ = emnist_mapping()
return mapping, inverse_mapping