summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-21 21:34:53 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-21 21:34:53 +0100
commitb44de0e11281c723ec426f8bec8ca0897ecfe3ff (patch)
tree998841a3a681d3dedfbe8470c1b8544b4dcbe7a2
parent3b2fb0fd977a6aff4dcf88e1a0f99faac51e05b1 (diff)
Remove VQVAE stuff, did not work...
-rw-r--r--text_recognizer/criterion/n_layer_discriminator.py59
-rw-r--r--text_recognizer/criterion/vqgan_loss.py123
-rw-r--r--text_recognizer/models/vq_transformer.py94
-rw-r--r--text_recognizer/models/vqgan.py116
-rw-r--r--text_recognizer/models/vqvae.py45
-rw-r--r--text_recognizer/networks/quantizer/__init__.py0
-rw-r--r--text_recognizer/networks/quantizer/codebook.py96
-rw-r--r--text_recognizer/networks/quantizer/kmeans.py32
-rw-r--r--text_recognizer/networks/quantizer/quantizer.py59
-rw-r--r--text_recognizer/networks/quantizer/utils.py26
-rw-r--r--text_recognizer/networks/vq_transformer.py84
-rw-r--r--text_recognizer/networks/vqvae/__init__.py1
-rw-r--r--text_recognizer/networks/vqvae/decoder.py93
-rw-r--r--text_recognizer/networks/vqvae/encoder.py85
-rw-r--r--text_recognizer/networks/vqvae/norm.py24
-rw-r--r--text_recognizer/networks/vqvae/residual.py54
-rw-r--r--text_recognizer/networks/vqvae/resize.py19
-rw-r--r--text_recognizer/networks/vqvae/vqvae.py42
-rw-r--r--training/conf/callbacks/vae.yaml3
-rw-r--r--training/conf/callbacks/wandb/reconstructions.yaml4
-rw-r--r--training/conf/criterion/mae.yaml2
-rw-r--r--training/conf/criterion/mse.yaml2
-rw-r--r--training/conf/criterion/vqgan_loss.yaml11
-rw-r--r--training/conf/datamodule/transform/paragraphs.yaml2
-rw-r--r--training/conf/experiment/vq_transformer_lines.yaml149
-rw-r--r--training/conf/experiment/vqgan.yaml98
-rw-r--r--training/conf/experiment/vqgan_htr_char.yaml59
-rw-r--r--training/conf/experiment/vqvae.yaml51
-rw-r--r--training/conf/model/lit_vqgan.yaml1
-rw-r--r--training/conf/model/lit_vqvae.yaml1
-rw-r--r--training/conf/network/decoder/vae_decoder.yaml9
-rw-r--r--training/conf/network/encoder/vae_encoder.yaml9
-rw-r--r--training/conf/network/quantizer.yaml12
-rw-r--r--training/conf/network/vqvae.yaml18
34 files changed, 1 insertions, 1482 deletions
diff --git a/text_recognizer/criterion/n_layer_discriminator.py b/text_recognizer/criterion/n_layer_discriminator.py
deleted file mode 100644
index a9f47f9..0000000
--- a/text_recognizer/criterion/n_layer_discriminator.py
+++ /dev/null
@@ -1,59 +0,0 @@
-"""Pix2pix discriminator loss."""
-from torch import nn, Tensor
-
-from text_recognizer.networks.vqvae.norm import Normalize
-
-
-class NLayerDiscriminator(nn.Module):
- """Defines a PatchGAN discriminator loss in Pix2Pix."""
-
- def __init__(
- self, in_channels: int = 1, num_channels: int = 32, num_layers: int = 3
- ) -> None:
- super().__init__()
- self.in_channels = in_channels
- self.num_channels = num_channels
- self.num_layers = num_layers
- self.discriminator = self._build_discriminator()
-
- def _build_discriminator(self) -> nn.Sequential:
- """Builds discriminator."""
- discriminator = [
- nn.Sigmoid(),
- nn.Conv2d(
- in_channels=self.in_channels,
- out_channels=self.num_channels,
- kernel_size=4,
- stride=2,
- padding=1,
- ),
- nn.Mish(inplace=True),
- ]
- in_channels = self.num_channels
- for n in range(1, self.num_layers):
- discriminator += [
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=in_channels * n,
- kernel_size=4,
- stride=2,
- padding=1,
- ),
- # Normalize(num_channels=in_channels * n),
- nn.Mish(inplace=True),
- ]
- in_channels *= n
-
- discriminator += [
- nn.Conv2d(
- in_channels=self.num_channels * (self.num_layers - 1),
- out_channels=1,
- kernel_size=4,
- padding=1,
- )
- ]
- return nn.Sequential(*discriminator)
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass through discriminator."""
- return self.discriminator(x)
diff --git a/text_recognizer/criterion/vqgan_loss.py b/text_recognizer/criterion/vqgan_loss.py
deleted file mode 100644
index 8e8b65b..0000000
--- a/text_recognizer/criterion/vqgan_loss.py
+++ /dev/null
@@ -1,123 +0,0 @@
-"""VQGAN loss for PyTorch Lightning."""
-from typing import Optional, Tuple
-
-import torch
-from torch import nn, Tensor
-import torch.nn.functional as F
-
-from text_recognizer.criterion.n_layer_discriminator import NLayerDiscriminator
-
-
-def _adopt_weight(
- weight: Tensor, global_step: int, threshold: int = 0, value: float = 0.0
-) -> float:
- """Sets weight to value after the threshold is passed."""
- if global_step < threshold:
- weight = value
- return weight
-
-
-class VQGANLoss(nn.Module):
- """VQGAN loss."""
-
- def __init__(
- self,
- reconstruction_loss: nn.L1Loss,
- discriminator: NLayerDiscriminator,
- commitment_weight: float = 1.0,
- discriminator_weight: float = 1.0,
- discriminator_factor: float = 1.0,
- discriminator_iter_start: int = 1000,
- ) -> None:
- super().__init__()
- self.reconstruction_loss = reconstruction_loss
- self.discriminator = discriminator
- self.commitment_weight = commitment_weight
- self.discriminator_weight = discriminator_weight
- self.discriminator_factor = discriminator_factor
- self.discriminator_iter_start = discriminator_iter_start
-
- @staticmethod
- def adversarial_loss(logits_real: Tensor, logits_fake: Tensor) -> Tensor:
- """Calculates the adversarial loss."""
- loss_real = torch.mean(F.relu(1.0 - logits_real))
- loss_fake = torch.mean(F.relu(1.0 + logits_fake))
- d_loss = (loss_real + loss_fake) / 2.0
- return d_loss
-
- def _adaptive_weight(
- self, rec_loss: Tensor, g_loss: Tensor, decoder_last_layer: Tensor
- ) -> Tensor:
- rec_grads = torch.autograd.grad(
- rec_loss, decoder_last_layer, retain_graph=True
- )[0]
- g_grads = torch.autograd.grad(g_loss, decoder_last_layer, retain_graph=True)[0]
- d_weight = torch.norm(rec_grads) / (torch.norm(g_grads) + 1.0e-4)
- d_weight = torch.clamp(d_weight, 0.0, 1.0e4).detach()
- d_weight *= self.discriminator_weight
- return d_weight
-
- def forward(
- self,
- data: Tensor,
- reconstructions: Tensor,
- commitment_loss: Tensor,
- decoder_last_layer: Tensor,
- optimizer_idx: int,
- global_step: int,
- stage: str,
- ) -> Optional[Tuple]:
- """Calculates the VQGAN loss."""
- rec_loss: Tensor = self.reconstruction_loss(reconstructions, data)
-
- # GAN part.
- if optimizer_idx == 0:
- logits_fake = self.discriminator(reconstructions)
- g_loss = -torch.mean(logits_fake)
-
- if self.training:
- d_weight = self._adaptive_weight(
- rec_loss=rec_loss,
- g_loss=g_loss,
- decoder_last_layer=decoder_last_layer,
- )
- else:
- d_weight = torch.tensor(0.0)
-
- d_factor = _adopt_weight(
- self.discriminator_factor,
- global_step=global_step,
- threshold=self.discriminator_iter_start,
- )
-
- loss: Tensor = (
- rec_loss
- + d_factor * d_weight * g_loss
- + self.commitment_weight * commitment_loss
- )
- log = {
- f"{stage}/total_loss": loss,
- f"{stage}/commitment_loss": commitment_loss,
- f"{stage}/rec_loss": rec_loss,
- f"{stage}/g_loss": g_loss,
- }
- return loss, log
-
- if optimizer_idx == 1:
- logits_fake = self.discriminator(reconstructions.detach())
- logits_real = self.discriminator(data.detach())
-
- d_factor = _adopt_weight(
- self.discriminator_factor,
- global_step=global_step,
- threshold=self.discriminator_iter_start,
- )
-
- d_loss = d_factor * self.adversarial_loss(
- logits_real=logits_real, logits_fake=logits_fake
- )
-
- log = {
- f"{stage}/d_loss": d_loss,
- }
- return d_loss, log
diff --git a/text_recognizer/models/vq_transformer.py b/text_recognizer/models/vq_transformer.py
deleted file mode 100644
index 8ec28fd..0000000
--- a/text_recognizer/models/vq_transformer.py
+++ /dev/null
@@ -1,94 +0,0 @@
-"""PyTorch Lightning model for base Transformers."""
-from typing import Tuple, Type
-
-import attr
-import torch
-from torch import Tensor
-
-from text_recognizer.models.transformer import TransformerLitModel
-
-
-@attr.s(auto_attribs=True, eq=False)
-class VqTransformerLitModel(TransformerLitModel):
- """A PyTorch Lightning model for transformer networks."""
-
- def forward(self, data: Tensor) -> Tensor:
- """Forward pass with the transformer network."""
- return self.predict(data)
-
- def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
- """Training step."""
- data, targets = batch
- logits, commitment_loss = self.network(data, targets[:, :-1])
- loss = self.loss_fn(logits, targets[:, 1:]) + commitment_loss
- self.log("train/loss", loss)
- self.log("train/commitment_loss", commitment_loss)
- return loss
-
- def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Validation step."""
- data, targets = batch
- logits, commitment_loss = self.network(data, targets[:, :-1])
- loss = self.loss_fn(logits, targets[:, 1:]) + commitment_loss
- self.log("val/loss", loss, prog_bar=True)
- self.log("val/commitment_loss", commitment_loss)
-
- # Get the token prediction.
- # pred = self(data)
- # self.val_cer(pred, targets)
- # self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
- # self.test_acc(pred, targets)
- # self.log("val/acc", self.test_acc, on_step=False, on_epoch=True)
-
- def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Test step."""
- data, targets = batch
- pred = self(data)
- self.test_cer(pred, targets)
- self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)
- self.test_acc(pred, targets)
- self.log("test/acc", self.test_acc, on_step=False, on_epoch=True)
-
- def predict(self, x: Tensor) -> Tensor:
- """Predicts text in image.
-
- Args:
- x (Tensor): Image(s) to extract text from.
-
- Shapes:
- - x: :math: `(B, H, W)`
- - output: :math: `(B, S)`
-
- Returns:
- Tensor: A tensor of token indices of the predictions from the model.
- """
- bsz = x.shape[0]
-
- # Encode image(s) to latent vectors.
- z, _ = self.network.encode(x)
-
- # Create a placeholder matrix for storing outputs from the network
- output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
- output[:, 0] = self.start_index
-
- for Sy in range(1, self.max_output_len):
- context = output[:, :Sy] # (B, Sy)
- logits = self.network.decode(z, context) # (B, C, Sy)
- tokens = torch.argmax(logits, dim=1) # (B, Sy)
- output[:, Sy : Sy + 1] = tokens[:, -1:]
-
- # Early stopping of prediction loop if token is end or padding token.
- if (
- (output[:, Sy - 1] == self.end_index)
- | (output[:, Sy - 1] == self.pad_index)
- ).all():
- break
-
- # Set all tokens after end token to pad token.
- for Sy in range(1, self.max_output_len):
- idx = (output[:, Sy - 1] == self.end_index) | (
- output[:, Sy - 1] == self.pad_index
- )
- output[idx, Sy] = self.pad_index
-
- return output
diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py
deleted file mode 100644
index 6a90e06..0000000
--- a/text_recognizer/models/vqgan.py
+++ /dev/null
@@ -1,116 +0,0 @@
-"""PyTorch Lightning model for base Transformers."""
-from typing import Tuple
-
-import attr
-from torch import Tensor
-
-from text_recognizer.criterion.vqgan_loss import VQGANLoss
-from text_recognizer.models.base import BaseLitModel
-
-
-@attr.s(auto_attribs=True, eq=False)
-class VQGANLitModel(BaseLitModel):
- """A PyTorch Lightning model for transformer networks."""
-
- loss_fn: VQGANLoss = attr.ib()
- latent_loss_weight: float = attr.ib(default=0.25)
-
- def forward(self, data: Tensor) -> Tensor:
- """Forward pass with the transformer network."""
- return self.network(data)
-
- def training_step(
- self, batch: Tuple[Tensor, Tensor], batch_idx: int, optimizer_idx: int
- ) -> Tensor:
- """Training step."""
- data, _ = batch
- reconstructions, commitment_loss = self(data)
-
- if optimizer_idx == 0:
- loss, log = self.loss_fn(
- data=data,
- reconstructions=reconstructions,
- commitment_loss=commitment_loss,
- decoder_last_layer=self.network.decoder.decoder[-1].weight,
- optimizer_idx=optimizer_idx,
- global_step=self.global_step,
- stage="train",
- )
- self.log(
- "train/loss", loss, prog_bar=True,
- )
- self.log_dict(log, logger=True, on_step=True, on_epoch=True)
- return loss
-
- if optimizer_idx == 1:
- loss, log = self.loss_fn(
- data=data,
- reconstructions=reconstructions,
- commitment_loss=commitment_loss,
- decoder_last_layer=self.network.decoder.decoder[-1].weight,
- optimizer_idx=optimizer_idx,
- global_step=self.global_step,
- stage="train",
- )
- self.log(
- "train/discriminator_loss", loss, prog_bar=True,
- )
- self.log_dict(log, logger=True, on_step=True, on_epoch=True)
- return loss
-
- def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Validation step."""
- data, _ = batch
- reconstructions, commitment_loss = self(data)
-
- loss, log = self.loss_fn(
- data=data,
- reconstructions=reconstructions,
- commitment_loss=commitment_loss,
- decoder_last_layer=self.network.decoder.decoder[-1].weight,
- optimizer_idx=0,
- global_step=self.global_step,
- stage="val",
- )
- self.log(
- "val/loss", loss, prog_bar=True,
- )
- self.log_dict(log)
-
- _, log = self.loss_fn(
- data=data,
- reconstructions=reconstructions,
- commitment_loss=commitment_loss,
- decoder_last_layer=self.network.decoder.decoder[-1].weight,
- optimizer_idx=1,
- global_step=self.global_step,
- stage="val",
- )
- self.log_dict(log)
-
- def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Test step."""
- data, _ = batch
- reconstructions, commitment_loss = self(data)
-
- _, log = self.loss_fn(
- data=data,
- reconstructions=reconstructions,
- commitment_loss=commitment_loss,
- decoder_last_layer=self.network.decoder.decoder[-1].weight,
- optimizer_idx=0,
- global_step=self.global_step,
- stage="test",
- )
- self.log_dict(log)
-
- _, log = self.loss_fn(
- data=data,
- reconstructions=reconstructions,
- commitment_loss=commitment_loss,
- decoder_last_layer=self.network.decoder.decoder[-1].weight,
- optimizer_idx=1,
- global_step=self.global_step,
- stage="test",
- )
- self.log_dict(log)
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
deleted file mode 100644
index 4898852..0000000
--- a/text_recognizer/models/vqvae.py
+++ /dev/null
@@ -1,45 +0,0 @@
-"""PyTorch Lightning model for base Transformers."""
-from typing import Tuple
-
-import attr
-from torch import Tensor
-
-from text_recognizer.models.base import BaseLitModel
-
-
-@attr.s(auto_attribs=True, eq=False)
-class VQVAELitModel(BaseLitModel):
- """A PyTorch Lightning model for transformer networks."""
-
- commitment: float = attr.ib(default=0.25)
-
- def forward(self, data: Tensor) -> Tensor:
- """Forward pass with the transformer network."""
- return self.network(data)
-
- def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
- """Training step."""
- data, _ = batch
- reconstructions, commitment_loss = self(data)
- loss = self.loss_fn(reconstructions, data)
- loss = loss + self.commitment * commitment_loss
- self.log("train/commitment_loss", commitment_loss)
- self.log("train/loss", loss)
- return loss
-
- def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Validation step."""
- data, _ = batch
- reconstructions, commitment_loss = self(data)
- loss = self.loss_fn(reconstructions, data)
- self.log("val/commitment_loss", commitment_loss)
- self.log("val/loss", loss, prog_bar=True)
-
- def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Test step."""
- data, _ = batch
- reconstructions, commitment_loss = self(data)
- loss = self.loss_fn(reconstructions, data)
- loss = loss + self.commitment * commitment_loss
- self.log("test/commitment_loss", commitment_loss)
- self.log("test/loss", loss)
diff --git a/text_recognizer/networks/quantizer/__init__.py b/text_recognizer/networks/quantizer/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/text_recognizer/networks/quantizer/__init__.py
+++ /dev/null
diff --git a/text_recognizer/networks/quantizer/codebook.py b/text_recognizer/networks/quantizer/codebook.py
deleted file mode 100644
index cb9bc59..0000000
--- a/text_recognizer/networks/quantizer/codebook.py
+++ /dev/null
@@ -1,96 +0,0 @@
-"""Codebook module."""
-from typing import Tuple
-
-import attr
-from einops import rearrange
-import torch
-from torch import nn, Tensor
-import torch.nn.functional as F
-
-from text_recognizer.networks.quantizer.kmeans import kmeans
-from text_recognizer.networks.quantizer.utils import (
- ema_inplace,
- norm,
- sample_vectors,
-)
-
-
-@attr.s(eq=False)
-class CosineSimilarityCodebook(nn.Module):
- """Cosine similarity codebook."""
-
- dim: int = attr.ib()
- codebook_size: int = attr.ib()
- kmeans_init: bool = attr.ib(default=False)
- kmeans_iters: int = attr.ib(default=10)
- decay: float = attr.ib(default=0.8)
- eps: float = attr.ib(default=1.0e-5)
- threshold_dead: int = attr.ib(default=2)
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
-
- def __attrs_post_init__(self) -> None:
- if not self.kmeans_init:
- embeddings = norm(torch.randn(self.codebook_size, self.dim))
- else:
- embeddings = torch.zeros(self.codebook_size, self.dim)
- self.register_buffer("initalized", Tensor([not self.kmeans_init]))
- self.register_buffer("cluster_size", torch.zeros(self.codebook_size))
- self.register_buffer("embeddings", embeddings)
-
- def _initalize_embedding(self, data: Tensor) -> None:
- embeddings, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
- self.embeddings.data.copy_(embeddings)
- self.cluster_size.data.copy_(cluster_size)
- self.initalized.data.copy_(Tensor([True]))
-
- def _replace(self, samples: Tensor, mask: Tensor) -> None:
- samples = norm(samples)
- modified_codebook = torch.where(
- mask[..., None],
- sample_vectors(samples, self.codebook_size),
- self.embeddings,
- )
- self.embeddings.data.copy_(modified_codebook)
-
- def _replace_dead_codes(self, batch_samples: Tensor) -> None:
- if self.threshold_dead == 0:
- return
- dead_codes = self.cluster_size < self.threshold_dead
- if not torch.any(dead_codes):
- return
- batch_samples = rearrange(batch_samples, "... d -> (...) d")
- self._replace(batch_samples, mask=dead_codes)
-
- def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Quantizes tensor."""
- shape = x.shape
- flatten = rearrange(x, "... d -> (...) d")
- flatten = norm(flatten)
-
- if not self.initalized:
- self._initalize_embedding(flatten)
-
- embeddings = norm(self.embeddings)
- dist = flatten @ embeddings.t()
- indices = dist.max(dim=-1).indices
- one_hot = F.one_hot(indices, self.codebook_size).type_as(x)
- indices = indices.view(*shape[:-1])
-
- quantized = F.embedding(indices, self.embeddings)
-
- if self.training:
- bins = one_hot.sum(0)
- ema_inplace(self.cluster_size, bins, self.decay)
- zero_mask = bins == 0
- bins = bins.masked_fill(zero_mask, 1.0)
-
- embed_sum = flatten.t() @ one_hot
- embed_norm = (embed_sum / bins.unsqueeze(0)).t()
- embed_norm = norm(embed_norm)
- embed_norm = torch.where(zero_mask[..., None], embeddings, embed_norm)
- ema_inplace(self.embeddings, embed_norm, self.decay)
- self._replace_dead_codes(x)
-
- return quantized, indices
diff --git a/text_recognizer/networks/quantizer/kmeans.py b/text_recognizer/networks/quantizer/kmeans.py
deleted file mode 100644
index a34c381..0000000
--- a/text_recognizer/networks/quantizer/kmeans.py
+++ /dev/null
@@ -1,32 +0,0 @@
-"""K-means clustering for embeddings."""
-from typing import Tuple
-
-from einops import repeat
-import torch
-from torch import Tensor
-
-from text_recognizer.networks.quantizer.utils import norm, sample_vectors
-
-
-def kmeans(
- samples: Tensor, num_clusters: int, num_iters: int = 10
-) -> Tuple[Tensor, Tensor]:
- """Compute k-means clusters."""
- D = samples.shape[-1]
-
- means = sample_vectors(samples, num_clusters)
-
- for _ in range(num_iters):
- dists = samples @ means.t()
- buckets = dists.max(dim=-1).indices
- bins = torch.bincount(buckets, minlength=num_clusters)
- zero_mask = bins == 0
- bins_min_clamped = bins.masked_fill(zero_mask, 1)
-
- new_means = buckets.new_zeros(num_clusters, D).type_as(samples)
- new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=D), samples)
- new_means /= bins_min_clamped[..., None]
- new_means = norm(new_means)
- means = torch.where(zero_mask[..., None], means, new_means)
-
- return means, bins
diff --git a/text_recognizer/networks/quantizer/quantizer.py b/text_recognizer/networks/quantizer/quantizer.py
deleted file mode 100644
index 3e8f0b2..0000000
--- a/text_recognizer/networks/quantizer/quantizer.py
+++ /dev/null
@@ -1,59 +0,0 @@
-"""Implementation of a Vector Quantized Variational AutoEncoder.
-
-Reference:
-https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py
-"""
-from typing import Tuple, Type
-
-import attr
-from einops import rearrange
-import torch
-from torch import nn
-from torch import Tensor
-import torch.nn.functional as F
-
-
-@attr.s(eq=False)
-class VectorQuantizer(nn.Module):
- """Vector quantizer."""
-
- input_dim: int = attr.ib()
- codebook: Type[nn.Module] = attr.ib()
- commitment: float = attr.ib(default=1.0)
- project_in: nn.Linear = attr.ib(default=None, init=False)
- project_out: nn.Linear = attr.ib(default=None, init=False)
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
-
- def __attrs_post_init__(self) -> None:
- require_projection = self.codebook.dim != self.input_dim
- self.project_in = (
- nn.Linear(self.input_dim, self.codebook.dim)
- if require_projection
- else nn.Identity()
- )
- self.project_out = (
- nn.Linear(self.codebook.dim, self.input_dim)
- if require_projection
- else nn.Identity()
- )
-
- def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
- """Quantizes latent vectors."""
- H, W = x.shape[-2:]
- x = rearrange(x, "b d h w -> b (h w) d")
- x = self.project_in(x)
-
- quantized, indices = self.codebook(x)
-
- if self.training:
- commitment_loss = F.mse_loss(quantized.detach(), x) * self.commitment
- quantized = x + (quantized - x).detach()
- else:
- commitment_loss = torch.tensor([0.0]).type_as(x)
-
- quantized = self.project_out(quantized)
- quantized = rearrange(quantized, "b (h w) d -> b d h w", h=H, w=W)
-
- return quantized, indices, commitment_loss
diff --git a/text_recognizer/networks/quantizer/utils.py b/text_recognizer/networks/quantizer/utils.py
deleted file mode 100644
index 0502d49..0000000
--- a/text_recognizer/networks/quantizer/utils.py
+++ /dev/null
@@ -1,26 +0,0 @@
-"""Helper functions for quantization."""
-from typing import Tuple
-
-import torch
-from torch import Tensor
-import torch.nn.functional as F
-
-
-def sample_vectors(samples: Tensor, num: int) -> Tensor:
- """Subsamples a set of vectors."""
- B, device = samples.shape[0], samples.device
- if B >= num:
- indices = torch.randperm(B, device=device)[:num]
- else:
- indices = torch.randint(0, B, (num,), device=device)[:num]
- return samples[indices]
-
-
-def norm(t: Tensor) -> Tensor:
- """Applies L2-normalization."""
- return F.normalize(t, p=2, dim=-1)
-
-
-def ema_inplace(moving_avg: Tensor, new: Tensor, decay: float) -> None:
- """Applies exponential moving average."""
- moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py
deleted file mode 100644
index a2bd81b..0000000
--- a/text_recognizer/networks/vq_transformer.py
+++ /dev/null
@@ -1,84 +0,0 @@
-"""Vector quantized encoder, transformer decoder."""
-from typing import Optional, Tuple, Type
-
-from torch import nn, Tensor
-
-from text_recognizer.networks.conv_transformer import ConvTransformer
-from text_recognizer.networks.quantizer.quantizer import VectorQuantizer
-from text_recognizer.networks.transformer.layers import Decoder
-
-
-class VqTransformer(ConvTransformer):
- """Convolutional encoder and transformer decoder network."""
-
- def __init__(
- self,
- input_dims: Tuple[int, int, int],
- hidden_dim: int,
- num_classes: int,
- pad_index: Tensor,
- encoder: nn.Module,
- decoder: Decoder,
- pixel_pos_embedding: Type[nn.Module],
- quantizer: VectorQuantizer,
- token_pos_embedding: Optional[Type[nn.Module]] = None,
- ) -> None:
- super().__init__(
- input_dims=input_dims,
- hidden_dim=hidden_dim,
- num_classes=num_classes,
- pad_index=pad_index,
- encoder=encoder,
- decoder=decoder,
- pixel_pos_embedding=pixel_pos_embedding,
- token_pos_embedding=token_pos_embedding,
- )
- self.quantizer = quantizer
-
- def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Encodes an image into a discrete (VQ) latent representation.
-
- Args:
- x (Tensor): Image tensor.
-
- Shape:
- - x: :math: `(B, C, H, W)`
- - z: :math: `(B, Sx, E)`
-
- where Sx is the length of the flattened feature maps projected from
- the encoder. E latent dimension for each pixel in the projected
- feature maps.
-
- Returns:
- Tensor: A Latent embedding of the image.
- """
- z = self.encoder(x)
- z = self.conv(z)
- z, _, commitment_loss = self.quantizer(z)
- z = self.pixel_pos_embedding(z)
- z = z.flatten(start_dim=2)
-
- # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
- z = z.permute(0, 2, 1)
- return z, commitment_loss
-
- def forward(self, x: Tensor, context: Tensor) -> Tensor:
- """Encodes images into word piece logtis.
-
- Args:
- x (Tensor): Input image(s).
- context (Tensor): Target word embeddings.
-
- Shapes:
- - x: :math: `(B, C, H, W)`
- - context: :math: `(B, Sy, T)`
-
- where B is the batch size, C is the number of input channels, H is
- the image height and W is the image width.
-
- Returns:
- Tensor: Sequence of logits.
- """
- z, commitment_loss = self.encode(x)
- logits = self.decode(z, context)
- return logits, commitment_loss
diff --git a/text_recognizer/networks/vqvae/__init__.py b/text_recognizer/networks/vqvae/__init__.py
deleted file mode 100644
index e1f05fa..0000000
--- a/text_recognizer/networks/vqvae/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""VQ-VAE module."""
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
deleted file mode 100644
index 7734a5a..0000000
--- a/text_recognizer/networks/vqvae/decoder.py
+++ /dev/null
@@ -1,93 +0,0 @@
-"""CNN decoder for the VQ-VAE."""
-from typing import Sequence
-
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.vqvae.norm import Normalize
-from text_recognizer.networks.vqvae.residual import Residual
-
-
-class Decoder(nn.Module):
- """A CNN encoder network."""
-
- def __init__(
- self,
- out_channels: int,
- hidden_dim: int,
- channels_multipliers: Sequence[int],
- dropout_rate: float,
- activation: str = "mish",
- use_norm: bool = False,
- num_residuals: int = 4,
- residual_channels: int = 32,
- ) -> None:
- super().__init__()
- self.out_channels = out_channels
- self.hidden_dim = hidden_dim
- self.channels_multipliers = tuple(channels_multipliers)
- self.activation = activation
- self.dropout_rate = dropout_rate
- self.use_norm = use_norm
- self.num_residuals = num_residuals
- self.residual_channels = residual_channels
- self.decoder = self._build_decompression_block()
-
- def _build_decompression_block(self,) -> nn.Sequential:
- decoder = []
- in_channels = self.hidden_dim * self.channels_multipliers[0]
- for _ in range(self.num_residuals):
- decoder += [
- Residual(
- in_channels=in_channels,
- residual_channels=self.residual_channels,
- use_norm=self.use_norm,
- activation=self.activation,
- ),
- ]
-
- activation_fn = activation_function(self.activation)
- out_channels_multipliers = self.channels_multipliers + (1,)
- num_blocks = len(self.channels_multipliers)
-
- for i in range(num_blocks):
- in_channels = self.hidden_dim * self.channels_multipliers[i]
- out_channels = self.hidden_dim * out_channels_multipliers[i + 1]
- if self.use_norm:
- decoder += [
- Normalize(num_channels=in_channels,),
- ]
- decoder += [
- activation_fn,
- nn.ConvTranspose2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=4,
- stride=2,
- padding=1,
- ),
- ]
-
- if self.use_norm:
- decoder += [
- Normalize(
- num_channels=self.hidden_dim * out_channels_multipliers[-1],
- num_groups=self.hidden_dim * out_channels_multipliers[-1] // 4,
- ),
- ]
-
- decoder += [
- nn.Conv2d(
- in_channels=self.hidden_dim * out_channels_multipliers[-1],
- out_channels=self.out_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- ),
- ]
- return nn.Sequential(*decoder)
-
- def forward(self, z_q: Tensor) -> Tensor:
- """Reconstruct input from given codes."""
- return self.decoder(z_q)
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
deleted file mode 100644
index 4761486..0000000
--- a/text_recognizer/networks/vqvae/encoder.py
+++ /dev/null
@@ -1,85 +0,0 @@
-"""CNN encoder for the VQ-VAE."""
-from typing import List, Tuple
-
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.vqvae.norm import Normalize
-from text_recognizer.networks.vqvae.residual import Residual
-
-
-class Encoder(nn.Module):
- """A CNN encoder network."""
-
- def __init__(
- self,
- in_channels: int,
- hidden_dim: int,
- channels_multipliers: List[int],
- dropout_rate: float,
- activation: str = "mish",
- use_norm: bool = False,
- num_residuals: int = 4,
- residual_channels: int = 32,
- ) -> None:
- super().__init__()
- self.in_channels = in_channels
- self.hidden_dim = hidden_dim
- self.channels_multipliers = tuple(channels_multipliers)
- self.activation = activation
- self.dropout_rate = dropout_rate
- self.use_norm = use_norm
- self.num_residuals = num_residuals
- self.residual_channels = residual_channels
- self.encoder = self._build_compression_block()
-
- def _build_compression_block(self) -> nn.Sequential:
- """Builds encoder network."""
- num_blocks = len(self.channels_multipliers)
- channels_multipliers = (1,) + self.channels_multipliers
- activation_fn = activation_function(self.activation)
-
- encoder = [
- nn.Conv2d(
- in_channels=self.in_channels,
- out_channels=self.hidden_dim,
- kernel_size=3,
- stride=1,
- padding=1,
- ),
- ]
-
- for i in range(num_blocks):
- in_channels = self.hidden_dim * channels_multipliers[i]
- out_channels = self.hidden_dim * channels_multipliers[i + 1]
- if self.use_norm:
- encoder += [
- Normalize(num_channels=in_channels,),
- ]
- encoder += [
- activation_fn,
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=4,
- stride=2,
- padding=1,
- ),
- ]
-
- for _ in range(self.num_residuals):
- encoder += [
- Residual(
- in_channels=out_channels,
- residual_channels=self.residual_channels,
- use_norm=self.use_norm,
- activation=self.activation,
- )
- ]
-
- return nn.Sequential(*encoder)
-
- def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Encodes input into a discrete representation."""
- return self.encoder(x)
diff --git a/text_recognizer/networks/vqvae/norm.py b/text_recognizer/networks/vqvae/norm.py
deleted file mode 100644
index d73f9f8..0000000
--- a/text_recognizer/networks/vqvae/norm.py
+++ /dev/null
@@ -1,24 +0,0 @@
-"""Normalizer block."""
-import attr
-from torch import nn, Tensor
-
-
-@attr.s(eq=False)
-class Normalize(nn.Module):
- num_channels: int = attr.ib()
- num_groups: int = attr.ib(default=32)
- norm: nn.GroupNorm = attr.ib(init=False)
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
- super().__init__()
- self.norm = nn.GroupNorm(
- num_groups=self.num_groups,
- num_channels=self.num_channels,
- eps=1.0e-6,
- affine=True,
- )
-
- def forward(self, x: Tensor) -> Tensor:
- """Applies group normalization."""
- return self.norm(x)
diff --git a/text_recognizer/networks/vqvae/residual.py b/text_recognizer/networks/vqvae/residual.py
deleted file mode 100644
index bdff9eb..0000000
--- a/text_recognizer/networks/vqvae/residual.py
+++ /dev/null
@@ -1,54 +0,0 @@
-"""Residual block."""
-import attr
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.vqvae.norm import Normalize
-
-
-@attr.s(eq=False)
-class Residual(nn.Module):
- in_channels: int = attr.ib()
- residual_channels: int = attr.ib()
- use_norm: bool = attr.ib(default=False)
- activation: str = attr.ib(default="relu")
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
- super().__init__()
- self.block = self._build_res_block()
-
- def _build_res_block(self) -> nn.Sequential:
- """Build residual block."""
- block = []
- activation_fn = activation_function(activation=self.activation)
-
- if self.use_norm:
- block.append(Normalize(num_channels=self.in_channels))
-
- block += [
- activation_fn,
- nn.Conv2d(
- self.in_channels,
- self.residual_channels,
- kernel_size=3,
- padding=1,
- bias=False,
- ),
- ]
-
- if self.use_norm:
- block.append(Normalize(num_channels=self.residual_channels))
-
- block += [
- activation_fn,
- nn.Conv2d(
- self.residual_channels, self.in_channels, kernel_size=1, bias=False
- ),
- ]
- return nn.Sequential(*block)
-
- def forward(self, x: Tensor) -> Tensor:
- """Apply the residual forward pass."""
- return x + self.block(x)
diff --git a/text_recognizer/networks/vqvae/resize.py b/text_recognizer/networks/vqvae/resize.py
deleted file mode 100644
index 8d67d02..0000000
--- a/text_recognizer/networks/vqvae/resize.py
+++ /dev/null
@@ -1,19 +0,0 @@
-"""Up and down-sample with linear interpolation."""
-from torch import nn, Tensor
-import torch.nn.functional as F
-
-
-class Upsample(nn.Module):
- """Upsamples by a factor 2."""
-
- def forward(self, x: Tensor) -> Tensor:
- """Applies upsampling."""
- return F.interpolate(x, scale_factor=2.0, mode="nearest")
-
-
-class Downsample(nn.Module):
- """Downsampling by a factor 2."""
-
- def forward(self, x: Tensor) -> Tensor:
- """Applies downsampling."""
- return F.avg_pool2d(x, kernel_size=2, stride=2)
diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py
deleted file mode 100644
index 5560e12..0000000
--- a/text_recognizer/networks/vqvae/vqvae.py
+++ /dev/null
@@ -1,42 +0,0 @@
-"""The VQ-VAE."""
-from typing import Tuple
-
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.quantizer.quantizer import VectorQuantizer
-
-
-class VQVAE(nn.Module):
- """Vector Quantized Variational AutoEncoder."""
-
- def __init__(
- self,
- encoder: nn.Module,
- decoder: nn.Module,
- quantizer: VectorQuantizer,
- ) -> None:
- super().__init__()
- self.encoder = encoder
- self.decoder = decoder
- self.quantizer = quantizer
-
- def encode(self, x: Tensor) -> Tensor:
- """Encodes input to a latent code."""
- return self.encoder(x)
-
- def quantize(self, z_e: Tensor) -> Tuple[Tensor, Tensor]:
- """Quantizes the encoded latent vectors."""
- z_q, _, commitment_loss = self.quantizer(z_e)
- return z_q, commitment_loss
-
- def decode(self, z_q: Tensor) -> Tensor:
- """Reconstructs input from latent codes."""
- return self.decoder(z_q)
-
- def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Compresses and decompresses input."""
- z_e = self.encode(x)
- z_q, commitment_loss = self.quantize(z_e)
- x_hat = self.decode(z_q)
- return x_hat, commitment_loss
diff --git a/training/conf/callbacks/vae.yaml b/training/conf/callbacks/vae.yaml
deleted file mode 100644
index 52adf69..0000000
--- a/training/conf/callbacks/vae.yaml
+++ /dev/null
@@ -1,3 +0,0 @@
-defaults:
- - default
- - wandb/reconstructions
diff --git a/training/conf/callbacks/wandb/reconstructions.yaml b/training/conf/callbacks/wandb/reconstructions.yaml
deleted file mode 100644
index 92f2d12..0000000
--- a/training/conf/callbacks/wandb/reconstructions.yaml
+++ /dev/null
@@ -1,4 +0,0 @@
-log_image_reconstruction:
- _target_: callbacks.wandb_callbacks.LogReconstuctedImages
- num_samples: 8
- use_sigmoid: true
diff --git a/training/conf/criterion/mae.yaml b/training/conf/criterion/mae.yaml
deleted file mode 100644
index cb07467..0000000
--- a/training/conf/criterion/mae.yaml
+++ /dev/null
@@ -1,2 +0,0 @@
-_target_: torch.nn.L1Loss
-reduction: mean
diff --git a/training/conf/criterion/mse.yaml b/training/conf/criterion/mse.yaml
deleted file mode 100644
index ffd1403..0000000
--- a/training/conf/criterion/mse.yaml
+++ /dev/null
@@ -1,2 +0,0 @@
-_target_: torch.nn.MSELoss
-reduction: mean
diff --git a/training/conf/criterion/vqgan_loss.yaml b/training/conf/criterion/vqgan_loss.yaml
deleted file mode 100644
index 34a67ae..0000000
--- a/training/conf/criterion/vqgan_loss.yaml
+++ /dev/null
@@ -1,11 +0,0 @@
-_target_: text_recognizer.criterion.vqgan_loss.VQGANLoss
-reconstruction_loss:
- _target_: torch.nn.MSELoss
- reduction: mean
-discriminator:
- _target_: text_recognizer.criterions.n_layer_discriminator.NLayerDiscriminator
- in_channels: 1
- num_channels: 32
- num_layers: 3
-vq_loss_weight: 1.0
-discriminator_weight: 1.0
diff --git a/training/conf/datamodule/transform/paragraphs.yaml b/training/conf/datamodule/transform/paragraphs.yaml
index c1d2e4f..1b0ab04 100644
--- a/training/conf/datamodule/transform/paragraphs.yaml
+++ b/training/conf/datamodule/transform/paragraphs.yaml
@@ -18,7 +18,7 @@ random_affine:
random_perspective:
_target_: torchvision.transforms.RandomPerspective
- distortion_scale: 0.07
+ distortion_scale: 0.1
p: 0.5
fill: 0
diff --git a/training/conf/experiment/vq_transformer_lines.yaml b/training/conf/experiment/vq_transformer_lines.yaml
deleted file mode 100644
index bbe1178..0000000
--- a/training/conf/experiment/vq_transformer_lines.yaml
+++ /dev/null
@@ -1,149 +0,0 @@
-# @package _global_
-
-defaults:
- - override /mapping: null
- - override /criterion: cross_entropy
- - override /callbacks: htr
- - override /datamodule: iam_lines
- - override /network: null
- - override /model: null
- - override /lr_schedulers: null
- - override /optimizers: null
-
-epochs: &epochs 512
-ignore_index: &ignore_index 3
-num_classes: &num_classes 57
-max_output_len: &max_output_len 89
-summary: [[1, 1, 56, 1024], [1, 89]]
-
-criterion:
- ignore_index: *ignore_index
-
-mapping: &mapping
- mapping:
- _target_: text_recognizer.data.mappings.emnist.EmnistMapping
-
-callbacks:
- stochastic_weight_averaging:
- _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
- swa_epoch_start: 0.75
- swa_lrs: 1.0e-5
- annealing_epochs: 10
- annealing_strategy: cos
- device: null
-
-optimizers:
- madgrad:
- _target_: madgrad.MADGRAD
- lr: 3.0e-4
- momentum: 0.9
- weight_decay: 0
- eps: 1.0e-6
- parameters: network
-
-lr_schedulers:
- network:
- _target_: torch.optim.lr_scheduler.CosineAnnealingLR
- T_max: *epochs
- eta_min: 1.0e-5
- last_epoch: -1
- interval: epoch
- monitor: val/loss
-
-datamodule:
- batch_size: 16
- num_workers: 12
- train_fraction: 0.9
- pin_memory: true
- << : *mapping
-
-rotary_embedding: &rotary_embedding
- rotary_embedding:
- _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding
- dim: 64
-
-attn: &attn
- dim: &hidden_dim 512
- num_heads: 4
- dim_head: 64
- dropout_rate: &dropout_rate 0.4
-
-network:
- _target_: text_recognizer.networks.vq_transformer.VqTransformer
- input_dims: [1, 56, 1024]
- hidden_dim: *hidden_dim
- num_classes: *num_classes
- pad_index: *ignore_index
- encoder:
- _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet
- arch: b1
- stochastic_dropout_rate: 0.2
- bn_momentum: 0.99
- bn_eps: 1.0e-3
- decoder:
- depth: 6
- _target_: text_recognizer.networks.transformer.layers.Decoder
- self_attn:
- _target_: text_recognizer.networks.transformer.attention.Attention
- << : *attn
- causal: true
- << : *rotary_embedding
- cross_attn:
- _target_: text_recognizer.networks.transformer.attention.Attention
- << : *attn
- causal: false
- norm:
- _target_: text_recognizer.networks.transformer.norm.ScaleNorm
- normalized_shape: *hidden_dim
- ff:
- _target_: text_recognizer.networks.transformer.mlp.FeedForward
- dim: *hidden_dim
- dim_out: null
- expansion_factor: 4
- glu: true
- dropout_rate: *dropout_rate
- pre_norm: true
- pixel_pos_embedding:
- _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding
- dim: *hidden_dim
- shape: [1, 32]
- quantizer:
- _target_: text_recognizer.networks.quantizer.quantizer.VectorQuantizer
- input_dim: 512
- codebook:
- _target_: text_recognizer.networks.quantizer.codebook.CosineSimilarityCodebook
- dim: 16
- codebook_size: 4096
- kmeans_init: true
- kmeans_iters: 10
- decay: 0.8
- eps: 1.0e-5
- threshold_dead: 2
- commitment: 1.0
-
-model:
- _target_: text_recognizer.models.vq_transformer.VqTransformerLitModel
- << : *mapping
- max_output_len: *max_output_len
- start_token: <s>
- end_token: <e>
- pad_token: <p>
-
-trainer:
- _target_: pytorch_lightning.Trainer
- stochastic_weight_avg: true
- auto_scale_batch_size: binsearch
- auto_lr_find: false
- gradient_clip_val: 0.5
- fast_dev_run: false
- gpus: 1
- precision: 16
- max_epochs: *epochs
- terminate_on_nan: true
- weights_summary: null
- limit_train_batches: 1.0
- limit_val_batches: 1.0
- limit_test_batches: 1.0
- resume_from_checkpoint: null
- accumulate_grad_batches: 1
- overfit_batches: 0
diff --git a/training/conf/experiment/vqgan.yaml b/training/conf/experiment/vqgan.yaml
deleted file mode 100644
index 726757f..0000000
--- a/training/conf/experiment/vqgan.yaml
+++ /dev/null
@@ -1,98 +0,0 @@
-# @package _global_
-
-defaults:
- - override /network: vqvae
- - override /criterion: null
- - override /model: lit_vqgan
- - override /callbacks: vae
- - override /optimizers: null
- - override /lr_schedulers: null
-
-epochs: &epochs 100
-ignore_index: &ignore_index 3
-num_classes: &num_classes 58
-max_output_len: &max_output_len 682
-summary: [[1, 1, 576, 640]]
-
-criterion:
- _target_: text_recognizer.criterion.vqgan_loss.VQGANLoss
- reconstruction_loss:
- _target_: torch.nn.BCEWithLogitsLoss
- reduction: mean
- discriminator:
- _target_: text_recognizer.criterion.n_layer_discriminator.NLayerDiscriminator
- in_channels: 1
- num_channels: 64
- num_layers: 3
- commitment_weight: 0.25
- discriminator_weight: 0.8
- discriminator_factor: 1.0
- discriminator_iter_start: 8.0e4
-
-mapping: &mapping
- mapping:
- _target_: text_recognizer.data.mappings.emnist.EmnistMapping
- extra_symbols: [ "\n" ]
-
-datamodule:
- _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs
- batch_size: 4
- num_workers: 12
- train_fraction: 0.9
- pin_memory: true
- << : *mapping
-
-lr_schedulers:
- network:
- _target_: torch.optim.lr_scheduler.CosineAnnealingLR
- T_max: *epochs
- eta_min: 1.0e-5
- last_epoch: -1
- interval: epoch
- monitor: val/loss
-
- discriminator:
- _target_: torch.optim.lr_scheduler.CosineAnnealingLR
- T_max: *epochs
- eta_min: 1.0e-5
- last_epoch: -1
- interval: epoch
- monitor: val/loss
-
-optimizers:
- generator:
- _target_: madgrad.MADGRAD
- lr: 1.0e-4
- momentum: 0.5
- weight_decay: 0
- eps: 1.0e-7
-
- parameters: network
-
- discriminator:
- _target_: madgrad.MADGRAD
- lr: 4.5e-6
- momentum: 0.5
- weight_decay: 0
- eps: 1.0e-6
-
- parameters: loss_fn.discriminator
-
-trainer:
- _target_: pytorch_lightning.Trainer
- stochastic_weight_avg: false
- auto_scale_batch_size: binsearch
- auto_lr_find: false
- gradient_clip_val: 0
- fast_dev_run: false
- gpus: 1
- precision: 16
- max_epochs: *epochs
- terminate_on_nan: true
- weights_summary: null
- limit_train_batches: 1.0
- limit_val_batches: 1.0
- limit_test_batches: 1.0
- resume_from_checkpoint: null
- accumulate_grad_batches: 2
- overfit_batches: 0
diff --git a/training/conf/experiment/vqgan_htr_char.yaml b/training/conf/experiment/vqgan_htr_char.yaml
deleted file mode 100644
index af3fa40..0000000
--- a/training/conf/experiment/vqgan_htr_char.yaml
+++ /dev/null
@@ -1,59 +0,0 @@
-defaults:
- - override /mapping: null
- - override /network: null
- - override /model: null
-
-mapping:
- _target_: text_recognizer.data.emnist_mapping.EmnistMapping
- extra_symbols: [ "\n" ]
-
-datamodule:
- word_pieces: false
- batch_size: 8
- augment: false
-
-criterion:
- ignore_index: 3
-
-network:
- _target_: text_recognizer.networks.vq_transformer.VqTransformer
- input_dims: [1, 576, 640]
- encoder_dim: 32
- hidden_dim: 256
- dropout_rate: 0.1
- num_classes: 58
- pad_index: 3
- no_grad: true
- decoder:
- _target_: text_recognizer.networks.transformer.Decoder
- dim: 256
- depth: 2
- num_heads: 8
- attn_fn: text_recognizer.networks.transformer.attention.Attention
- attn_kwargs:
- dim_head: 32
- dropout_rate: 0.2
- norm_fn: torch.nn.LayerNorm
- ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
- ff_kwargs:
- dim_out: null
- expansion_factor: 4
- glu: true
- dropout_rate: 0.2
- cross_attend: true
- pre_norm: true
- rotary_emb: null
- pretrained_encoder_path: "training/logs/runs/2021-09-25/23-07-28"
-
-model:
- _target_: text_recognizer.models.vq_transformer.VqTransformerLitModel
- start_token: <s>
- end_token: <e>
- pad_token: <p>
- max_output_len: 682 # 451
- alpha: 1.0
-
-trainer:
- max_epochs: 64
- limit_train_batches: 0.1
- limit_val_batches: 0.1
diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml
deleted file mode 100644
index d069aef..0000000
--- a/training/conf/experiment/vqvae.yaml
+++ /dev/null
@@ -1,51 +0,0 @@
-defaults:
- - override /network: vqvae
- - override /criterion: mse
- - override /model: lit_vqvae
- - override /callbacks: wandb_vae
- - override /optimizers: null
- # - override /lr_schedulers:
- # - cosine_annealing
-
-# lr_schedulers: null
-# network:
-# _target_: torch.optim.lr_scheduler.OneCycleLR
-# max_lr: 1.0e-2
-# total_steps: null
-# epochs: 100
-# steps_per_epoch: 200
-# pct_start: 0.1
-# anneal_strategy: cos
-# cycle_momentum: true
-# base_momentum: 0.85
-# max_momentum: 0.95
-# div_factor: 25
-# final_div_factor: 1.0e4
-# three_phase: true
-# last_epoch: -1
-# verbose: false
-
-# # Non-class arguments
-# interval: step
-# monitor: val/loss
-
-optimizers:
- network:
- _target_: madgrad.MADGRAD
- lr: 1.0e-4
- momentum: 0.9
- weight_decay: 0
- eps: 1.0e-7
-
- parameters: network
-
-trainer:
- max_epochs: 128
- limit_train_batches: 0.1
- limit_val_batches: 0.1
-
-datamodule:
- batch_size: 8
- # resize: [288, 320]
-
-summary: null
diff --git a/training/conf/model/lit_vqgan.yaml b/training/conf/model/lit_vqgan.yaml
deleted file mode 100644
index 9ee1046..0000000
--- a/training/conf/model/lit_vqgan.yaml
+++ /dev/null
@@ -1 +0,0 @@
-_target_: text_recognizer.models.vqgan.VQGANLitModel
diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml
deleted file mode 100644
index 6dc44d7..0000000
--- a/training/conf/model/lit_vqvae.yaml
+++ /dev/null
@@ -1 +0,0 @@
-_target_: text_recognizer.models.vqvae.VQVAELitModel
diff --git a/training/conf/network/decoder/vae_decoder.yaml b/training/conf/network/decoder/vae_decoder.yaml
deleted file mode 100644
index aed5733..0000000
--- a/training/conf/network/decoder/vae_decoder.yaml
+++ /dev/null
@@ -1,9 +0,0 @@
-_target_: text_recognizer.networks.vqvae.decoder.Decoder
-out_channels: 1
-hidden_dim: 32
-channels_multipliers: [4, 2, 1]
-dropout_rate: 0.0
-activation: mish
-use_norm: true
-num_residuals: 4
-residual_channels: 32
diff --git a/training/conf/network/encoder/vae_encoder.yaml b/training/conf/network/encoder/vae_encoder.yaml
deleted file mode 100644
index 5d39bf7..0000000
--- a/training/conf/network/encoder/vae_encoder.yaml
+++ /dev/null
@@ -1,9 +0,0 @@
-_target_: text_recognizer.networks.vqvae.encoder.Encoder
-in_channels: 1
-hidden_dim: 32
-channels_multipliers: [1, 2, 4]
-dropout_rate: 0.0
-activation: mish
-use_norm: true
-num_residuals: 4
-residual_channels: 32
diff --git a/training/conf/network/quantizer.yaml b/training/conf/network/quantizer.yaml
deleted file mode 100644
index 827a247..0000000
--- a/training/conf/network/quantizer.yaml
+++ /dev/null
@@ -1,12 +0,0 @@
-_target_: text_recognizer.networks.quantizer.quantizer.VectorQuantizer
-input_dim: 192
-codebook:
- _target_: text_recognizer.networks.quantizer.codebook.CosineSimilarityCodebook
- dim: 16
- codebook_size: 2048
- kmeans_init: true
- kmeans_iters: 10
- decay: 0.8
- eps: 1.0e-5
- threshold_dead: 2
-commitment: 1.0
diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml
deleted file mode 100644
index 22f786f..0000000
--- a/training/conf/network/vqvae.yaml
+++ /dev/null
@@ -1,18 +0,0 @@
-defaults:
- - encoder: vae_encoder
- - decoder: vae_decoder
-
-_target_: text_recognizer.networks.vqvae.vqvae.VQVAE
-quantizer:
- _target_: text_recognizer.networks.quantizer.quantizer.VectorQuantizer
- input_dim: 128
- codebook:
- _target_: text_recognizer.networks.quantizer.codebook.CosineSimilarityCodebook
- dim: 8
- codebook_size: 512
- kmeans_init: true
- kmeans_iters: 10
- decay: 0.8
- eps: 1.0e-5
- threshold_dead: 2
- commitment: 1.0