summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/image_transformer.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-05 23:12:20 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-05 23:12:20 +0200
commit38202e9c6c1155d96ee0f6e9f337022ee4eeb7e3 (patch)
treeaaa3f56495cdfbcc5f1434485fb237dfd6cf34a2 /text_recognizer/networks/image_transformer.py
parentbef106191e20b42741984c407dc4884ab1ee49eb (diff)
Add OmegaConf for configs
Diffstat (limited to 'text_recognizer/networks/image_transformer.py')
-rw-r--r--text_recognizer/networks/image_transformer.py42
1 files changed, 27 insertions, 15 deletions
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py
index b9254c9..aa024e0 100644
--- a/text_recognizer/networks/image_transformer.py
+++ b/text_recognizer/networks/image_transformer.py
@@ -8,10 +8,12 @@ together with the target tokens.
TODO: Local attention for transformer.j
"""
+import importlib
import math
-from typing import Any, Dict, List, Optional, Sequence, Type
+from typing import Dict, List, Union, Sequence, Tuple, Type
from einops import rearrange
+from omegaconf import OmegaConf
import torch
from torch import nn
from torch import Tensor
@@ -32,8 +34,8 @@ class ImageTransformer(nn.Module):
self,
input_shape: Sequence[int],
output_shape: Sequence[int],
- backbone: Type[nn.Module],
- mapping: Optional[List[str]] = None,
+ encoder: Union[OmegaConf, Dict],
+ mapping: str,
num_decoder_layers: int = 4,
hidden_dim: int = 256,
num_heads: int = 4,
@@ -51,8 +53,8 @@ class ImageTransformer(nn.Module):
self.pad_index = inverse_mapping["<p>"]
# Image backbone
- self.backbone = backbone
- self.latent_encoding = PositionalEncoding2D(
+ self.encoder = self._configure_encoder(encoder)
+ self.feature_map_encoding = PositionalEncoding2D(
hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2]
)
@@ -86,20 +88,30 @@ class ImageTransformer(nn.Module):
self.head.weight.data.uniform_(-0.1, 0.1)
nn.init.kaiming_normal_(
- self.latent_encoding.weight.data, a=0, mode="fan_out", nonlinearity="relu"
+ self.feature_map_encoding.weight.data,
+ a=0,
+ mode="fan_out",
+ nonlinearity="relu",
)
- if self.latent_encoding.bias is not None:
+ if self.feature_map_encoding.bias is not None:
_, fan_out = nn.init._calculate_fan_in_and_fan_out(
- self.latent_encoding.weight.data
+ self.feature_map_encoding.weight.data
)
bound = 1 / math.sqrt(fan_out)
- nn.init.normal_(self.latent_encoding.bias, -bound, bound)
+ nn.init.normal_(self.feature_map_encoding.bias, -bound, bound)
+
+ @staticmethod
+ def _configure_encoder(encoder: Union[OmegaConf, NamedTuple]) -> Type[nn.Module]:
+ encoder = OmegaConf.create(encoder)
+ network_module = importlib.import_module("text_recognizer.networks")
+ encoder_class = getattr(network_module, encoder.type)
+ return encoder_class(**encoder.args)
def _configure_mapping(
- self, mapping: Optional[List[str]]
+ self, mapping: str
) -> Tuple[List[str], Dict[str, int]]:
"""Configures mapping."""
- if mapping is None:
+ if mapping == "emnist":
mapping, inverse_mapping, _ = emnist_mapping()
return mapping, inverse_mapping
@@ -118,14 +130,14 @@ class ImageTransformer(nn.Module):
"""
# Extract image features.
- latent = self.backbone(image)
+ image_features = self.encoder(image)
# Add 2d encoding to the feature maps.
- latent = self.latent_encoding(latent)
+ image_features = self.feature_map_encoding(image_features)
# Collapse features maps height and width.
- latent = rearrange(latent, "b c h w -> b (h w) c")
- return latent
+ image_features = rearrange(image_features, "b c h w -> b (h w) c")
+ return image_features
def decode(self, memory: Tensor, trg: Tensor) -> Tensor:
"""Decodes image features with transformer decoder."""