diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-30 23:07:20 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-30 23:07:20 +0200 |
commit | 52a4291e47ca23c9c7a43541f03280ec92aafde3 (patch) | |
tree | 025605bdbd0d94b17634581f7bb9a9a2d14facad /text_recognizer/networks | |
parent | 58bff6b69287abffc8df481ba6fb5fec9a072054 (diff) |
Bug fix for loading pretrained vq encoder
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/vq_transformer.py | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py index 5cdbab4..2121c33 100644 --- a/text_recognizer/networks/vq_transformer.py +++ b/text_recognizer/networks/vq_transformer.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import OrderedDict, Tuple, Union -from hydra import compose, initialize +from omegaconf import OmegaConf from hydra.utils import instantiate import torch from torch import Tensor @@ -31,7 +31,6 @@ class VqTransformer(ConvTransformer): self.encoder: VQVAE = None self.no_grad = no_grad - self._setup_encoder(pretrained_encoder_path) super().__init__( input_dims=input_dims, encoder_dim=encoder_dim, @@ -42,6 +41,7 @@ class VqTransformer(ConvTransformer): encoder=self.encoder, decoder=decoder, ) + self._setup_encoder(pretrained_encoder_path) def _load_state_dict(self, path: Path) -> OrderedDict: weights_path = list((path / "checkpoints").glob("epoch=*.ckpt"))[0] @@ -56,9 +56,9 @@ class VqTransformer(ConvTransformer): def _setup_encoder(self, pretrained_encoder_path: str,) -> None: """Load encoder module.""" - with initialize(config_path=pretrained_encoder_path): - cfg = compose(config_name="config") path = Path(__file__).resolve().parents[2] / pretrained_encoder_path + with open(path / "config.yaml") as f: + cfg = OmegaConf.load(f) state_dict = self._load_state_dict(path) self.encoder = instantiate(cfg.network) self.encoder.load_state_dict(state_dict) @@ -67,7 +67,8 @@ class VqTransformer(ConvTransformer): def _encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: z_e = self.encoder.encode(x) z_q, commitment_loss = self.encoder.quantize(z_e) - return z_q, commitment_loss + z = self.encoder.post_codebook_conv(z_q) + return z, commitment_loss def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Encodes an image into a discrete (VQ) latent representation. |