From 8afa8e1c6e9623b0dea86236da04b2b4173e9443 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 7 Apr 2021 22:12:10 +0200 Subject: Fixed typing and typos, train script load config, reformatted --- text_recognizer/networks/image_transformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'text_recognizer/networks/image_transformer.py') diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py index edebca9..9ed67a4 100644 --- a/text_recognizer/networks/image_transformer.py +++ b/text_recognizer/networks/image_transformer.py @@ -13,7 +13,7 @@ import math from typing import Dict, List, Union, Sequence, Tuple, Type from einops import rearrange -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf import torch from torch import nn from torch import Tensor @@ -34,7 +34,7 @@ class ImageTransformer(nn.Module): self, input_shape: Sequence[int], output_shape: Sequence[int], - encoder: Union[OmegaConf, Dict], + encoder: Union[DictConfig, Dict], mapping: str, num_decoder_layers: int = 4, hidden_dim: int = 256, @@ -101,7 +101,7 @@ class ImageTransformer(nn.Module): nn.init.normal_(self.feature_map_encoding.bias, -bound, bound) @staticmethod - def _configure_encoder(encoder: Union[OmegaConf, NamedTuple]) -> Type[nn.Module]: + def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: encoder = OmegaConf.create(encoder) network_module = importlib.import_module("text_recognizer.networks") encoder_class = getattr(network_module, encoder.type) -- cgit v1.2.3-70-g09d2