summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-06 02:42:45 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-06 02:42:45 +0200
commit3ab82ad36bce6fa698a13a029a0694b75a5947b7 (patch)
tree136f71a62d60e3ccf01e1f95d64bb4d9f9c9befe /text_recognizer/networks
parent1bccf71cf4eec335001b50a8fbc0c991d0e6d13a (diff)
Fix VQVAE into en/decoder, bug in wandb artifact code uploading
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/conv_transformer.py7
-rw-r--r--text_recognizer/networks/vq_transformer.py50
-rw-r--r--text_recognizer/networks/vqvae/__init__.py1
-rw-r--r--text_recognizer/networks/vqvae/attention.py7
-rw-r--r--text_recognizer/networks/vqvae/decoder.py83
-rw-r--r--text_recognizer/networks/vqvae/encoder.py82
-rw-r--r--text_recognizer/networks/vqvae/norm.py4
-rw-r--r--text_recognizer/networks/vqvae/pixelcnn.py165
-rw-r--r--text_recognizer/networks/vqvae/quantizer.py15
-rw-r--r--text_recognizer/networks/vqvae/residual.py53
-rw-r--r--text_recognizer/networks/vqvae/resize.py2
-rw-r--r--text_recognizer/networks/vqvae/vqvae.py98
12 files changed, 359 insertions, 208 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index f3ba49d..b1a101e 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -4,7 +4,6 @@ from typing import Tuple
from torch import nn, Tensor
-from text_recognizer.networks.encoders.efficientnet import EfficientNet
from text_recognizer.networks.transformer.layers import Decoder
from text_recognizer.networks.transformer.positional_encodings import (
PositionalEncoding,
@@ -18,15 +17,17 @@ class ConvTransformer(nn.Module):
def __init__(
self,
input_dims: Tuple[int, int, int],
+ encoder_dim: int,
hidden_dim: int,
dropout_rate: float,
num_classes: int,
pad_index: Tensor,
- encoder: EfficientNet,
+ encoder: nn.Module,
decoder: Decoder,
) -> None:
super().__init__()
self.input_dims = input_dims
+ self.encoder_dim = encoder_dim
self.hidden_dim = hidden_dim
self.dropout_rate = dropout_rate
self.num_classes = num_classes
@@ -38,7 +39,7 @@ class ConvTransformer(nn.Module):
# positional encoding.
self.latent_encoder = nn.Sequential(
nn.Conv2d(
- in_channels=self.encoder.out_channels,
+ in_channels=self.encoder_dim,
out_channels=self.hidden_dim,
kernel_size=1,
),
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py
index a972565..0433863 100644
--- a/text_recognizer/networks/vq_transformer.py
+++ b/text_recognizer/networks/vq_transformer.py
@@ -1,16 +1,12 @@
"""Vector quantized encoder, transformer decoder."""
-import math
from typing import Tuple
-from torch import nn, Tensor
+import torch
+from torch import Tensor
-from text_recognizer.networks.encoders.efficientnet import EfficientNet
+from text_recognizer.networks.vqvae.vqvae import VQVAE
from text_recognizer.networks.conv_transformer import ConvTransformer
from text_recognizer.networks.transformer.layers import Decoder
-from text_recognizer.networks.transformer.positional_encodings import (
- PositionalEncoding,
- PositionalEncoding2D,
-)
class VqTransformer(ConvTransformer):
@@ -19,16 +15,18 @@ class VqTransformer(ConvTransformer):
def __init__(
self,
input_dims: Tuple[int, int, int],
+ encoder_dim: int,
hidden_dim: int,
dropout_rate: float,
num_classes: int,
pad_index: Tensor,
- encoder: EfficientNet,
+ encoder: VQVAE,
decoder: Decoder,
+ pretrained_encoder_path: str,
) -> None:
- # TODO: Load pretrained vqvae encoder.
super().__init__(
input_dims=input_dims,
+ encoder_dim=encoder_dim,
hidden_dim=hidden_dim,
dropout_rate=dropout_rate,
num_classes=num_classes,
@@ -36,24 +34,19 @@ class VqTransformer(ConvTransformer):
encoder=encoder,
decoder=decoder,
)
- # Latent projector for down sampling number of filters and 2d
- # positional encoding.
- self.latent_encoder = nn.Sequential(
- nn.Conv2d(
- in_channels=self.encoder.out_channels,
- out_channels=self.hidden_dim,
- kernel_size=1,
- ),
- PositionalEncoding2D(
- hidden_dim=self.hidden_dim,
- max_h=self.input_dims[1],
- max_w=self.input_dims[2],
- ),
- nn.Flatten(start_dim=2),
- )
+ self.pretrained_encoder_path = pretrained_encoder_path
+
+ # For typing
+ self.encoder: VQVAE
+
+ def setup_encoder(self) -> None:
+ """Remove unecessary layers."""
+ # TODO: load pretrained vqvae
+ del self.encoder.decoder
+ del self.encoder.post_codebook_conv
def encode(self, x: Tensor) -> Tensor:
- """Encodes an image into a latent feature vector.
+ """Encodes an image into a discrete (VQ) latent representation.
Args:
x (Tensor): Image tensor.
@@ -69,8 +62,11 @@ class VqTransformer(ConvTransformer):
Returns:
Tensor: A Latent embedding of the image.
"""
- z = self.encoder(x)
- z = self.latent_encoder(z)
+ with torch.no_grad():
+ z_e = self.encoder.encode(x)
+ z_q, _ = self.encoder.quantize(z_e)
+
+ z = self.latent_encoder(z_q)
# Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
z = z.permute(0, 2, 1)
diff --git a/text_recognizer/networks/vqvae/__init__.py b/text_recognizer/networks/vqvae/__init__.py
index 7d56bdb..e1f05fa 100644
--- a/text_recognizer/networks/vqvae/__init__.py
+++ b/text_recognizer/networks/vqvae/__init__.py
@@ -1,2 +1 @@
"""VQ-VAE module."""
-from .vqvae import VQVAE
diff --git a/text_recognizer/networks/vqvae/attention.py b/text_recognizer/networks/vqvae/attention.py
index 5a6b3ce..78a2cc9 100644
--- a/text_recognizer/networks/vqvae/attention.py
+++ b/text_recognizer/networks/vqvae/attention.py
@@ -7,7 +7,7 @@ import torch.nn.functional as F
from text_recognizer.networks.vqvae.norm import Normalize
-@attr.s
+@attr.s(eq=False)
class Attention(nn.Module):
"""Convolutional attention."""
@@ -63,11 +63,12 @@ class Attention(nn.Module):
B, C, H, W = q.shape
q = q.reshape(B, C, H * W).permute(0, 2, 1) # [B, HW, C]
k = k.reshape(B, C, H * W) # [B, C, HW]
- energy = torch.bmm(q, k) * (C ** -0.5)
+ energy = torch.bmm(q, k) * (int(C) ** -0.5)
attention = F.softmax(energy, dim=2)
# Compute attention to which values
- v = v.reshape(B, C, H * W).permute(0, 2, 1) # [B, HW, C]
+ v = v.reshape(B, C, H * W)
+ attention = attention.permute(0, 2, 1) # [B, HW, HW]
out = torch.bmm(v, attention)
out = out.reshape(B, C, H, W)
out = self.proj(out)
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
index fcf768b..f51e0a3 100644
--- a/text_recognizer/networks/vqvae/decoder.py
+++ b/text_recognizer/networks/vqvae/decoder.py
@@ -1,62 +1,69 @@
"""CNN decoder for the VQ-VAE."""
-import attr
+from typing import Sequence
+
from torch import nn
from torch import Tensor
from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.vqvae.norm import Normalize
from text_recognizer.networks.vqvae.residual import Residual
-@attr.s(eq=False)
class Decoder(nn.Module):
"""A CNN encoder network."""
- in_channels: int = attr.ib()
- embedding_dim: int = attr.ib()
- out_channels: int = attr.ib()
- res_channels: int = attr.ib()
- num_residual_layers: int = attr.ib()
- activation: str = attr.ib()
- decoder: nn.Sequential = attr.ib(init=False)
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
+ def __init__(self, out_channels: int, hidden_dim: int, channels_multipliers: Sequence[int], dropout_rate: float, activation: str = "mish") -> None:
super().__init__()
+ self.out_channels = out_channels
+ self.hidden_dim = hidden_dim
+ self.channels_multipliers = tuple(channels_multipliers)
+ self.activation = activation
+ self.dropout_rate = dropout_rate
self.decoder = self._build_decompression_block()
def _build_decompression_block(self,) -> nn.Sequential:
+ in_channels = self.hidden_dim * self.channels_multipliers[0]
+ decoder = []
+ for _ in range(2):
+ decoder += [
+ Residual(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ dropout_rate=self.dropout_rate,
+ use_norm=True,
+ ),
+ ]
+
activation_fn = activation_function(self.activation)
- blocks = [
+ out_channels_multipliers = self.channels_multipliers + (1, )
+ num_blocks = len(self.channels_multipliers)
+
+ for i in range(num_blocks):
+ in_channels = self.hidden_dim * self.channels_multipliers[i]
+ out_channels = self.hidden_dim * out_channels_multipliers[i + 1]
+ decoder += [
+ nn.ConvTranspose2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ ),
+ activation_fn,
+ ]
+
+ decoder += [
+ Normalize(num_channels=self.hidden_dim * out_channels_multipliers[-1]),
+ nn.Mish(),
nn.Conv2d(
- in_channels=self.in_channels,
- out_channels=self.embedding_dim,
- kernel_size=3,
- padding=1,
- )
- ]
- for _ in range(self.num_residual_layers):
- blocks.append(
- Residual(in_channels=self.embedding_dim, out_channels=self.res_channels)
- )
- blocks.append(activation_fn)
- blocks += [
- nn.ConvTranspose2d(
- in_channels=self.embedding_dim,
- out_channels=self.embedding_dim // 2,
- kernel_size=4,
- stride=2,
- padding=1,
- ),
- activation_fn,
- nn.ConvTranspose2d(
- in_channels=self.embedding_dim // 2,
+ in_channels=self.hidden_dim * out_channels_multipliers[-1],
out_channels=self.out_channels,
- kernel_size=4,
- stride=2,
+ kernel_size=3,
+ stride=1,
padding=1,
),
]
- return nn.Sequential(*blocks)
+ return nn.Sequential(*decoder)
def forward(self, z_q: Tensor) -> Tensor:
"""Reconstruct input from given codes."""
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
index f086c6b..ad8f950 100644
--- a/text_recognizer/networks/vqvae/encoder.py
+++ b/text_recognizer/networks/vqvae/encoder.py
@@ -1,7 +1,6 @@
"""CNN encoder for the VQ-VAE."""
-from typing import Sequence, Optional, Tuple, Type
+from typing import List, Tuple
-import attr
from torch import nn
from torch import Tensor
@@ -9,64 +8,59 @@ from text_recognizer.networks.util import activation_function
from text_recognizer.networks.vqvae.residual import Residual
-@attr.s(eq=False)
class Encoder(nn.Module):
"""A CNN encoder network."""
- in_channels: int = attr.ib()
- out_channels: int = attr.ib()
- res_channels: int = attr.ib()
- num_residual_layers: int = attr.ib()
- embedding_dim: int = attr.ib()
- activation: str = attr.ib()
- encoder: nn.Sequential = attr.ib(init=False)
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
+ def __init__(self, in_channels: int, hidden_dim: int, channels_multipliers: List[int], dropout_rate: float, activation: str = "mish") -> None:
super().__init__()
+ self.in_channels = in_channels
+ self.hidden_dim = hidden_dim
+ self.channels_multipliers = tuple(channels_multipliers)
+ self.activation = activation
+ self.dropout_rate = dropout_rate
self.encoder = self._build_compression_block()
def _build_compression_block(self) -> nn.Sequential:
- activation_fn = activation_function(self.activation)
- block = [
+ """Builds encoder network."""
+ encoder = [
nn.Conv2d(
in_channels=self.in_channels,
- out_channels=self.out_channels // 2,
- kernel_size=4,
- stride=2,
- padding=1,
- ),
- activation_fn,
- nn.Conv2d(
- in_channels=self.out_channels // 2,
- out_channels=self.out_channels,
- kernel_size=4,
- stride=2,
- padding=1,
- ),
- activation_fn,
- nn.Conv2d(
- in_channels=self.out_channels,
- out_channels=self.out_channels,
+ out_channels=self.hidden_dim,
kernel_size=3,
+ stride=1,
padding=1,
),
]
- for _ in range(self.num_residual_layers):
- block.append(
- Residual(in_channels=self.out_channels, out_channels=self.res_channels)
- )
+ num_blocks = len(self.channels_multipliers)
+ channels_multipliers = (1, ) + self.channels_multipliers
+ activation_fn = activation_function(self.activation)
- block.append(
- nn.Conv2d(
- in_channels=self.out_channels,
- out_channels=self.embedding_dim,
- kernel_size=1,
- )
- )
+ for i in range(num_blocks):
+ in_channels = self.hidden_dim * channels_multipliers[i]
+ out_channels = self.hidden_dim * channels_multipliers[i + 1]
+ encoder += [
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ ),
+ activation_fn,
+ ]
+
+ for _ in range(2):
+ encoder += [
+ Residual(
+ in_channels=self.hidden_dim * self.channels_multipliers[-1],
+ out_channels=self.hidden_dim * self.channels_multipliers[-1],
+ dropout_rate=self.dropout_rate,
+ use_norm=True,
+ )
+ ]
- return nn.Sequential(*block)
+ return nn.Sequential(*encoder)
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""Encodes input into a discrete representation."""
diff --git a/text_recognizer/networks/vqvae/norm.py b/text_recognizer/networks/vqvae/norm.py
index df66efc..3e6963a 100644
--- a/text_recognizer/networks/vqvae/norm.py
+++ b/text_recognizer/networks/vqvae/norm.py
@@ -3,7 +3,7 @@ import attr
from torch import nn, Tensor
-@attr.s
+@attr.s(eq=False)
class Normalize(nn.Module):
num_channels: int = attr.ib()
norm: nn.GroupNorm = attr.ib(init=False)
@@ -12,7 +12,7 @@ class Normalize(nn.Module):
"""Post init configuration."""
super().__init__()
self.norm = nn.GroupNorm(
- num_groups=32, num_channels=self.num_channels, eps=1.0e-6, affine=True
+ num_groups=self.num_channels, num_channels=self.num_channels, eps=1.0e-6, affine=True
)
def forward(self, x: Tensor) -> Tensor:
diff --git a/text_recognizer/networks/vqvae/pixelcnn.py b/text_recognizer/networks/vqvae/pixelcnn.py
new file mode 100644
index 0000000..5c580df
--- /dev/null
+++ b/text_recognizer/networks/vqvae/pixelcnn.py
@@ -0,0 +1,165 @@
+"""PixelCNN encoder and decoder.
+
+Same as in VQGAN among other. Hopefully, better reconstructions...
+
+TODO: Add num of residual layers.
+"""
+from typing import Sequence
+
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.vqvae.attention import Attention
+from text_recognizer.networks.vqvae.norm import Normalize
+from text_recognizer.networks.vqvae.residual import Residual
+from text_recognizer.networks.vqvae.resize import Downsample, Upsample
+
+
+class Encoder(nn.Module):
+ """PixelCNN encoder."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ hidden_dim: int,
+ channels_multipliers: Sequence[int],
+ dropout_rate: float,
+ ) -> None:
+ super().__init__()
+ self.in_channels = in_channels
+ self.dropout_rate = dropout_rate
+ self.hidden_dim = hidden_dim
+ self.channels_multipliers = tuple(channels_multipliers)
+ self.encoder = self._build_encoder()
+
+ def _build_encoder(self) -> nn.Sequential:
+ """Builds encoder."""
+ encoder = [
+ nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.hidden_dim,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ ),
+ ]
+ num_blocks = len(self.channels_multipliers)
+ in_channels_multipliers = (1,) + self.channels_multipliers
+ for i in range(num_blocks):
+ in_channels = self.hidden_dim * in_channels_multipliers[i]
+ out_channels = self.hidden_dim * self.channels_multipliers[i]
+ encoder += [
+ Residual(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout_rate=self.dropout_rate,
+ use_norm=True,
+ ),
+ ]
+ if i == num_blocks - 1:
+ encoder.append(Attention(in_channels=out_channels))
+ encoder.append(Downsample())
+
+ for _ in range(2):
+ encoder += [
+ Residual(
+ in_channels=self.hidden_dim * self.channels_multipliers[-1],
+ out_channels=self.hidden_dim * self.channels_multipliers[-1],
+ dropout_rate=self.dropout_rate,
+ use_norm=True,
+ ),
+ Attention(in_channels=self.hidden_dim * self.channels_multipliers[-1])
+ ]
+
+ encoder += [
+ Normalize(num_channels=self.hidden_dim * self.channels_multipliers[-1]),
+ nn.Mish(),
+ nn.Conv2d(
+ in_channels=self.hidden_dim * self.channels_multipliers[-1],
+ out_channels=self.hidden_dim * self.channels_multipliers[-1],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ ),
+ ]
+ return nn.Sequential(*encoder)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Encodes input to a latent representation."""
+ return self.encoder(x)
+
+
+class Decoder(nn.Module):
+ """PixelCNN decoder."""
+
+ def __init__(
+ self,
+ hidden_dim: int,
+ channels_multipliers: Sequence[int],
+ out_channels: int,
+ dropout_rate: float,
+ ) -> None:
+ super().__init__()
+ self.hidden_dim = hidden_dim
+ self.out_channels = out_channels
+ self.channels_multipliers = tuple(channels_multipliers)
+ self.dropout_rate = dropout_rate
+ self.decoder = self._build_decoder()
+
+ def _build_decoder(self) -> nn.Sequential:
+ """Builds decoder."""
+ in_channels = self.hidden_dim * self.channels_multipliers[0]
+ decoder = [
+ Residual(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ dropout_rate=self.dropout_rate,
+ use_norm=True,
+ ),
+ Attention(in_channels=in_channels),
+ Residual(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ dropout_rate=self.dropout_rate,
+ use_norm=True,
+ ),
+ ]
+
+ out_channels_multipliers = self.channels_multipliers + (1, )
+ num_blocks = len(self.channels_multipliers)
+
+ for i in range(num_blocks):
+ in_channels = self.hidden_dim * self.channels_multipliers[i]
+ out_channels = self.hidden_dim * out_channels_multipliers[i + 1]
+ decoder.append(
+ Residual(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout_rate=self.dropout_rate,
+ use_norm=True,
+ )
+ )
+ if i == 0:
+ decoder.append(
+ Attention(
+ in_channels=out_channels
+ )
+ )
+ decoder.append(Upsample())
+
+ decoder += [
+ Normalize(num_channels=self.hidden_dim * out_channels_multipliers[-1]),
+ nn.Mish(),
+ nn.Conv2d(
+ in_channels=self.hidden_dim * out_channels_multipliers[-1],
+ out_channels=self.out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ ),
+ ]
+ return nn.Sequential(*decoder)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Decodes latent vector."""
+ return self.decoder(x)
diff --git a/text_recognizer/networks/vqvae/quantizer.py b/text_recognizer/networks/vqvae/quantizer.py
index a4f11f0..6fb57e8 100644
--- a/text_recognizer/networks/vqvae/quantizer.py
+++ b/text_recognizer/networks/vqvae/quantizer.py
@@ -11,13 +11,15 @@ import torch.nn.functional as F
class EmbeddingEMA(nn.Module):
+ """Embedding for Exponential Moving Average (EMA)."""
+
def __init__(self, num_embeddings: int, embedding_dim: int) -> None:
super().__init__()
weight = torch.zeros(num_embeddings, embedding_dim)
nn.init.kaiming_uniform_(weight, nonlinearity="linear")
self.register_buffer("weight", weight)
- self.register_buffer("_cluster_size", torch.zeros(num_embeddings))
- self.register_buffer("_weight_avg", weight)
+ self.register_buffer("cluster_size", torch.zeros(num_embeddings))
+ self.register_buffer("weight_avg", weight.clone())
class VectorQuantizer(nn.Module):
@@ -81,16 +83,17 @@ class VectorQuantizer(nn.Module):
return quantized_latent
def compute_ema(self, one_hot_encoding: Tensor, latent: Tensor) -> None:
+ """Computes the EMA update to the codebook."""
batch_cluster_size = one_hot_encoding.sum(axis=0)
batch_embedding_avg = (latent.t() @ one_hot_encoding).t()
- self.embedding._cluster_size.data.mul_(self.decay).add_(
+ self.embedding.cluster_size.data.mul_(self.decay).add_(
batch_cluster_size, alpha=1 - self.decay
)
- self.embedding._weight_avg.data.mul_(self.decay).add_(
+ self.embedding.weight_avg.data.mul_(self.decay).add_(
batch_embedding_avg, alpha=1 - self.decay
)
- new_embedding = self.embedding._weight_avg / (
- self.embedding._cluster_size + 1.0e-5
+ new_embedding = self.embedding.weight_avg / (
+ self.embedding.cluster_size + 1.0e-5
).unsqueeze(1)
self.embedding.weight.data.copy_(new_embedding)
diff --git a/text_recognizer/networks/vqvae/residual.py b/text_recognizer/networks/vqvae/residual.py
index 98109b8..4ed3781 100644
--- a/text_recognizer/networks/vqvae/residual.py
+++ b/text_recognizer/networks/vqvae/residual.py
@@ -1,18 +1,55 @@
"""Residual block."""
+import attr
from torch import nn
from torch import Tensor
+from text_recognizer.networks.vqvae.norm import Normalize
+
+@attr.s(eq=False)
class Residual(nn.Module):
- def __init__(self, in_channels: int, out_channels: int,) -> None:
+ in_channels: int = attr.ib()
+ out_channels: int = attr.ib()
+ dropout_rate: float = attr.ib(default=0.0)
+ use_norm: bool = attr.ib(default=False)
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
super().__init__()
- self.block = nn.Sequential(
- nn.Mish(inplace=True),
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
- nn.Mish(inplace=True),
- nn.Conv2d(out_channels, in_channels, kernel_size=1, bias=False),
- )
+ self.block = self._build_res_block()
+ if self.in_channels != self.out_channels:
+ self.conv_shortcut = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=3, stride=1, padding=1)
+ else:
+ self.conv_shortcut = None
+
+ def _build_res_block(self) -> nn.Sequential:
+ """Build residual block."""
+ block = []
+ if self.use_norm:
+ block.append(Normalize(num_channels=self.in_channels))
+ block += [
+ nn.Mish(),
+ nn.Conv2d(
+ self.in_channels,
+ self.out_channels,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ ),
+ ]
+ if self.dropout_rate:
+ block += [nn.Dropout(p=self.dropout_rate)]
+
+ if self.use_norm:
+ block.append(Normalize(num_channels=self.out_channels))
+
+ block += [
+ nn.Mish(),
+ nn.Conv2d(self.out_channels, self.out_channels, kernel_size=1, bias=False),
+ ]
+ return nn.Sequential(*block)
def forward(self, x: Tensor) -> Tensor:
"""Apply the residual forward pass."""
- return x + self.block(x)
+ residual = self.conv_shortcut(x) if self.conv_shortcut is not None else x
+ return residual + self.block(x)
diff --git a/text_recognizer/networks/vqvae/resize.py b/text_recognizer/networks/vqvae/resize.py
index 769d089..8d67d02 100644
--- a/text_recognizer/networks/vqvae/resize.py
+++ b/text_recognizer/networks/vqvae/resize.py
@@ -8,7 +8,7 @@ class Upsample(nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Applies upsampling."""
- return F.interpolate(x, scale_factor=2, mode="nearest")
+ return F.interpolate(x, scale_factor=2.0, mode="nearest")
class Downsample(nn.Module):
diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py
index 1585d40..0646119 100644
--- a/text_recognizer/networks/vqvae/vqvae.py
+++ b/text_recognizer/networks/vqvae/vqvae.py
@@ -1,13 +1,9 @@
"""The VQ-VAE."""
from typing import Tuple
-import torch
from torch import nn
from torch import Tensor
-import torch.nn.functional as F
-from text_recognizer.networks.vqvae.decoder import Decoder
-from text_recognizer.networks.vqvae.encoder import Encoder
from text_recognizer.networks.vqvae.quantizer import VectorQuantizer
@@ -16,93 +12,45 @@ class VQVAE(nn.Module):
def __init__(
self,
- in_channels: int,
- res_channels: int,
- num_residual_layers: int,
+ encoder: nn.Module,
+ decoder: nn.Module,
+ hidden_dim: int,
embedding_dim: int,
num_embeddings: int,
decay: float = 0.99,
- activation: str = "mish",
) -> None:
super().__init__()
- # Encoders
- self.btm_encoder = Encoder(
- in_channels=1,
- out_channels=embedding_dim,
- res_channels=res_channels,
- num_residual_layers=num_residual_layers,
- embedding_dim=embedding_dim,
- activation=activation,
+ self.encoder = encoder
+ self.decoder = decoder
+ self.pre_codebook_conv = nn.Conv2d(
+ in_channels=hidden_dim, out_channels=embedding_dim, kernel_size=1
)
-
- self.top_encoder = Encoder(
- in_channels=embedding_dim,
- out_channels=embedding_dim,
- res_channels=res_channels,
- num_residual_layers=num_residual_layers,
- embedding_dim=embedding_dim,
- activation=activation,
+ self.post_codebook_conv = nn.Conv2d(
+ in_channels=embedding_dim, out_channels=hidden_dim, kernel_size=1
)
-
- # Quantizers
- self.btm_quantizer = VectorQuantizer(
- num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay,
- )
-
- self.top_quantizer = VectorQuantizer(
+ self.quantizer = VectorQuantizer(
num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay,
)
- # Decoders
- self.top_decoder = Decoder(
- in_channels=embedding_dim,
- out_channels=embedding_dim,
- embedding_dim=embedding_dim,
- res_channels=res_channels,
- num_residual_layers=num_residual_layers,
- activation=activation,
- )
-
- self.btm_decoder = Decoder(
- in_channels=2 * embedding_dim,
- out_channels=in_channels,
- embedding_dim=embedding_dim,
- res_channels=res_channels,
- num_residual_layers=num_residual_layers,
- activation=activation,
- )
- def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ def encode(self, x: Tensor) -> Tensor:
"""Encodes input to a latent code."""
- z_btm = self.btm_encoder(x)
- z_top = self.top_encoder(z_btm)
- return z_btm, z_top
+ z_e = self.encoder(x)
+ return self.pre_codebook_conv(z_e)
- def quantize(
- self, z_btm: Tensor, z_top: Tensor
- ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
- q_btm, vq_btm_loss = self.top_quantizer(z_btm)
- q_top, vq_top_loss = self.top_quantizer(z_top)
- return q_btm, vq_btm_loss, q_top, vq_top_loss
+ def quantize(self, z_e: Tensor) -> Tuple[Tensor, Tensor]:
+ z_q, vq_loss = self.quantizer(z_e)
+ return z_q, vq_loss
- def decode(self, q_btm: Tensor, q_top: Tensor) -> Tuple[Tensor, Tensor]:
+ def decode(self, z_q: Tensor) -> Tensor:
"""Reconstructs input from latent codes."""
- d_top = self.top_decoder(q_top)
- x_hat = self.btm_decoder(torch.cat((d_top, q_btm), dim=1))
- return d_top, x_hat
-
- def loss_fn(
- self, vq_btm_loss: Tensor, vq_top_loss: Tensor, d_top: Tensor, z_btm: Tensor
- ) -> Tensor:
- """Calculates the latent loss."""
- return 0.5 * (vq_top_loss + vq_btm_loss) + F.mse_loss(d_top, z_btm)
+ z = self.post_codebook_conv(z_q)
+ x_hat = self.decoder(z)
+ return x_hat
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""Compresses and decompresses input."""
- z_btm, z_top = self.encode(x)
- q_btm, vq_btm_loss, q_top, vq_top_loss = self.quantize(z_btm=z_btm, z_top=z_top)
- d_top, x_hat = self.decode(q_btm=q_btm, q_top=q_top)
- vq_loss = self.loss_fn(
- vq_btm_loss=vq_btm_loss, vq_top_loss=vq_top_loss, d_top=d_top, z_btm=z_btm
- )
+ z_e = self.encode(x)
+ z_q, vq_loss = self.quantize(z_e)
+ x_hat = self.decode(z_q)
return x_hat, vq_loss