diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-21 21:34:53 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-21 21:34:53 +0100 |
commit | b44de0e11281c723ec426f8bec8ca0897ecfe3ff (patch) | |
tree | 998841a3a681d3dedfbe8470c1b8544b4dcbe7a2 /text_recognizer/models | |
parent | 3b2fb0fd977a6aff4dcf88e1a0f99faac51e05b1 (diff) |
Remove VQVAE stuff, did not work...
Diffstat (limited to 'text_recognizer/models')
-rw-r--r-- | text_recognizer/models/vq_transformer.py | 94 | ||||
-rw-r--r-- | text_recognizer/models/vqgan.py | 116 | ||||
-rw-r--r-- | text_recognizer/models/vqvae.py | 45 |
3 files changed, 0 insertions, 255 deletions
diff --git a/text_recognizer/models/vq_transformer.py b/text_recognizer/models/vq_transformer.py deleted file mode 100644 index 8ec28fd..0000000 --- a/text_recognizer/models/vq_transformer.py +++ /dev/null @@ -1,94 +0,0 @@ -"""PyTorch Lightning model for base Transformers.""" -from typing import Tuple, Type - -import attr -import torch -from torch import Tensor - -from text_recognizer.models.transformer import TransformerLitModel - - -@attr.s(auto_attribs=True, eq=False) -class VqTransformerLitModel(TransformerLitModel): - """A PyTorch Lightning model for transformer networks.""" - - def forward(self, data: Tensor) -> Tensor: - """Forward pass with the transformer network.""" - return self.predict(data) - - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: - """Training step.""" - data, targets = batch - logits, commitment_loss = self.network(data, targets[:, :-1]) - loss = self.loss_fn(logits, targets[:, 1:]) + commitment_loss - self.log("train/loss", loss) - self.log("train/commitment_loss", commitment_loss) - return loss - - def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Validation step.""" - data, targets = batch - logits, commitment_loss = self.network(data, targets[:, :-1]) - loss = self.loss_fn(logits, targets[:, 1:]) + commitment_loss - self.log("val/loss", loss, prog_bar=True) - self.log("val/commitment_loss", commitment_loss) - - # Get the token prediction. - # pred = self(data) - # self.val_cer(pred, targets) - # self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) - # self.test_acc(pred, targets) - # self.log("val/acc", self.test_acc, on_step=False, on_epoch=True) - - def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Test step.""" - data, targets = batch - pred = self(data) - self.test_cer(pred, targets) - self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) - self.test_acc(pred, targets) - self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) - - def predict(self, x: Tensor) -> Tensor: - """Predicts text in image. - - Args: - x (Tensor): Image(s) to extract text from. - - Shapes: - - x: :math: `(B, H, W)` - - output: :math: `(B, S)` - - Returns: - Tensor: A tensor of token indices of the predictions from the model. - """ - bsz = x.shape[0] - - # Encode image(s) to latent vectors. - z, _ = self.network.encode(x) - - # Create a placeholder matrix for storing outputs from the network - output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) - output[:, 0] = self.start_index - - for Sy in range(1, self.max_output_len): - context = output[:, :Sy] # (B, Sy) - logits = self.network.decode(z, context) # (B, C, Sy) - tokens = torch.argmax(logits, dim=1) # (B, Sy) - output[:, Sy : Sy + 1] = tokens[:, -1:] - - # Early stopping of prediction loop if token is end or padding token. - if ( - (output[:, Sy - 1] == self.end_index) - | (output[:, Sy - 1] == self.pad_index) - ).all(): - break - - # Set all tokens after end token to pad token. - for Sy in range(1, self.max_output_len): - idx = (output[:, Sy - 1] == self.end_index) | ( - output[:, Sy - 1] == self.pad_index - ) - output[idx, Sy] = self.pad_index - - return output diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py deleted file mode 100644 index 6a90e06..0000000 --- a/text_recognizer/models/vqgan.py +++ /dev/null @@ -1,116 +0,0 @@ -"""PyTorch Lightning model for base Transformers.""" -from typing import Tuple - -import attr -from torch import Tensor - -from text_recognizer.criterion.vqgan_loss import VQGANLoss -from text_recognizer.models.base import BaseLitModel - - -@attr.s(auto_attribs=True, eq=False) -class VQGANLitModel(BaseLitModel): - """A PyTorch Lightning model for transformer networks.""" - - loss_fn: VQGANLoss = attr.ib() - latent_loss_weight: float = attr.ib(default=0.25) - - def forward(self, data: Tensor) -> Tensor: - """Forward pass with the transformer network.""" - return self.network(data) - - def training_step( - self, batch: Tuple[Tensor, Tensor], batch_idx: int, optimizer_idx: int - ) -> Tensor: - """Training step.""" - data, _ = batch - reconstructions, commitment_loss = self(data) - - if optimizer_idx == 0: - loss, log = self.loss_fn( - data=data, - reconstructions=reconstructions, - commitment_loss=commitment_loss, - decoder_last_layer=self.network.decoder.decoder[-1].weight, - optimizer_idx=optimizer_idx, - global_step=self.global_step, - stage="train", - ) - self.log( - "train/loss", loss, prog_bar=True, - ) - self.log_dict(log, logger=True, on_step=True, on_epoch=True) - return loss - - if optimizer_idx == 1: - loss, log = self.loss_fn( - data=data, - reconstructions=reconstructions, - commitment_loss=commitment_loss, - decoder_last_layer=self.network.decoder.decoder[-1].weight, - optimizer_idx=optimizer_idx, - global_step=self.global_step, - stage="train", - ) - self.log( - "train/discriminator_loss", loss, prog_bar=True, - ) - self.log_dict(log, logger=True, on_step=True, on_epoch=True) - return loss - - def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Validation step.""" - data, _ = batch - reconstructions, commitment_loss = self(data) - - loss, log = self.loss_fn( - data=data, - reconstructions=reconstructions, - commitment_loss=commitment_loss, - decoder_last_layer=self.network.decoder.decoder[-1].weight, - optimizer_idx=0, - global_step=self.global_step, - stage="val", - ) - self.log( - "val/loss", loss, prog_bar=True, - ) - self.log_dict(log) - - _, log = self.loss_fn( - data=data, - reconstructions=reconstructions, - commitment_loss=commitment_loss, - decoder_last_layer=self.network.decoder.decoder[-1].weight, - optimizer_idx=1, - global_step=self.global_step, - stage="val", - ) - self.log_dict(log) - - def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Test step.""" - data, _ = batch - reconstructions, commitment_loss = self(data) - - _, log = self.loss_fn( - data=data, - reconstructions=reconstructions, - commitment_loss=commitment_loss, - decoder_last_layer=self.network.decoder.decoder[-1].weight, - optimizer_idx=0, - global_step=self.global_step, - stage="test", - ) - self.log_dict(log) - - _, log = self.loss_fn( - data=data, - reconstructions=reconstructions, - commitment_loss=commitment_loss, - decoder_last_layer=self.network.decoder.decoder[-1].weight, - optimizer_idx=1, - global_step=self.global_step, - stage="test", - ) - self.log_dict(log) diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py deleted file mode 100644 index 4898852..0000000 --- a/text_recognizer/models/vqvae.py +++ /dev/null @@ -1,45 +0,0 @@ -"""PyTorch Lightning model for base Transformers.""" -from typing import Tuple - -import attr -from torch import Tensor - -from text_recognizer.models.base import BaseLitModel - - -@attr.s(auto_attribs=True, eq=False) -class VQVAELitModel(BaseLitModel): - """A PyTorch Lightning model for transformer networks.""" - - commitment: float = attr.ib(default=0.25) - - def forward(self, data: Tensor) -> Tensor: - """Forward pass with the transformer network.""" - return self.network(data) - - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: - """Training step.""" - data, _ = batch - reconstructions, commitment_loss = self(data) - loss = self.loss_fn(reconstructions, data) - loss = loss + self.commitment * commitment_loss - self.log("train/commitment_loss", commitment_loss) - self.log("train/loss", loss) - return loss - - def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Validation step.""" - data, _ = batch - reconstructions, commitment_loss = self(data) - loss = self.loss_fn(reconstructions, data) - self.log("val/commitment_loss", commitment_loss) - self.log("val/loss", loss, prog_bar=True) - - def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Test step.""" - data, _ = batch - reconstructions, commitment_loss = self(data) - loss = self.loss_fn(reconstructions, data) - loss = loss + self.commitment * commitment_loss - self.log("test/commitment_loss", commitment_loss) - self.log("test/loss", loss) |