diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-07 22:12:10 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-07 22:12:10 +0200 |
commit | 8afa8e1c6e9623b0dea86236da04b2b4173e9443 (patch) | |
tree | 4c9462507b3b3076aa26f08ab629f64b90aed2cb /text_recognizer/networks/image_transformer.py | |
parent | 33190bc9c0c377edab280efe4b0bd0e53bb6cb00 (diff) |
Fixed typing and typos, train script load config, reformatted
Diffstat (limited to 'text_recognizer/networks/image_transformer.py')
-rw-r--r-- | text_recognizer/networks/image_transformer.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py index edebca9..9ed67a4 100644 --- a/text_recognizer/networks/image_transformer.py +++ b/text_recognizer/networks/image_transformer.py @@ -13,7 +13,7 @@ import math from typing import Dict, List, Union, Sequence, Tuple, Type from einops import rearrange -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf import torch from torch import nn from torch import Tensor @@ -34,7 +34,7 @@ class ImageTransformer(nn.Module): self, input_shape: Sequence[int], output_shape: Sequence[int], - encoder: Union[OmegaConf, Dict], + encoder: Union[DictConfig, Dict], mapping: str, num_decoder_layers: int = 4, hidden_dim: int = 256, @@ -101,7 +101,7 @@ class ImageTransformer(nn.Module): nn.init.normal_(self.feature_map_encoding.bias, -bound, bound) @staticmethod - def _configure_encoder(encoder: Union[OmegaConf, NamedTuple]) -> Type[nn.Module]: + def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: encoder = OmegaConf.create(encoder) network_module = importlib.import_module("text_recognizer.networks") encoder_class = getattr(network_module, encoder.type) |