summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/image_transformer.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-07 22:12:10 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-07 22:12:10 +0200
commit8afa8e1c6e9623b0dea86236da04b2b4173e9443 (patch)
tree4c9462507b3b3076aa26f08ab629f64b90aed2cb /text_recognizer/networks/image_transformer.py
parent33190bc9c0c377edab280efe4b0bd0e53bb6cb00 (diff)
Fixed typing and typos, train script load config, reformatted
Diffstat (limited to 'text_recognizer/networks/image_transformer.py')
-rw-r--r--text_recognizer/networks/image_transformer.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py
index edebca9..9ed67a4 100644
--- a/text_recognizer/networks/image_transformer.py
+++ b/text_recognizer/networks/image_transformer.py
@@ -13,7 +13,7 @@ import math
from typing import Dict, List, Union, Sequence, Tuple, Type
from einops import rearrange
-from omegaconf import OmegaConf
+from omegaconf import DictConfig, OmegaConf
import torch
from torch import nn
from torch import Tensor
@@ -34,7 +34,7 @@ class ImageTransformer(nn.Module):
self,
input_shape: Sequence[int],
output_shape: Sequence[int],
- encoder: Union[OmegaConf, Dict],
+ encoder: Union[DictConfig, Dict],
mapping: str,
num_decoder_layers: int = 4,
hidden_dim: int = 256,
@@ -101,7 +101,7 @@ class ImageTransformer(nn.Module):
nn.init.normal_(self.feature_map_encoding.bias, -bound, bound)
@staticmethod
- def _configure_encoder(encoder: Union[OmegaConf, NamedTuple]) -> Type[nn.Module]:
+ def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]:
encoder = OmegaConf.create(encoder)
network_module = importlib.import_module("text_recognizer.networks")
encoder_class = getattr(network_module, encoder.type)