summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-30 23:07:20 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-30 23:07:20 +0200
commit52a4291e47ca23c9c7a43541f03280ec92aafde3 (patch)
tree025605bdbd0d94b17634581f7bb9a9a2d14facad /text_recognizer/networks
parent58bff6b69287abffc8df481ba6fb5fec9a072054 (diff)
Bug fix for loading pretrained vq encoder
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/vq_transformer.py11
1 files changed, 6 insertions, 5 deletions
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py
index 5cdbab4..2121c33 100644
--- a/text_recognizer/networks/vq_transformer.py
+++ b/text_recognizer/networks/vq_transformer.py
@@ -2,7 +2,7 @@
from pathlib import Path
from typing import OrderedDict, Tuple, Union
-from hydra import compose, initialize
+from omegaconf import OmegaConf
from hydra.utils import instantiate
import torch
from torch import Tensor
@@ -31,7 +31,6 @@ class VqTransformer(ConvTransformer):
self.encoder: VQVAE = None
self.no_grad = no_grad
- self._setup_encoder(pretrained_encoder_path)
super().__init__(
input_dims=input_dims,
encoder_dim=encoder_dim,
@@ -42,6 +41,7 @@ class VqTransformer(ConvTransformer):
encoder=self.encoder,
decoder=decoder,
)
+ self._setup_encoder(pretrained_encoder_path)
def _load_state_dict(self, path: Path) -> OrderedDict:
weights_path = list((path / "checkpoints").glob("epoch=*.ckpt"))[0]
@@ -56,9 +56,9 @@ class VqTransformer(ConvTransformer):
def _setup_encoder(self, pretrained_encoder_path: str,) -> None:
"""Load encoder module."""
- with initialize(config_path=pretrained_encoder_path):
- cfg = compose(config_name="config")
path = Path(__file__).resolve().parents[2] / pretrained_encoder_path
+ with open(path / "config.yaml") as f:
+ cfg = OmegaConf.load(f)
state_dict = self._load_state_dict(path)
self.encoder = instantiate(cfg.network)
self.encoder.load_state_dict(state_dict)
@@ -67,7 +67,8 @@ class VqTransformer(ConvTransformer):
def _encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
z_e = self.encoder.encode(x)
z_q, commitment_loss = self.encoder.quantize(z_e)
- return z_q, commitment_loss
+ z = self.encoder.post_codebook_conv(z_q)
+ return z, commitment_loss
def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""Encodes an image into a discrete (VQ) latent representation.