1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
|
"""Implementation of a Vector Quantized Variational AutoEncoder.
Reference:
https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py
"""
from einops import rearrange
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
class EmbeddingEMA(nn.Module):
"""Embedding for Exponential Moving Average (EMA)."""
def __init__(self, num_embeddings: int, embedding_dim: int) -> None:
super().__init__()
weight = torch.zeros(num_embeddings, embedding_dim)
nn.init.kaiming_uniform_(weight, nonlinearity="linear")
self.register_buffer("weight", weight)
self.register_buffer("cluster_size", torch.zeros(num_embeddings))
self.register_buffer("weight_avg", weight.clone())
class VectorQuantizer(nn.Module):
"""The codebook that contains quantized vectors."""
def __init__(
self, num_embeddings: int, embedding_dim: int, decay: float = 0.99
) -> None:
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.decay = decay
self.embedding = EmbeddingEMA(self.num_embeddings, self.embedding_dim)
def _discretization_bottleneck(self, latent: Tensor) -> Tensor:
"""Computes the code nearest to the latent representation.
First we compute the posterior categorical distribution, and then map
the latent representation to the nearest element of the embedding.
Args:
latent (Tensor): The latent representation.
Shape:
- latent :math:`(B x H x W, D)`
Returns:
Tensor: The quantized embedding vector.
"""
# Store latent shape.
b, h, w, d = latent.shape
# Flatten the latent representation to 2D.
latent = rearrange(latent, "b h w d -> (b h w) d")
# Compute the L2 distance between the latents and the embeddings.
l2_distance = (
torch.sum(latent ** 2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight ** 2, dim=1)
- 2 * latent @ self.embedding.weight.t()
) # [BHW x K]
# Find the embedding k nearest to each latent.
encoding_indices = torch.argmin(l2_distance, dim=1).unsqueeze(1) # [BHW, 1]
# Convert to one-hot encodings, aka discrete bottleneck.
one_hot_encoding = torch.zeros(
encoding_indices.shape[0], self.num_embeddings, device=latent.device
)
one_hot_encoding.scatter_(1, encoding_indices, 1) # [BHW x K]
# Embedding quantization.
quantized_latent = one_hot_encoding @ self.embedding.weight # [BHW, D]
quantized_latent = rearrange(
quantized_latent, "(b h w) d -> b h w d", b=b, h=h, w=w
)
if self.training:
self._compute_ema(one_hot_encoding=one_hot_encoding, latent=latent)
return quantized_latent
def _compute_ema(self, one_hot_encoding: Tensor, latent: Tensor) -> None:
"""Computes the EMA update to the codebook."""
batch_cluster_size = one_hot_encoding.sum(axis=0)
batch_embedding_avg = (latent.t() @ one_hot_encoding).t()
self.embedding.cluster_size.data.mul_(self.decay).add_(
batch_cluster_size, alpha=1 - self.decay
)
self.embedding.weight_avg.data.mul_(self.decay).add_(
batch_embedding_avg, alpha=1 - self.decay
)
new_embedding = self.embedding.weight_avg / (
self.embedding.cluster_size + 1.0e-5
).unsqueeze(1)
self.embedding.weight.data.copy_(new_embedding)
def _commitment_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor:
"""Vector Quantization loss.
The vector quantization algorithm allows us to create a codebook. The VQ
algorithm works by moving the embedding vectors towards the encoder outputs.
The embedding loss moves the embedding vector towards the encoder outputs. The
.detach() works as the stop gradient (sg) described in the paper.
Because the volume of the embedding space is dimensionless, it can arbitarily
grow if the embeddings are not trained as fast as the encoder parameters. To
mitigate this, a commitment loss is added in the second term which makes sure
that the encoder commits to an embedding and that its output does not grow.
Args:
latent (Tensor): The encoder output.
quantized_latent (Tensor): The quantized latent.
Returns:
Tensor: The combinded VQ loss.
"""
loss = F.mse_loss(quantized_latent.detach(), latent)
return loss
def forward(self, latent: Tensor) -> Tensor:
"""Forward pass that returns the quantized vector and the vq loss."""
# Rearrange latent representation s.t. the hidden dim is at the end.
latent = rearrange(latent, "b d h w -> b h w d")
# Maps latent to the nearest code in the codebook.
quantized_latent = self._discretization_bottleneck(latent)
loss = self._commitment_loss(latent, quantized_latent)
# Add residue to the quantized latent.
quantized_latent = latent + (quantized_latent - latent).detach()
# Rearrange the quantized shape back to the original shape.
quantized_latent = rearrange(quantized_latent, "b h w d -> b d h w")
return quantized_latent, loss
|