diff options
Diffstat (limited to 'text_recognizer/networks/image_transformer.py')
-rw-r--r-- | text_recognizer/networks/image_transformer.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py index edebca9..9ed67a4 100644 --- a/text_recognizer/networks/image_transformer.py +++ b/text_recognizer/networks/image_transformer.py @@ -13,7 +13,7 @@ import math from typing import Dict, List, Union, Sequence, Tuple, Type from einops import rearrange -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf import torch from torch import nn from torch import Tensor @@ -34,7 +34,7 @@ class ImageTransformer(nn.Module): self, input_shape: Sequence[int], output_shape: Sequence[int], - encoder: Union[OmegaConf, Dict], + encoder: Union[DictConfig, Dict], mapping: str, num_decoder_layers: int = 4, hidden_dim: int = 256, @@ -101,7 +101,7 @@ class ImageTransformer(nn.Module): nn.init.normal_(self.feature_map_encoding.bias, -bound, bound) @staticmethod - def _configure_encoder(encoder: Union[OmegaConf, NamedTuple]) -> Type[nn.Module]: + def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: encoder = OmegaConf.create(encoder) network_module = importlib.import_module("text_recognizer.networks") encoder_class = getattr(network_module, encoder.type) |