From 52a4291e47ca23c9c7a43541f03280ec92aafde3 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 30 Sep 2021 23:07:20 +0200 Subject: Bug fix for loading pretrained vq encoder --- text_recognizer/networks/vq_transformer.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py index 5cdbab4..2121c33 100644 --- a/text_recognizer/networks/vq_transformer.py +++ b/text_recognizer/networks/vq_transformer.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import OrderedDict, Tuple, Union -from hydra import compose, initialize +from omegaconf import OmegaConf from hydra.utils import instantiate import torch from torch import Tensor @@ -31,7 +31,6 @@ class VqTransformer(ConvTransformer): self.encoder: VQVAE = None self.no_grad = no_grad - self._setup_encoder(pretrained_encoder_path) super().__init__( input_dims=input_dims, encoder_dim=encoder_dim, @@ -42,6 +41,7 @@ class VqTransformer(ConvTransformer): encoder=self.encoder, decoder=decoder, ) + self._setup_encoder(pretrained_encoder_path) def _load_state_dict(self, path: Path) -> OrderedDict: weights_path = list((path / "checkpoints").glob("epoch=*.ckpt"))[0] @@ -56,9 +56,9 @@ class VqTransformer(ConvTransformer): def _setup_encoder(self, pretrained_encoder_path: str,) -> None: """Load encoder module.""" - with initialize(config_path=pretrained_encoder_path): - cfg = compose(config_name="config") path = Path(__file__).resolve().parents[2] / pretrained_encoder_path + with open(path / "config.yaml") as f: + cfg = OmegaConf.load(f) state_dict = self._load_state_dict(path) self.encoder = instantiate(cfg.network) self.encoder.load_state_dict(state_dict) @@ -67,7 +67,8 @@ class VqTransformer(ConvTransformer): def _encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: z_e = self.encoder.encode(x) z_q, commitment_loss = self.encoder.quantize(z_e) - return z_q, commitment_loss + z = self.encoder.post_codebook_conv(z_q) + return z, commitment_loss def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Encodes an image into a discrete (VQ) latent representation. -- cgit v1.2.3-70-g09d2