diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/networks/vq_transformer.py | 54 |
1 files changed, 29 insertions, 25 deletions
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py index 69f68fd..5cdbab4 100644 --- a/text_recognizer/networks/vq_transformer.py +++ b/text_recognizer/networks/vq_transformer.py @@ -1,7 +1,9 @@ """Vector quantized encoder, transformer decoder.""" from pathlib import Path -from typing import Tuple, Optional +from typing import OrderedDict, Tuple, Union +from hydra import compose, initialize +from hydra.utils import instantiate import torch from torch import Tensor @@ -21,11 +23,15 @@ class VqTransformer(ConvTransformer): dropout_rate: float, num_classes: int, pad_index: Tensor, - encoder: VQVAE, decoder: Decoder, no_grad: bool, - pretrained_encoder_path: Optional[str] = None, + pretrained_encoder_path: str, ) -> None: + # For typing + self.encoder: VQVAE = None + self.no_grad = no_grad + + self._setup_encoder(pretrained_encoder_path) super().__init__( input_dims=input_dims, encoder_dim=encoder_dim, @@ -33,32 +39,30 @@ class VqTransformer(ConvTransformer): dropout_rate=dropout_rate, num_classes=num_classes, pad_index=pad_index, - encoder=encoder, + encoder=self.encoder, decoder=decoder, ) - # For typing - self.encoder: VQVAE - - self.no_grad = no_grad - - if pretrained_encoder_path is not None: - self.pretrained_encoder_path = ( - Path(__file__).resolve().parents[2] / pretrained_encoder_path - ) - self._setup_encoder() - else: - self.pretrained_encoder_path = None - - def _load_pretrained_encoder(self) -> None: - self.encoder.load_state_dict( - torch.load(self.pretrained_encoder_path)["state_dict"]["network"] - ) - def _setup_encoder(self) -> None: - """Remove unecessary layers.""" - self._load_pretrained_encoder() + def _load_state_dict(self, path: Path) -> OrderedDict: + weights_path = list((path / "checkpoints").glob("epoch=*.ckpt"))[0] + renamed_state_dict = OrderedDict() + state_dict = torch.load(weights_path)["state_dict"] + for key in state_dict.keys(): + if "network" in key: + new_key = key.removeprefix("network.") + renamed_state_dict[new_key] = state_dict[key] + del state_dict + return renamed_state_dict + + def _setup_encoder(self, pretrained_encoder_path: str,) -> None: + """Load encoder module.""" + with initialize(config_path=pretrained_encoder_path): + cfg = compose(config_name="config") + path = Path(__file__).resolve().parents[2] / pretrained_encoder_path + state_dict = self._load_state_dict(path) + self.encoder = instantiate(cfg.network) + self.encoder.load_state_dict(state_dict) del self.encoder.decoder - # del self.encoder.post_codebook_conv def _encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: z_e = self.encoder.encode(x) |