summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-21 21:34:53 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-21 21:34:53 +0100
commitb44de0e11281c723ec426f8bec8ca0897ecfe3ff (patch)
tree998841a3a681d3dedfbe8470c1b8544b4dcbe7a2 /text_recognizer
parent3b2fb0fd977a6aff4dcf88e1a0f99faac51e05b1 (diff)
Remove VQVAE stuff, did not work...
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/criterion/n_layer_discriminator.py59
-rw-r--r--text_recognizer/criterion/vqgan_loss.py123
-rw-r--r--text_recognizer/models/vq_transformer.py94
-rw-r--r--text_recognizer/models/vqgan.py116
-rw-r--r--text_recognizer/models/vqvae.py45
-rw-r--r--text_recognizer/networks/quantizer/__init__.py0
-rw-r--r--text_recognizer/networks/quantizer/codebook.py96
-rw-r--r--text_recognizer/networks/quantizer/kmeans.py32
-rw-r--r--text_recognizer/networks/quantizer/quantizer.py59
-rw-r--r--text_recognizer/networks/quantizer/utils.py26
-rw-r--r--text_recognizer/networks/vq_transformer.py84
-rw-r--r--text_recognizer/networks/vqvae/__init__.py1
-rw-r--r--text_recognizer/networks/vqvae/decoder.py93
-rw-r--r--text_recognizer/networks/vqvae/encoder.py85
-rw-r--r--text_recognizer/networks/vqvae/norm.py24
-rw-r--r--text_recognizer/networks/vqvae/residual.py54
-rw-r--r--text_recognizer/networks/vqvae/resize.py19
-rw-r--r--text_recognizer/networks/vqvae/vqvae.py42
18 files changed, 0 insertions, 1052 deletions
diff --git a/text_recognizer/criterion/n_layer_discriminator.py b/text_recognizer/criterion/n_layer_discriminator.py
deleted file mode 100644
index a9f47f9..0000000
--- a/text_recognizer/criterion/n_layer_discriminator.py
+++ /dev/null
@@ -1,59 +0,0 @@
-"""Pix2pix discriminator loss."""
-from torch import nn, Tensor
-
-from text_recognizer.networks.vqvae.norm import Normalize
-
-
-class NLayerDiscriminator(nn.Module):
- """Defines a PatchGAN discriminator loss in Pix2Pix."""
-
- def __init__(
- self, in_channels: int = 1, num_channels: int = 32, num_layers: int = 3
- ) -> None:
- super().__init__()
- self.in_channels = in_channels
- self.num_channels = num_channels
- self.num_layers = num_layers
- self.discriminator = self._build_discriminator()
-
- def _build_discriminator(self) -> nn.Sequential:
- """Builds discriminator."""
- discriminator = [
- nn.Sigmoid(),
- nn.Conv2d(
- in_channels=self.in_channels,
- out_channels=self.num_channels,
- kernel_size=4,
- stride=2,
- padding=1,
- ),
- nn.Mish(inplace=True),
- ]
- in_channels = self.num_channels
- for n in range(1, self.num_layers):
- discriminator += [
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=in_channels * n,
- kernel_size=4,
- stride=2,
- padding=1,
- ),
- # Normalize(num_channels=in_channels * n),
- nn.Mish(inplace=True),
- ]
- in_channels *= n
-
- discriminator += [
- nn.Conv2d(
- in_channels=self.num_channels * (self.num_layers - 1),
- out_channels=1,
- kernel_size=4,
- padding=1,
- )
- ]
- return nn.Sequential(*discriminator)
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass through discriminator."""
- return self.discriminator(x)
diff --git a/text_recognizer/criterion/vqgan_loss.py b/text_recognizer/criterion/vqgan_loss.py
deleted file mode 100644
index 8e8b65b..0000000
--- a/text_recognizer/criterion/vqgan_loss.py
+++ /dev/null
@@ -1,123 +0,0 @@
-"""VQGAN loss for PyTorch Lightning."""
-from typing import Optional, Tuple
-
-import torch
-from torch import nn, Tensor
-import torch.nn.functional as F
-
-from text_recognizer.criterion.n_layer_discriminator import NLayerDiscriminator
-
-
-def _adopt_weight(
- weight: Tensor, global_step: int, threshold: int = 0, value: float = 0.0
-) -> float:
- """Sets weight to value after the threshold is passed."""
- if global_step < threshold:
- weight = value
- return weight
-
-
-class VQGANLoss(nn.Module):
- """VQGAN loss."""
-
- def __init__(
- self,
- reconstruction_loss: nn.L1Loss,
- discriminator: NLayerDiscriminator,
- commitment_weight: float = 1.0,
- discriminator_weight: float = 1.0,
- discriminator_factor: float = 1.0,
- discriminator_iter_start: int = 1000,
- ) -> None:
- super().__init__()
- self.reconstruction_loss = reconstruction_loss
- self.discriminator = discriminator
- self.commitment_weight = commitment_weight
- self.discriminator_weight = discriminator_weight
- self.discriminator_factor = discriminator_factor
- self.discriminator_iter_start = discriminator_iter_start
-
- @staticmethod
- def adversarial_loss(logits_real: Tensor, logits_fake: Tensor) -> Tensor:
- """Calculates the adversarial loss."""
- loss_real = torch.mean(F.relu(1.0 - logits_real))
- loss_fake = torch.mean(F.relu(1.0 + logits_fake))
- d_loss = (loss_real + loss_fake) / 2.0
- return d_loss
-
- def _adaptive_weight(
- self, rec_loss: Tensor, g_loss: Tensor, decoder_last_layer: Tensor
- ) -> Tensor:
- rec_grads = torch.autograd.grad(
- rec_loss, decoder_last_layer, retain_graph=True
- )[0]
- g_grads = torch.autograd.grad(g_loss, decoder_last_layer, retain_graph=True)[0]
- d_weight = torch.norm(rec_grads) / (torch.norm(g_grads) + 1.0e-4)
- d_weight = torch.clamp(d_weight, 0.0, 1.0e4).detach()
- d_weight *= self.discriminator_weight
- return d_weight
-
- def forward(
- self,
- data: Tensor,
- reconstructions: Tensor,
- commitment_loss: Tensor,
- decoder_last_layer: Tensor,
- optimizer_idx: int,
- global_step: int,
- stage: str,
- ) -> Optional[Tuple]:
- """Calculates the VQGAN loss."""
- rec_loss: Tensor = self.reconstruction_loss(reconstructions, data)
-
- # GAN part.
- if optimizer_idx == 0:
- logits_fake = self.discriminator(reconstructions)
- g_loss = -torch.mean(logits_fake)
-
- if self.training:
- d_weight = self._adaptive_weight(
- rec_loss=rec_loss,
- g_loss=g_loss,
- decoder_last_layer=decoder_last_layer,
- )
- else:
- d_weight = torch.tensor(0.0)
-
- d_factor = _adopt_weight(
- self.discriminator_factor,
- global_step=global_step,
- threshold=self.discriminator_iter_start,
- )
-
- loss: Tensor = (
- rec_loss
- + d_factor * d_weight * g_loss
- + self.commitment_weight * commitment_loss
- )
- log = {
- f"{stage}/total_loss": loss,
- f"{stage}/commitment_loss": commitment_loss,
- f"{stage}/rec_loss": rec_loss,
- f"{stage}/g_loss": g_loss,
- }
- return loss, log
-
- if optimizer_idx == 1:
- logits_fake = self.discriminator(reconstructions.detach())
- logits_real = self.discriminator(data.detach())
-
- d_factor = _adopt_weight(
- self.discriminator_factor,
- global_step=global_step,
- threshold=self.discriminator_iter_start,
- )
-
- d_loss = d_factor * self.adversarial_loss(
- logits_real=logits_real, logits_fake=logits_fake
- )
-
- log = {
- f"{stage}/d_loss": d_loss,
- }
- return d_loss, log
diff --git a/text_recognizer/models/vq_transformer.py b/text_recognizer/models/vq_transformer.py
deleted file mode 100644
index 8ec28fd..0000000
--- a/text_recognizer/models/vq_transformer.py
+++ /dev/null
@@ -1,94 +0,0 @@
-"""PyTorch Lightning model for base Transformers."""
-from typing import Tuple, Type
-
-import attr
-import torch
-from torch import Tensor
-
-from text_recognizer.models.transformer import TransformerLitModel
-
-
-@attr.s(auto_attribs=True, eq=False)
-class VqTransformerLitModel(TransformerLitModel):
- """A PyTorch Lightning model for transformer networks."""
-
- def forward(self, data: Tensor) -> Tensor:
- """Forward pass with the transformer network."""
- return self.predict(data)
-
- def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
- """Training step."""
- data, targets = batch
- logits, commitment_loss = self.network(data, targets[:, :-1])
- loss = self.loss_fn(logits, targets[:, 1:]) + commitment_loss
- self.log("train/loss", loss)
- self.log("train/commitment_loss", commitment_loss)
- return loss
-
- def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Validation step."""
- data, targets = batch
- logits, commitment_loss = self.network(data, targets[:, :-1])
- loss = self.loss_fn(logits, targets[:, 1:]) + commitment_loss
- self.log("val/loss", loss, prog_bar=True)
- self.log("val/commitment_loss", commitment_loss)
-
- # Get the token prediction.
- # pred = self(data)
- # self.val_cer(pred, targets)
- # self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
- # self.test_acc(pred, targets)
- # self.log("val/acc", self.test_acc, on_step=False, on_epoch=True)
-
- def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Test step."""
- data, targets = batch
- pred = self(data)
- self.test_cer(pred, targets)
- self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)
- self.test_acc(pred, targets)
- self.log("test/acc", self.test_acc, on_step=False, on_epoch=True)
-
- def predict(self, x: Tensor) -> Tensor:
- """Predicts text in image.
-
- Args:
- x (Tensor): Image(s) to extract text from.
-
- Shapes:
- - x: :math: `(B, H, W)`
- - output: :math: `(B, S)`
-
- Returns:
- Tensor: A tensor of token indices of the predictions from the model.
- """
- bsz = x.shape[0]
-
- # Encode image(s) to latent vectors.
- z, _ = self.network.encode(x)
-
- # Create a placeholder matrix for storing outputs from the network
- output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
- output[:, 0] = self.start_index
-
- for Sy in range(1, self.max_output_len):
- context = output[:, :Sy] # (B, Sy)
- logits = self.network.decode(z, context) # (B, C, Sy)
- tokens = torch.argmax(logits, dim=1) # (B, Sy)
- output[:, Sy : Sy + 1] = tokens[:, -1:]
-
- # Early stopping of prediction loop if token is end or padding token.
- if (
- (output[:, Sy - 1] == self.end_index)
- | (output[:, Sy - 1] == self.pad_index)
- ).all():
- break
-
- # Set all tokens after end token to pad token.
- for Sy in range(1, self.max_output_len):
- idx = (output[:, Sy - 1] == self.end_index) | (
- output[:, Sy - 1] == self.pad_index
- )
- output[idx, Sy] = self.pad_index
-
- return output
diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py
deleted file mode 100644
index 6a90e06..0000000
--- a/text_recognizer/models/vqgan.py
+++ /dev/null
@@ -1,116 +0,0 @@
-"""PyTorch Lightning model for base Transformers."""
-from typing import Tuple
-
-import attr
-from torch import Tensor
-
-from text_recognizer.criterion.vqgan_loss import VQGANLoss
-from text_recognizer.models.base import BaseLitModel
-
-
-@attr.s(auto_attribs=True, eq=False)
-class VQGANLitModel(BaseLitModel):
- """A PyTorch Lightning model for transformer networks."""
-
- loss_fn: VQGANLoss = attr.ib()
- latent_loss_weight: float = attr.ib(default=0.25)
-
- def forward(self, data: Tensor) -> Tensor:
- """Forward pass with the transformer network."""
- return self.network(data)
-
- def training_step(
- self, batch: Tuple[Tensor, Tensor], batch_idx: int, optimizer_idx: int
- ) -> Tensor:
- """Training step."""
- data, _ = batch
- reconstructions, commitment_loss = self(data)
-
- if optimizer_idx == 0:
- loss, log = self.loss_fn(
- data=data,
- reconstructions=reconstructions,
- commitment_loss=commitment_loss,
- decoder_last_layer=self.network.decoder.decoder[-1].weight,
- optimizer_idx=optimizer_idx,
- global_step=self.global_step,
- stage="train",
- )
- self.log(
- "train/loss", loss, prog_bar=True,
- )
- self.log_dict(log, logger=True, on_step=True, on_epoch=True)
- return loss
-
- if optimizer_idx == 1:
- loss, log = self.loss_fn(
- data=data,
- reconstructions=reconstructions,
- commitment_loss=commitment_loss,
- decoder_last_layer=self.network.decoder.decoder[-1].weight,
- optimizer_idx=optimizer_idx,
- global_step=self.global_step,
- stage="train",
- )
- self.log(
- "train/discriminator_loss", loss, prog_bar=True,
- )
- self.log_dict(log, logger=True, on_step=True, on_epoch=True)
- return loss
-
- def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Validation step."""
- data, _ = batch
- reconstructions, commitment_loss = self(data)
-
- loss, log = self.loss_fn(
- data=data,
- reconstructions=reconstructions,
- commitment_loss=commitment_loss,
- decoder_last_layer=self.network.decoder.decoder[-1].weight,
- optimizer_idx=0,
- global_step=self.global_step,
- stage="val",
- )
- self.log(
- "val/loss", loss, prog_bar=True,
- )
- self.log_dict(log)
-
- _, log = self.loss_fn(
- data=data,
- reconstructions=reconstructions,
- commitment_loss=commitment_loss,
- decoder_last_layer=self.network.decoder.decoder[-1].weight,
- optimizer_idx=1,
- global_step=self.global_step,
- stage="val",
- )
- self.log_dict(log)
-
- def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Test step."""
- data, _ = batch
- reconstructions, commitment_loss = self(data)
-
- _, log = self.loss_fn(
- data=data,
- reconstructions=reconstructions,
- commitment_loss=commitment_loss,
- decoder_last_layer=self.network.decoder.decoder[-1].weight,
- optimizer_idx=0,
- global_step=self.global_step,
- stage="test",
- )
- self.log_dict(log)
-
- _, log = self.loss_fn(
- data=data,
- reconstructions=reconstructions,
- commitment_loss=commitment_loss,
- decoder_last_layer=self.network.decoder.decoder[-1].weight,
- optimizer_idx=1,
- global_step=self.global_step,
- stage="test",
- )
- self.log_dict(log)
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
deleted file mode 100644
index 4898852..0000000
--- a/text_recognizer/models/vqvae.py
+++ /dev/null
@@ -1,45 +0,0 @@
-"""PyTorch Lightning model for base Transformers."""
-from typing import Tuple
-
-import attr
-from torch import Tensor
-
-from text_recognizer.models.base import BaseLitModel
-
-
-@attr.s(auto_attribs=True, eq=False)
-class VQVAELitModel(BaseLitModel):
- """A PyTorch Lightning model for transformer networks."""
-
- commitment: float = attr.ib(default=0.25)
-
- def forward(self, data: Tensor) -> Tensor:
- """Forward pass with the transformer network."""
- return self.network(data)
-
- def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
- """Training step."""
- data, _ = batch
- reconstructions, commitment_loss = self(data)
- loss = self.loss_fn(reconstructions, data)
- loss = loss + self.commitment * commitment_loss
- self.log("train/commitment_loss", commitment_loss)
- self.log("train/loss", loss)
- return loss
-
- def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Validation step."""
- data, _ = batch
- reconstructions, commitment_loss = self(data)
- loss = self.loss_fn(reconstructions, data)
- self.log("val/commitment_loss", commitment_loss)
- self.log("val/loss", loss, prog_bar=True)
-
- def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Test step."""
- data, _ = batch
- reconstructions, commitment_loss = self(data)
- loss = self.loss_fn(reconstructions, data)
- loss = loss + self.commitment * commitment_loss
- self.log("test/commitment_loss", commitment_loss)
- self.log("test/loss", loss)
diff --git a/text_recognizer/networks/quantizer/__init__.py b/text_recognizer/networks/quantizer/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/text_recognizer/networks/quantizer/__init__.py
+++ /dev/null
diff --git a/text_recognizer/networks/quantizer/codebook.py b/text_recognizer/networks/quantizer/codebook.py
deleted file mode 100644
index cb9bc59..0000000
--- a/text_recognizer/networks/quantizer/codebook.py
+++ /dev/null
@@ -1,96 +0,0 @@
-"""Codebook module."""
-from typing import Tuple
-
-import attr
-from einops import rearrange
-import torch
-from torch import nn, Tensor
-import torch.nn.functional as F
-
-from text_recognizer.networks.quantizer.kmeans import kmeans
-from text_recognizer.networks.quantizer.utils import (
- ema_inplace,
- norm,
- sample_vectors,
-)
-
-
-@attr.s(eq=False)
-class CosineSimilarityCodebook(nn.Module):
- """Cosine similarity codebook."""
-
- dim: int = attr.ib()
- codebook_size: int = attr.ib()
- kmeans_init: bool = attr.ib(default=False)
- kmeans_iters: int = attr.ib(default=10)
- decay: float = attr.ib(default=0.8)
- eps: float = attr.ib(default=1.0e-5)
- threshold_dead: int = attr.ib(default=2)
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
-
- def __attrs_post_init__(self) -> None:
- if not self.kmeans_init:
- embeddings = norm(torch.randn(self.codebook_size, self.dim))
- else:
- embeddings = torch.zeros(self.codebook_size, self.dim)
- self.register_buffer("initalized", Tensor([not self.kmeans_init]))
- self.register_buffer("cluster_size", torch.zeros(self.codebook_size))
- self.register_buffer("embeddings", embeddings)
-
- def _initalize_embedding(self, data: Tensor) -> None:
- embeddings, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
- self.embeddings.data.copy_(embeddings)
- self.cluster_size.data.copy_(cluster_size)
- self.initalized.data.copy_(Tensor([True]))
-
- def _replace(self, samples: Tensor, mask: Tensor) -> None:
- samples = norm(samples)
- modified_codebook = torch.where(
- mask[..., None],
- sample_vectors(samples, self.codebook_size),
- self.embeddings,
- )
- self.embeddings.data.copy_(modified_codebook)
-
- def _replace_dead_codes(self, batch_samples: Tensor) -> None:
- if self.threshold_dead == 0:
- return
- dead_codes = self.cluster_size < self.threshold_dead
- if not torch.any(dead_codes):
- return
- batch_samples = rearrange(batch_samples, "... d -> (...) d")
- self._replace(batch_samples, mask=dead_codes)
-
- def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Quantizes tensor."""
- shape = x.shape
- flatten = rearrange(x, "... d -> (...) d")
- flatten = norm(flatten)
-
- if not self.initalized:
- self._initalize_embedding(flatten)
-
- embeddings = norm(self.embeddings)
- dist = flatten @ embeddings.t()
- indices = dist.max(dim=-1).indices
- one_hot = F.one_hot(indices, self.codebook_size).type_as(x)
- indices = indices.view(*shape[:-1])
-
- quantized = F.embedding(indices, self.embeddings)
-
- if self.training:
- bins = one_hot.sum(0)
- ema_inplace(self.cluster_size, bins, self.decay)
- zero_mask = bins == 0
- bins = bins.masked_fill(zero_mask, 1.0)
-
- embed_sum = flatten.t() @ one_hot
- embed_norm = (embed_sum / bins.unsqueeze(0)).t()
- embed_norm = norm(embed_norm)
- embed_norm = torch.where(zero_mask[..., None], embeddings, embed_norm)
- ema_inplace(self.embeddings, embed_norm, self.decay)
- self._replace_dead_codes(x)
-
- return quantized, indices
diff --git a/text_recognizer/networks/quantizer/kmeans.py b/text_recognizer/networks/quantizer/kmeans.py
deleted file mode 100644
index a34c381..0000000
--- a/text_recognizer/networks/quantizer/kmeans.py
+++ /dev/null
@@ -1,32 +0,0 @@
-"""K-means clustering for embeddings."""
-from typing import Tuple
-
-from einops import repeat
-import torch
-from torch import Tensor
-
-from text_recognizer.networks.quantizer.utils import norm, sample_vectors
-
-
-def kmeans(
- samples: Tensor, num_clusters: int, num_iters: int = 10
-) -> Tuple[Tensor, Tensor]:
- """Compute k-means clusters."""
- D = samples.shape[-1]
-
- means = sample_vectors(samples, num_clusters)
-
- for _ in range(num_iters):
- dists = samples @ means.t()
- buckets = dists.max(dim=-1).indices
- bins = torch.bincount(buckets, minlength=num_clusters)
- zero_mask = bins == 0
- bins_min_clamped = bins.masked_fill(zero_mask, 1)
-
- new_means = buckets.new_zeros(num_clusters, D).type_as(samples)
- new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=D), samples)
- new_means /= bins_min_clamped[..., None]
- new_means = norm(new_means)
- means = torch.where(zero_mask[..., None], means, new_means)
-
- return means, bins
diff --git a/text_recognizer/networks/quantizer/quantizer.py b/text_recognizer/networks/quantizer/quantizer.py
deleted file mode 100644
index 3e8f0b2..0000000
--- a/text_recognizer/networks/quantizer/quantizer.py
+++ /dev/null
@@ -1,59 +0,0 @@
-"""Implementation of a Vector Quantized Variational AutoEncoder.
-
-Reference:
-https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py
-"""
-from typing import Tuple, Type
-
-import attr
-from einops import rearrange
-import torch
-from torch import nn
-from torch import Tensor
-import torch.nn.functional as F
-
-
-@attr.s(eq=False)
-class VectorQuantizer(nn.Module):
- """Vector quantizer."""
-
- input_dim: int = attr.ib()
- codebook: Type[nn.Module] = attr.ib()
- commitment: float = attr.ib(default=1.0)
- project_in: nn.Linear = attr.ib(default=None, init=False)
- project_out: nn.Linear = attr.ib(default=None, init=False)
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
-
- def __attrs_post_init__(self) -> None:
- require_projection = self.codebook.dim != self.input_dim
- self.project_in = (
- nn.Linear(self.input_dim, self.codebook.dim)
- if require_projection
- else nn.Identity()
- )
- self.project_out = (
- nn.Linear(self.codebook.dim, self.input_dim)
- if require_projection
- else nn.Identity()
- )
-
- def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
- """Quantizes latent vectors."""
- H, W = x.shape[-2:]
- x = rearrange(x, "b d h w -> b (h w) d")
- x = self.project_in(x)
-
- quantized, indices = self.codebook(x)
-
- if self.training:
- commitment_loss = F.mse_loss(quantized.detach(), x) * self.commitment
- quantized = x + (quantized - x).detach()
- else:
- commitment_loss = torch.tensor([0.0]).type_as(x)
-
- quantized = self.project_out(quantized)
- quantized = rearrange(quantized, "b (h w) d -> b d h w", h=H, w=W)
-
- return quantized, indices, commitment_loss
diff --git a/text_recognizer/networks/quantizer/utils.py b/text_recognizer/networks/quantizer/utils.py
deleted file mode 100644
index 0502d49..0000000
--- a/text_recognizer/networks/quantizer/utils.py
+++ /dev/null
@@ -1,26 +0,0 @@
-"""Helper functions for quantization."""
-from typing import Tuple
-
-import torch
-from torch import Tensor
-import torch.nn.functional as F
-
-
-def sample_vectors(samples: Tensor, num: int) -> Tensor:
- """Subsamples a set of vectors."""
- B, device = samples.shape[0], samples.device
- if B >= num:
- indices = torch.randperm(B, device=device)[:num]
- else:
- indices = torch.randint(0, B, (num,), device=device)[:num]
- return samples[indices]
-
-
-def norm(t: Tensor) -> Tensor:
- """Applies L2-normalization."""
- return F.normalize(t, p=2, dim=-1)
-
-
-def ema_inplace(moving_avg: Tensor, new: Tensor, decay: float) -> None:
- """Applies exponential moving average."""
- moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py
deleted file mode 100644
index a2bd81b..0000000
--- a/text_recognizer/networks/vq_transformer.py
+++ /dev/null
@@ -1,84 +0,0 @@
-"""Vector quantized encoder, transformer decoder."""
-from typing import Optional, Tuple, Type
-
-from torch import nn, Tensor
-
-from text_recognizer.networks.conv_transformer import ConvTransformer
-from text_recognizer.networks.quantizer.quantizer import VectorQuantizer
-from text_recognizer.networks.transformer.layers import Decoder
-
-
-class VqTransformer(ConvTransformer):
- """Convolutional encoder and transformer decoder network."""
-
- def __init__(
- self,
- input_dims: Tuple[int, int, int],
- hidden_dim: int,
- num_classes: int,
- pad_index: Tensor,
- encoder: nn.Module,
- decoder: Decoder,
- pixel_pos_embedding: Type[nn.Module],
- quantizer: VectorQuantizer,
- token_pos_embedding: Optional[Type[nn.Module]] = None,
- ) -> None:
- super().__init__(
- input_dims=input_dims,
- hidden_dim=hidden_dim,
- num_classes=num_classes,
- pad_index=pad_index,
- encoder=encoder,
- decoder=decoder,
- pixel_pos_embedding=pixel_pos_embedding,
- token_pos_embedding=token_pos_embedding,
- )
- self.quantizer = quantizer
-
- def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Encodes an image into a discrete (VQ) latent representation.
-
- Args:
- x (Tensor): Image tensor.
-
- Shape:
- - x: :math: `(B, C, H, W)`
- - z: :math: `(B, Sx, E)`
-
- where Sx is the length of the flattened feature maps projected from
- the encoder. E latent dimension for each pixel in the projected
- feature maps.
-
- Returns:
- Tensor: A Latent embedding of the image.
- """
- z = self.encoder(x)
- z = self.conv(z)
- z, _, commitment_loss = self.quantizer(z)
- z = self.pixel_pos_embedding(z)
- z = z.flatten(start_dim=2)
-
- # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
- z = z.permute(0, 2, 1)
- return z, commitment_loss
-
- def forward(self, x: Tensor, context: Tensor) -> Tensor:
- """Encodes images into word piece logtis.
-
- Args:
- x (Tensor): Input image(s).
- context (Tensor): Target word embeddings.
-
- Shapes:
- - x: :math: `(B, C, H, W)`
- - context: :math: `(B, Sy, T)`
-
- where B is the batch size, C is the number of input channels, H is
- the image height and W is the image width.
-
- Returns:
- Tensor: Sequence of logits.
- """
- z, commitment_loss = self.encode(x)
- logits = self.decode(z, context)
- return logits, commitment_loss
diff --git a/text_recognizer/networks/vqvae/__init__.py b/text_recognizer/networks/vqvae/__init__.py
deleted file mode 100644
index e1f05fa..0000000
--- a/text_recognizer/networks/vqvae/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""VQ-VAE module."""
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
deleted file mode 100644
index 7734a5a..0000000
--- a/text_recognizer/networks/vqvae/decoder.py
+++ /dev/null
@@ -1,93 +0,0 @@
-"""CNN decoder for the VQ-VAE."""
-from typing import Sequence
-
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.vqvae.norm import Normalize
-from text_recognizer.networks.vqvae.residual import Residual
-
-
-class Decoder(nn.Module):
- """A CNN encoder network."""
-
- def __init__(
- self,
- out_channels: int,
- hidden_dim: int,
- channels_multipliers: Sequence[int],
- dropout_rate: float,
- activation: str = "mish",
- use_norm: bool = False,
- num_residuals: int = 4,
- residual_channels: int = 32,
- ) -> None:
- super().__init__()
- self.out_channels = out_channels
- self.hidden_dim = hidden_dim
- self.channels_multipliers = tuple(channels_multipliers)
- self.activation = activation
- self.dropout_rate = dropout_rate
- self.use_norm = use_norm
- self.num_residuals = num_residuals
- self.residual_channels = residual_channels
- self.decoder = self._build_decompression_block()
-
- def _build_decompression_block(self,) -> nn.Sequential:
- decoder = []
- in_channels = self.hidden_dim * self.channels_multipliers[0]
- for _ in range(self.num_residuals):
- decoder += [
- Residual(
- in_channels=in_channels,
- residual_channels=self.residual_channels,
- use_norm=self.use_norm,
- activation=self.activation,
- ),
- ]
-
- activation_fn = activation_function(self.activation)
- out_channels_multipliers = self.channels_multipliers + (1,)
- num_blocks = len(self.channels_multipliers)
-
- for i in range(num_blocks):
- in_channels = self.hidden_dim * self.channels_multipliers[i]
- out_channels = self.hidden_dim * out_channels_multipliers[i + 1]
- if self.use_norm:
- decoder += [
- Normalize(num_channels=in_channels,),
- ]
- decoder += [
- activation_fn,
- nn.ConvTranspose2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=4,
- stride=2,
- padding=1,
- ),
- ]
-
- if self.use_norm:
- decoder += [
- Normalize(
- num_channels=self.hidden_dim * out_channels_multipliers[-1],
- num_groups=self.hidden_dim * out_channels_multipliers[-1] // 4,
- ),
- ]
-
- decoder += [
- nn.Conv2d(
- in_channels=self.hidden_dim * out_channels_multipliers[-1],
- out_channels=self.out_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- ),
- ]
- return nn.Sequential(*decoder)
-
- def forward(self, z_q: Tensor) -> Tensor:
- """Reconstruct input from given codes."""
- return self.decoder(z_q)
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
deleted file mode 100644
index 4761486..0000000
--- a/text_recognizer/networks/vqvae/encoder.py
+++ /dev/null
@@ -1,85 +0,0 @@
-"""CNN encoder for the VQ-VAE."""
-from typing import List, Tuple
-
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.vqvae.norm import Normalize
-from text_recognizer.networks.vqvae.residual import Residual
-
-
-class Encoder(nn.Module):
- """A CNN encoder network."""
-
- def __init__(
- self,
- in_channels: int,
- hidden_dim: int,
- channels_multipliers: List[int],
- dropout_rate: float,
- activation: str = "mish",
- use_norm: bool = False,
- num_residuals: int = 4,
- residual_channels: int = 32,
- ) -> None:
- super().__init__()
- self.in_channels = in_channels
- self.hidden_dim = hidden_dim
- self.channels_multipliers = tuple(channels_multipliers)
- self.activation = activation
- self.dropout_rate = dropout_rate
- self.use_norm = use_norm
- self.num_residuals = num_residuals
- self.residual_channels = residual_channels
- self.encoder = self._build_compression_block()
-
- def _build_compression_block(self) -> nn.Sequential:
- """Builds encoder network."""
- num_blocks = len(self.channels_multipliers)
- channels_multipliers = (1,) + self.channels_multipliers
- activation_fn = activation_function(self.activation)
-
- encoder = [
- nn.Conv2d(
- in_channels=self.in_channels,
- out_channels=self.hidden_dim,
- kernel_size=3,
- stride=1,
- padding=1,
- ),
- ]
-
- for i in range(num_blocks):
- in_channels = self.hidden_dim * channels_multipliers[i]
- out_channels = self.hidden_dim * channels_multipliers[i + 1]
- if self.use_norm:
- encoder += [
- Normalize(num_channels=in_channels,),
- ]
- encoder += [
- activation_fn,
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=4,
- stride=2,
- padding=1,
- ),
- ]
-
- for _ in range(self.num_residuals):
- encoder += [
- Residual(
- in_channels=out_channels,
- residual_channels=self.residual_channels,
- use_norm=self.use_norm,
- activation=self.activation,
- )
- ]
-
- return nn.Sequential(*encoder)
-
- def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Encodes input into a discrete representation."""
- return self.encoder(x)
diff --git a/text_recognizer/networks/vqvae/norm.py b/text_recognizer/networks/vqvae/norm.py
deleted file mode 100644
index d73f9f8..0000000
--- a/text_recognizer/networks/vqvae/norm.py
+++ /dev/null
@@ -1,24 +0,0 @@
-"""Normalizer block."""
-import attr
-from torch import nn, Tensor
-
-
-@attr.s(eq=False)
-class Normalize(nn.Module):
- num_channels: int = attr.ib()
- num_groups: int = attr.ib(default=32)
- norm: nn.GroupNorm = attr.ib(init=False)
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
- super().__init__()
- self.norm = nn.GroupNorm(
- num_groups=self.num_groups,
- num_channels=self.num_channels,
- eps=1.0e-6,
- affine=True,
- )
-
- def forward(self, x: Tensor) -> Tensor:
- """Applies group normalization."""
- return self.norm(x)
diff --git a/text_recognizer/networks/vqvae/residual.py b/text_recognizer/networks/vqvae/residual.py
deleted file mode 100644
index bdff9eb..0000000
--- a/text_recognizer/networks/vqvae/residual.py
+++ /dev/null
@@ -1,54 +0,0 @@
-"""Residual block."""
-import attr
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.vqvae.norm import Normalize
-
-
-@attr.s(eq=False)
-class Residual(nn.Module):
- in_channels: int = attr.ib()
- residual_channels: int = attr.ib()
- use_norm: bool = attr.ib(default=False)
- activation: str = attr.ib(default="relu")
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
- super().__init__()
- self.block = self._build_res_block()
-
- def _build_res_block(self) -> nn.Sequential:
- """Build residual block."""
- block = []
- activation_fn = activation_function(activation=self.activation)
-
- if self.use_norm:
- block.append(Normalize(num_channels=self.in_channels))
-
- block += [
- activation_fn,
- nn.Conv2d(
- self.in_channels,
- self.residual_channels,
- kernel_size=3,
- padding=1,
- bias=False,
- ),
- ]
-
- if self.use_norm:
- block.append(Normalize(num_channels=self.residual_channels))
-
- block += [
- activation_fn,
- nn.Conv2d(
- self.residual_channels, self.in_channels, kernel_size=1, bias=False
- ),
- ]
- return nn.Sequential(*block)
-
- def forward(self, x: Tensor) -> Tensor:
- """Apply the residual forward pass."""
- return x + self.block(x)
diff --git a/text_recognizer/networks/vqvae/resize.py b/text_recognizer/networks/vqvae/resize.py
deleted file mode 100644
index 8d67d02..0000000
--- a/text_recognizer/networks/vqvae/resize.py
+++ /dev/null
@@ -1,19 +0,0 @@
-"""Up and down-sample with linear interpolation."""
-from torch import nn, Tensor
-import torch.nn.functional as F
-
-
-class Upsample(nn.Module):
- """Upsamples by a factor 2."""
-
- def forward(self, x: Tensor) -> Tensor:
- """Applies upsampling."""
- return F.interpolate(x, scale_factor=2.0, mode="nearest")
-
-
-class Downsample(nn.Module):
- """Downsampling by a factor 2."""
-
- def forward(self, x: Tensor) -> Tensor:
- """Applies downsampling."""
- return F.avg_pool2d(x, kernel_size=2, stride=2)
diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py
deleted file mode 100644
index 5560e12..0000000
--- a/text_recognizer/networks/vqvae/vqvae.py
+++ /dev/null
@@ -1,42 +0,0 @@
-"""The VQ-VAE."""
-from typing import Tuple
-
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.quantizer.quantizer import VectorQuantizer
-
-
-class VQVAE(nn.Module):
- """Vector Quantized Variational AutoEncoder."""
-
- def __init__(
- self,
- encoder: nn.Module,
- decoder: nn.Module,
- quantizer: VectorQuantizer,
- ) -> None:
- super().__init__()
- self.encoder = encoder
- self.decoder = decoder
- self.quantizer = quantizer
-
- def encode(self, x: Tensor) -> Tensor:
- """Encodes input to a latent code."""
- return self.encoder(x)
-
- def quantize(self, z_e: Tensor) -> Tuple[Tensor, Tensor]:
- """Quantizes the encoded latent vectors."""
- z_q, _, commitment_loss = self.quantizer(z_e)
- return z_q, commitment_loss
-
- def decode(self, z_q: Tensor) -> Tensor:
- """Reconstructs input from latent codes."""
- return self.decoder(z_q)
-
- def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Compresses and decompresses input."""
- z_e = self.encode(x)
- z_q, commitment_loss = self.quantize(z_e)
- x_hat = self.decode(z_q)
- return x_hat, commitment_loss