summaryrefslogtreecommitdiff
path: root/text_recognizer/models
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models')
-rw-r--r--text_recognizer/models/vq_transformer.py94
-rw-r--r--text_recognizer/models/vqgan.py116
-rw-r--r--text_recognizer/models/vqvae.py45
3 files changed, 0 insertions, 255 deletions
diff --git a/text_recognizer/models/vq_transformer.py b/text_recognizer/models/vq_transformer.py
deleted file mode 100644
index 8ec28fd..0000000
--- a/text_recognizer/models/vq_transformer.py
+++ /dev/null
@@ -1,94 +0,0 @@
-"""PyTorch Lightning model for base Transformers."""
-from typing import Tuple, Type
-
-import attr
-import torch
-from torch import Tensor
-
-from text_recognizer.models.transformer import TransformerLitModel
-
-
-@attr.s(auto_attribs=True, eq=False)
-class VqTransformerLitModel(TransformerLitModel):
- """A PyTorch Lightning model for transformer networks."""
-
- def forward(self, data: Tensor) -> Tensor:
- """Forward pass with the transformer network."""
- return self.predict(data)
-
- def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
- """Training step."""
- data, targets = batch
- logits, commitment_loss = self.network(data, targets[:, :-1])
- loss = self.loss_fn(logits, targets[:, 1:]) + commitment_loss
- self.log("train/loss", loss)
- self.log("train/commitment_loss", commitment_loss)
- return loss
-
- def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Validation step."""
- data, targets = batch
- logits, commitment_loss = self.network(data, targets[:, :-1])
- loss = self.loss_fn(logits, targets[:, 1:]) + commitment_loss
- self.log("val/loss", loss, prog_bar=True)
- self.log("val/commitment_loss", commitment_loss)
-
- # Get the token prediction.
- # pred = self(data)
- # self.val_cer(pred, targets)
- # self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
- # self.test_acc(pred, targets)
- # self.log("val/acc", self.test_acc, on_step=False, on_epoch=True)
-
- def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Test step."""
- data, targets = batch
- pred = self(data)
- self.test_cer(pred, targets)
- self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)
- self.test_acc(pred, targets)
- self.log("test/acc", self.test_acc, on_step=False, on_epoch=True)
-
- def predict(self, x: Tensor) -> Tensor:
- """Predicts text in image.
-
- Args:
- x (Tensor): Image(s) to extract text from.
-
- Shapes:
- - x: :math: `(B, H, W)`
- - output: :math: `(B, S)`
-
- Returns:
- Tensor: A tensor of token indices of the predictions from the model.
- """
- bsz = x.shape[0]
-
- # Encode image(s) to latent vectors.
- z, _ = self.network.encode(x)
-
- # Create a placeholder matrix for storing outputs from the network
- output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
- output[:, 0] = self.start_index
-
- for Sy in range(1, self.max_output_len):
- context = output[:, :Sy] # (B, Sy)
- logits = self.network.decode(z, context) # (B, C, Sy)
- tokens = torch.argmax(logits, dim=1) # (B, Sy)
- output[:, Sy : Sy + 1] = tokens[:, -1:]
-
- # Early stopping of prediction loop if token is end or padding token.
- if (
- (output[:, Sy - 1] == self.end_index)
- | (output[:, Sy - 1] == self.pad_index)
- ).all():
- break
-
- # Set all tokens after end token to pad token.
- for Sy in range(1, self.max_output_len):
- idx = (output[:, Sy - 1] == self.end_index) | (
- output[:, Sy - 1] == self.pad_index
- )
- output[idx, Sy] = self.pad_index
-
- return output
diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py
deleted file mode 100644
index 6a90e06..0000000
--- a/text_recognizer/models/vqgan.py
+++ /dev/null
@@ -1,116 +0,0 @@
-"""PyTorch Lightning model for base Transformers."""
-from typing import Tuple
-
-import attr
-from torch import Tensor
-
-from text_recognizer.criterion.vqgan_loss import VQGANLoss
-from text_recognizer.models.base import BaseLitModel
-
-
-@attr.s(auto_attribs=True, eq=False)
-class VQGANLitModel(BaseLitModel):
- """A PyTorch Lightning model for transformer networks."""
-
- loss_fn: VQGANLoss = attr.ib()
- latent_loss_weight: float = attr.ib(default=0.25)
-
- def forward(self, data: Tensor) -> Tensor:
- """Forward pass with the transformer network."""
- return self.network(data)
-
- def training_step(
- self, batch: Tuple[Tensor, Tensor], batch_idx: int, optimizer_idx: int
- ) -> Tensor:
- """Training step."""
- data, _ = batch
- reconstructions, commitment_loss = self(data)
-
- if optimizer_idx == 0:
- loss, log = self.loss_fn(
- data=data,
- reconstructions=reconstructions,
- commitment_loss=commitment_loss,
- decoder_last_layer=self.network.decoder.decoder[-1].weight,
- optimizer_idx=optimizer_idx,
- global_step=self.global_step,
- stage="train",
- )
- self.log(
- "train/loss", loss, prog_bar=True,
- )
- self.log_dict(log, logger=True, on_step=True, on_epoch=True)
- return loss
-
- if optimizer_idx == 1:
- loss, log = self.loss_fn(
- data=data,
- reconstructions=reconstructions,
- commitment_loss=commitment_loss,
- decoder_last_layer=self.network.decoder.decoder[-1].weight,
- optimizer_idx=optimizer_idx,
- global_step=self.global_step,
- stage="train",
- )
- self.log(
- "train/discriminator_loss", loss, prog_bar=True,
- )
- self.log_dict(log, logger=True, on_step=True, on_epoch=True)
- return loss
-
- def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Validation step."""
- data, _ = batch
- reconstructions, commitment_loss = self(data)
-
- loss, log = self.loss_fn(
- data=data,
- reconstructions=reconstructions,
- commitment_loss=commitment_loss,
- decoder_last_layer=self.network.decoder.decoder[-1].weight,
- optimizer_idx=0,
- global_step=self.global_step,
- stage="val",
- )
- self.log(
- "val/loss", loss, prog_bar=True,
- )
- self.log_dict(log)
-
- _, log = self.loss_fn(
- data=data,
- reconstructions=reconstructions,
- commitment_loss=commitment_loss,
- decoder_last_layer=self.network.decoder.decoder[-1].weight,
- optimizer_idx=1,
- global_step=self.global_step,
- stage="val",
- )
- self.log_dict(log)
-
- def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Test step."""
- data, _ = batch
- reconstructions, commitment_loss = self(data)
-
- _, log = self.loss_fn(
- data=data,
- reconstructions=reconstructions,
- commitment_loss=commitment_loss,
- decoder_last_layer=self.network.decoder.decoder[-1].weight,
- optimizer_idx=0,
- global_step=self.global_step,
- stage="test",
- )
- self.log_dict(log)
-
- _, log = self.loss_fn(
- data=data,
- reconstructions=reconstructions,
- commitment_loss=commitment_loss,
- decoder_last_layer=self.network.decoder.decoder[-1].weight,
- optimizer_idx=1,
- global_step=self.global_step,
- stage="test",
- )
- self.log_dict(log)
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
deleted file mode 100644
index 4898852..0000000
--- a/text_recognizer/models/vqvae.py
+++ /dev/null
@@ -1,45 +0,0 @@
-"""PyTorch Lightning model for base Transformers."""
-from typing import Tuple
-
-import attr
-from torch import Tensor
-
-from text_recognizer.models.base import BaseLitModel
-
-
-@attr.s(auto_attribs=True, eq=False)
-class VQVAELitModel(BaseLitModel):
- """A PyTorch Lightning model for transformer networks."""
-
- commitment: float = attr.ib(default=0.25)
-
- def forward(self, data: Tensor) -> Tensor:
- """Forward pass with the transformer network."""
- return self.network(data)
-
- def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
- """Training step."""
- data, _ = batch
- reconstructions, commitment_loss = self(data)
- loss = self.loss_fn(reconstructions, data)
- loss = loss + self.commitment * commitment_loss
- self.log("train/commitment_loss", commitment_loss)
- self.log("train/loss", loss)
- return loss
-
- def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Validation step."""
- data, _ = batch
- reconstructions, commitment_loss = self(data)
- loss = self.loss_fn(reconstructions, data)
- self.log("val/commitment_loss", commitment_loss)
- self.log("val/loss", loss, prog_bar=True)
-
- def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Test step."""
- data, _ = batch
- reconstructions, commitment_loss = self(data)
- loss = self.loss_fn(reconstructions, data)
- loss = loss + self.commitment * commitment_loss
- self.log("test/commitment_loss", commitment_loss)
- self.log("test/loss", loss)