diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-05 23:12:20 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-05 23:12:20 +0200 |
commit | 38202e9c6c1155d96ee0f6e9f337022ee4eeb7e3 (patch) | |
tree | aaa3f56495cdfbcc5f1434485fb237dfd6cf34a2 /text_recognizer/networks/image_transformer.py | |
parent | bef106191e20b42741984c407dc4884ab1ee49eb (diff) |
Add OmegaConf for configs
Diffstat (limited to 'text_recognizer/networks/image_transformer.py')
-rw-r--r-- | text_recognizer/networks/image_transformer.py | 42 |
1 files changed, 27 insertions, 15 deletions
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py index b9254c9..aa024e0 100644 --- a/text_recognizer/networks/image_transformer.py +++ b/text_recognizer/networks/image_transformer.py @@ -8,10 +8,12 @@ together with the target tokens. TODO: Local attention for transformer.j """ +import importlib import math -from typing import Any, Dict, List, Optional, Sequence, Type +from typing import Dict, List, Union, Sequence, Tuple, Type from einops import rearrange +from omegaconf import OmegaConf import torch from torch import nn from torch import Tensor @@ -32,8 +34,8 @@ class ImageTransformer(nn.Module): self, input_shape: Sequence[int], output_shape: Sequence[int], - backbone: Type[nn.Module], - mapping: Optional[List[str]] = None, + encoder: Union[OmegaConf, Dict], + mapping: str, num_decoder_layers: int = 4, hidden_dim: int = 256, num_heads: int = 4, @@ -51,8 +53,8 @@ class ImageTransformer(nn.Module): self.pad_index = inverse_mapping["<p>"] # Image backbone - self.backbone = backbone - self.latent_encoding = PositionalEncoding2D( + self.encoder = self._configure_encoder(encoder) + self.feature_map_encoding = PositionalEncoding2D( hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2] ) @@ -86,20 +88,30 @@ class ImageTransformer(nn.Module): self.head.weight.data.uniform_(-0.1, 0.1) nn.init.kaiming_normal_( - self.latent_encoding.weight.data, a=0, mode="fan_out", nonlinearity="relu" + self.feature_map_encoding.weight.data, + a=0, + mode="fan_out", + nonlinearity="relu", ) - if self.latent_encoding.bias is not None: + if self.feature_map_encoding.bias is not None: _, fan_out = nn.init._calculate_fan_in_and_fan_out( - self.latent_encoding.weight.data + self.feature_map_encoding.weight.data ) bound = 1 / math.sqrt(fan_out) - nn.init.normal_(self.latent_encoding.bias, -bound, bound) + nn.init.normal_(self.feature_map_encoding.bias, -bound, bound) + + @staticmethod + def _configure_encoder(encoder: Union[OmegaConf, NamedTuple]) -> Type[nn.Module]: + encoder = OmegaConf.create(encoder) + network_module = importlib.import_module("text_recognizer.networks") + encoder_class = getattr(network_module, encoder.type) + return encoder_class(**encoder.args) def _configure_mapping( - self, mapping: Optional[List[str]] + self, mapping: str ) -> Tuple[List[str], Dict[str, int]]: """Configures mapping.""" - if mapping is None: + if mapping == "emnist": mapping, inverse_mapping, _ = emnist_mapping() return mapping, inverse_mapping @@ -118,14 +130,14 @@ class ImageTransformer(nn.Module): """ # Extract image features. - latent = self.backbone(image) + image_features = self.encoder(image) # Add 2d encoding to the feature maps. - latent = self.latent_encoding(latent) + image_features = self.feature_map_encoding(image_features) # Collapse features maps height and width. - latent = rearrange(latent, "b c h w -> b (h w) c") - return latent + image_features = rearrange(image_features, "b c h w -> b (h w) c") + return image_features def decode(self, memory: Tensor, trg: Tensor) -> Tensor: """Decodes image features with transformer decoder.""" |