summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-19 21:06:56 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-19 21:06:56 +0200
commitd90e4109f2066f0396cbca31dda34579e3c608d0 (patch)
treed3a6d712462a6946097b08827307ac3c3167794a
parentb25c07af1986a73c2b129bfdcbefbc1dceef1885 (diff)
Add loading of encoder in vq transformer network
-rw-r--r--text_recognizer/networks/vq_transformer.py54
1 files changed, 29 insertions, 25 deletions
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py
index 69f68fd..5cdbab4 100644
--- a/text_recognizer/networks/vq_transformer.py
+++ b/text_recognizer/networks/vq_transformer.py
@@ -1,7 +1,9 @@
"""Vector quantized encoder, transformer decoder."""
from pathlib import Path
-from typing import Tuple, Optional
+from typing import OrderedDict, Tuple, Union
+from hydra import compose, initialize
+from hydra.utils import instantiate
import torch
from torch import Tensor
@@ -21,11 +23,15 @@ class VqTransformer(ConvTransformer):
dropout_rate: float,
num_classes: int,
pad_index: Tensor,
- encoder: VQVAE,
decoder: Decoder,
no_grad: bool,
- pretrained_encoder_path: Optional[str] = None,
+ pretrained_encoder_path: str,
) -> None:
+ # For typing
+ self.encoder: VQVAE = None
+ self.no_grad = no_grad
+
+ self._setup_encoder(pretrained_encoder_path)
super().__init__(
input_dims=input_dims,
encoder_dim=encoder_dim,
@@ -33,32 +39,30 @@ class VqTransformer(ConvTransformer):
dropout_rate=dropout_rate,
num_classes=num_classes,
pad_index=pad_index,
- encoder=encoder,
+ encoder=self.encoder,
decoder=decoder,
)
- # For typing
- self.encoder: VQVAE
-
- self.no_grad = no_grad
-
- if pretrained_encoder_path is not None:
- self.pretrained_encoder_path = (
- Path(__file__).resolve().parents[2] / pretrained_encoder_path
- )
- self._setup_encoder()
- else:
- self.pretrained_encoder_path = None
-
- def _load_pretrained_encoder(self) -> None:
- self.encoder.load_state_dict(
- torch.load(self.pretrained_encoder_path)["state_dict"]["network"]
- )
- def _setup_encoder(self) -> None:
- """Remove unecessary layers."""
- self._load_pretrained_encoder()
+ def _load_state_dict(self, path: Path) -> OrderedDict:
+ weights_path = list((path / "checkpoints").glob("epoch=*.ckpt"))[0]
+ renamed_state_dict = OrderedDict()
+ state_dict = torch.load(weights_path)["state_dict"]
+ for key in state_dict.keys():
+ if "network" in key:
+ new_key = key.removeprefix("network.")
+ renamed_state_dict[new_key] = state_dict[key]
+ del state_dict
+ return renamed_state_dict
+
+ def _setup_encoder(self, pretrained_encoder_path: str,) -> None:
+ """Load encoder module."""
+ with initialize(config_path=pretrained_encoder_path):
+ cfg = compose(config_name="config")
+ path = Path(__file__).resolve().parents[2] / pretrained_encoder_path
+ state_dict = self._load_state_dict(path)
+ self.encoder = instantiate(cfg.network)
+ self.encoder.load_state_dict(state_dict)
del self.encoder.decoder
- # del self.encoder.post_codebook_conv
def _encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
z_e = self.encoder.encode(x)