In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch.nn.functional as F
import torch
from torch import nn
from torchsummary import summary
from importlib.util import find_spec
if find_spec("text_recognizer") is None:
    import sys
    sys.path.append('..')

from text_recognizer.networks.transformer.vit import ViT
from text_recognizer.networks.transformer.transformer import Transformer
from text_recognizer.networks.transformer.layers import Decoder
from text_recognizer.networks.transformer.nystromer.nystromer import Nystromer

In [2]:
torch.cuda.is_available()

True

In [3]:
decoder = Decoder(dim=64, depth=2, num_heads=4, ff_kwargs={}, attn_kwargs={}, cross_attend=True)

In [4]:
decoder.cuda()

Decoder(
  (layers): ModuleList(
    (0): ModuleList(
      (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (1): Attention(
        (qkv_fn): Sequential(
          (0): Linear(in_features=64, out_features=12288, bias=False)
          (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (fc): Linear(in_features=4096, out_features=64, bias=True)
      )
      (2): Residual()
    )
    (1): ModuleList(
      (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (1): Attention(
        (qkv_fn): Sequential(
          (0): Linear(in_features=64, out_features=12288, bias=False)
          (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (fc): Linear(in_features=4096, out_features=64, bias=True)
      )
      (2): Residual()
    )
    (2): ModuleList(
      (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      

In [5]:
transformer_decoder = Transformer(num_tokens=90, max_seq_len=690, attn_layers=decoder, emb_dim=64, emb_dropout=0.1)

In [6]:
transformer_decoder.cuda()

Transformer(
  (attn_layers): Decoder(
    (layers): ModuleList(
      (0): ModuleList(
        (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (1): Attention(
          (qkv_fn): Sequential(
            (0): Linear(in_features=64, out_features=12288, bias=False)
            (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)
          )
          (dropout): Dropout(p=0.0, inplace=False)
          (fc): Linear(in_features=4096, out_features=64, bias=True)
        )
        (2): Residual()
      )
      (1): ModuleList(
        (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (1): Attention(
          (qkv_fn): Sequential(
            (0): Linear(in_features=64, out_features=12288, bias=False)
            (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)
          )
          (dropout): Dropout(p=0.0, inplace=False)
          (fc): Linear(in_features=4096, out_features=64, bias=True)
        )
        (2): Residual()
      )
      (2): 

In [7]:
efficient_transformer = Nystromer(
    dim = 64,
    depth = 4,
    num_heads = 8,
    num_landmarks = 64
)

In [8]:
v = ViT(
    dim = 64,
    image_size = (576, 640),
    patch_size = (64, 64),
    transformer = efficient_transformer
).cuda()

In [9]:
t = torch.randn(4, 1, 576, 640).cuda()

In [10]:
o = v(t)

In [11]:
caption = torch.randint(0, 90, (16, 690)).cuda()

In [12]:
o.shape

torch.Size([4, 90, 64])

In [13]:
caption.shape

torch.Size([16, 690])

In [14]:
transformer_decoder(caption, context = o).shape # (1, 1024, 20000)

torch.Size([16, 690, 90])

In [None]:
from text_recognizer.networks.encoders.efficientnet import EfficientNet

In [None]:
en = EfficientNet()

In [None]:
en.cuda()

In [None]:
summary(en, (1, 576, 640))

In [None]:
type(efficient_transformer)

In [None]:
efficient_transformer = efficient_transformer(num_landmarks=256)

In [None]:
efficient_transformer()

In [None]:
from omegaconf import OmegaConf

In [None]:
path = "../training/configs/vqvae.yaml"

In [None]:
conf = OmegaConf.load(path)

In [None]:
print(OmegaConf.to_yaml(conf))

In [None]:
from text_recognizer.networks import VQVAE

In [None]:
vae = VQVAE(**conf.network.args)

In [None]:
vae

In [None]:
datum = torch.randn([2, 1, 576, 640])

In [None]:
vae.encoder(datum)[0].shape

In [None]:
vae(datum)[0].shape

In [None]:
datum = torch.randn([2, 1, 576, 640])

In [None]:
trg = torch.randint(0, 1000, [2, 682])

In [None]:
trg.shape

In [None]:
datum = torch.randn([2, 1, 224, 224])

In [None]:
en(t).shape

In [None]:
path = "../training/configs/cnn_transformer.yaml"

In [None]:
conf = OmegaConf.load(path)

In [None]:
print(OmegaConf.to_yaml(conf))

In [None]:
from text_recognizer.networks.cnn_transformer import CNNTransformer

In [None]:
t = CNNTransformer(input_shape=(1, 576, 640), output_shape=(682, 1), **conf.network.args)

In [None]:
t.encode(datum).shape

In [None]:
trg.shape

In [None]:
t(datum, trg).shape

In [None]:
b, n = 16, 128
device = "cpu"

In [None]:
x = lambda: torch.ones((b, n), device=device).bool()

In [None]:
x().shape

In [None]:
torch.ones((b, n), device=device).bool().shape

In [None]:
x = torch.randn(1, 1, 576, 640)

In [None]:
576 // 32

In [None]:
640 // 32

In [None]:
18 * 20

In [None]:
x = torch.randn(1, 1, 144, 160)

In [None]:
from einops import rearrange

In [None]:
patch_size=16
p = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size)

In [None]:
p.shape