In [4]:
%load_ext autoreload
%autoreload 2

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch.nn.functional as F
import torch
from torch import nn
from torchsummary import summary
from importlib.util import find_spec
if find_spec("text_recognizer") is None:
 import sys
 sys.path.append('..')

In [16]:
from omegaconf import OmegaConf

In [17]:
path = "../training/configs/vqvae.yaml"

In [18]:
conf = OmegaConf.load(path)

In [19]:
print(OmegaConf.to_yaml(conf))

seed: 4711
network:
 desc: Configuration of the PyTorch neural network.
 type: VQVAE
 args:
 in_channels: 1
 channels:
 - 32
 - 64
 - 96
 - 96
 - 128
 kernel_sizes:
 - 4
 - 4
 - 4
 - 4
 - 4
 strides:
 - 2
 - 2
 - 2
 - 2
 - 2
 num_residual_layers: 2
 embedding_dim: 128
 num_embeddings: 1024
 upsampling: null
 beta: 0.25
 activation: leaky_relu
 dropout_rate: 0.1
model:
 desc: Configuration of the PyTorch Lightning model.
 type: LitVQVAEModel
 args:
 optimizer:
 type: MADGRAD
 args:
 lr: 0.001
 momentum: 0.9
 weight_decay: 0
 eps: 1.0e-06
 lr_scheduler:
 type: OneCycleLR
 args:
 interval: step
 max_lr: 0.001
 three_phase: true
 epochs: 1024
 steps_per_epoch: 317
 criterion:
 type: MSELoss
 args:
 reduction: mean
 monitor: val_loss
 mapping: sentence_piece
data:
 desc: Configuration of the training/test data.
 type: IAMExtendedParagraphs
 args:
 batch_size: 64
 num_workers: 12
 train_fraction: 0.8
 augment: true
callbacks:
- type: ModelCheckpoint
 args:
 monitor: val_loss
 mode: min
 save

In [20]:
from text_recognizer.networks import VQVAE

In [21]:
vae = VQVAE(**conf.network.args)

In [22]:
vae

VQVAE(
 (encoder): Encoder(
 (encoder): Sequential(
 (0): Sequential(
 (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.01, inplace=True)
 )
 (1): Dropout(p=0.1, inplace=False)
 (2): Sequential(
 (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.01, inplace=True)
 )
 (3): Dropout(p=0.1, inplace=False)
 (4): Sequential(
 (0): Conv2d(64, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.01, inplace=True)
 )
 (5): Dropout(p=0.1, inplace=False)
 (6): Sequential(
 (0): Conv2d(96, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.01, inplace=True)
 )
 (7): Dropout(p=0.1, inplace=False)
 (8): Sequential(
 (0): Conv2d(96, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.01, inplace=True)
 )
 (9): Dropout(p=0.1, inplace=False)
 (10): _ResidualBlock(
 (block): Sequential(
 (0

In [5]:
datum = torch.randn([2, 1, 576, 640])

In [6]:
proj = nn.Conv2d(1, 32, kernel_size=16, stride=16)

In [7]:
x = proj(datum)

In [8]:
x.shape

torch.Size([2, 32, 36, 40])

In [9]:
xx = x.flatten(2)

In [10]:
xx.shape

torch.Size([2, 32, 1440])

In [11]:
xxx = xx.transpose(1,2)

In [12]:
xxx.shape

torch.Size([2, 1440, 32])

In [13]:
from einops import rearrange

In [14]:
xxxx = rearrange(x, "b c h w -> b ( h w ) c")

In [15]:
xxxx.shape

torch.Size([2, 1440, 32])

In [None]:
 B, N, C = x.shape
 H, W = size
 assert N == 1 + H * W

 # Extract CLS token and image tokens.
 cls_token, img_tokens = x[:, :1], x[:, 1:] # Shape: [B, 1, C], [B, H*W, C].
 
 # Depthwise convolution.
 feat = img_tokens.transpose(1, 2).view(B, C, H, W)

In [22]:
xxx.transpose(1, 2).view(2, 32, 36, 40).shape

torch.Size([2, 32, 36, 40])

In [18]:
576 / 8

72.0

In [19]:
640 / 8

80.0

In [26]:
datum.shape

torch.Size([2, 1, 576, 640])

In [27]:
vae.encoder(datum)[0].shape

torch.Size([2, 128, 18, 20])

In [87]:
vae(datum)[0].shape

torch.Size([2, 1, 576, 640])