In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch.nn.functional as F
import torch
from torch import nn
from torchsummary import summary
from importlib.util import find_spec
if find_spec("text_recognizer") is None:
    import sys
    sys.path.append('..')

In [2]:
from omegaconf import OmegaConf

In [5]:
path = "../training/configs/vqvae.yaml"

In [74]:
conf = OmegaConf.load(path)

In [75]:
print(OmegaConf.to_yaml(conf))

seed: 4711
network:
  desc: Configuration of the PyTorch neural network.
  type: ImageTransformer
  args:
    in_channels: 1
    channels:
    - 128
    - 64
    - 32
    kernel_sizes:
    - 4
    - 4
    - 4
    strides:
    - 2
    - 2
    - 2
    num_residual_layers: 4
    embedding_dim: 128
    num_embeddings: 1024
    upsampling: null
    beta: 6.6
    activation: leaky_relu
    dropout_rate: 0.25
model:
  desc: Configuration of the PyTorch Lightning model.
  type: LitTransformerModel
  args:
    optimizer:
      type: MADGRAD
      args:
        lr: 0.001
        momentum: 0.9
        weight_decay: 0
        eps: 1.0e-06
    lr_scheduler:
      type: OneCycle
      args:
        interval: step
        max_lr: 0.001
        three_phase: true
        epochs: 512
        steps_per_epoch: 1246
    criterion:
      type: CrossEntropyLoss
      args:
        weight: None
        ignore_index: -100
        reduction: mean
    monitor: val_loss
    mapping: sentence_piece
data:
  desc: C

In [76]:
from text_recognizer.networks import VQVAE

In [78]:
vae = VQVAE(**conf.network.args)

In [79]:
vae

VQVAE(
  (encoder): Encoder(
    (encoder): Sequential(
      (0): Sequential(
        (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (1): LeakyReLU(negative_slope=0.01, inplace=True)
      )
      (1): Dropout(p=0.25, inplace=False)
      (2): Sequential(
        (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (1): LeakyReLU(negative_slope=0.01, inplace=True)
      )
      (3): Dropout(p=0.25, inplace=False)
      (4): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (1): LeakyReLU(negative_slope=0.01, inplace=True)
      )
      (5): Dropout(p=0.25, inplace=False)
      (6): _ResidualBlock(
        (block): Sequential(
          (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (3): Dropout(p=0.25, inplace=False)
    

In [80]:
torch.Tensor([1])

tensor([1.])

In [81]:
datum = torch.randn([2, 1, 576, 640])

In [82]:
datum.shape

torch.Size([2, 1, 576, 640])

In [85]:
vae.encoder(datum)[0].shape

torch.Size([2, 128, 72, 80])

In [87]:
vae(datum)[0].shape

torch.Size([2, 1, 576, 640])