summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/image_transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/image_transformer.py')
-rw-r--r--text_recognizer/networks/image_transformer.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py
index edebca9..9ed67a4 100644
--- a/text_recognizer/networks/image_transformer.py
+++ b/text_recognizer/networks/image_transformer.py
@@ -13,7 +13,7 @@ import math
from typing import Dict, List, Union, Sequence, Tuple, Type
from einops import rearrange
-from omegaconf import OmegaConf
+from omegaconf import DictConfig, OmegaConf
import torch
from torch import nn
from torch import Tensor
@@ -34,7 +34,7 @@ class ImageTransformer(nn.Module):
self,
input_shape: Sequence[int],
output_shape: Sequence[int],
- encoder: Union[OmegaConf, Dict],
+ encoder: Union[DictConfig, Dict],
mapping: str,
num_decoder_layers: int = 4,
hidden_dim: int = 256,
@@ -101,7 +101,7 @@ class ImageTransformer(nn.Module):
nn.init.normal_(self.feature_map_encoding.bias, -bound, bound)
@staticmethod
- def _configure_encoder(encoder: Union[OmegaConf, NamedTuple]) -> Type[nn.Module]:
+ def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]:
encoder = OmegaConf.create(encoder)
network_module = importlib.import_module("text_recognizer.networks")
encoder_class = getattr(network_module, encoder.type)