diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-13 18:45:36 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-13 18:45:36 +0200 |
commit | 8b2e5296b290f147935c58207fbfd9674394c7b3 (patch) | |
tree | 9957280d78bf112f41f4ba339ee1a832cfa2acb9 /text_recognizer | |
parent | 191aa5a5c080ed7bf7f6d0408aa1cac4295c57d2 (diff) |
Remove vq and perceiver models
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/models/__init__.py | 2 | ||||
-rw-r--r-- | text_recognizer/models/perceiver.py | 76 | ||||
-rw-r--r-- | text_recognizer/models/vq_transformer.py | 113 |
3 files changed, 0 insertions, 191 deletions
diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py index 56e3e93..cc02487 100644 --- a/text_recognizer/models/__init__.py +++ b/text_recognizer/models/__init__.py @@ -1,4 +1,2 @@ """PyTorch Lightning models modules.""" from text_recognizer.models.transformer import LitTransformer -from text_recognizer.models.perceiver import LitPerceiver -from text_recognizer.models.vq_transformer import LitVqTransformer diff --git a/text_recognizer/models/perceiver.py b/text_recognizer/models/perceiver.py deleted file mode 100644 index c482235..0000000 --- a/text_recognizer/models/perceiver.py +++ /dev/null @@ -1,76 +0,0 @@ -"""Lightning model for base Perceiver.""" -from typing import Optional, Tuple, Type - -from omegaconf import DictConfig -import torch -from torch import nn, Tensor - -from text_recognizer.data.mappings import EmnistMapping -from text_recognizer.models.base import LitBase -from text_recognizer.models.metrics import CharacterErrorRate - - -class LitPerceiver(LitBase): - """A PyTorch Lightning model for transformer networks.""" - - def __init__( - self, - network: Type[nn.Module], - loss_fn: Type[nn.Module], - optimizer_config: DictConfig, - lr_scheduler_config: Optional[DictConfig], - mapping: EmnistMapping, - max_output_len: int = 682, - start_token: str = "<s>", - end_token: str = "<e>", - pad_token: str = "<p>", - ) -> None: - super().__init__( - network, loss_fn, optimizer_config, lr_scheduler_config, mapping - ) - self.max_output_len = max_output_len - self.start_token = start_token - self.end_token = end_token - self.pad_token = pad_token - self.start_index = int(self.mapping.get_index(self.start_token)) - self.end_index = int(self.mapping.get_index(self.end_token)) - self.pad_index = int(self.mapping.get_index(self.pad_token)) - self.ignore_indices = set([self.start_index, self.end_index, self.pad_index]) - self.val_cer = CharacterErrorRate(self.ignore_indices) - self.test_cer = CharacterErrorRate(self.ignore_indices) - - def forward(self, data: Tensor) -> Tensor: - """Forward pass with the transformer network.""" - return self.predict(data) - - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: - """Training step.""" - data, targets = batch - logits = self.network(data) - loss = self.loss_fn(logits, targets) - self.log("train/loss", loss) - return loss - - def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Validation step.""" - data, targets = batch - preds = self.predict(data) - self.val_acc(preds, targets) - self.log("val/acc", self.val_acc, on_step=False, on_epoch=True) - self.val_cer(preds, targets) - self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) - - def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Test step.""" - data, targets = batch - - # Compute the text prediction. - pred = self(data) - self.test_cer(pred, targets) - self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) - self.test_acc(pred, targets) - self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) - - @torch.no_grad() - def predict(self, x: Tensor) -> Tensor: - return self.network(x).argmax(dim=1) diff --git a/text_recognizer/models/vq_transformer.py b/text_recognizer/models/vq_transformer.py deleted file mode 100644 index 99f69c0..0000000 --- a/text_recognizer/models/vq_transformer.py +++ /dev/null @@ -1,113 +0,0 @@ -"""Lightning model for Vector Quantized Transformers.""" -from typing import Optional, Tuple, Type - -from omegaconf import DictConfig -import torch -from torch import nn, Tensor - -from text_recognizer.data.mappings import EmnistMapping -from text_recognizer.models.transformer import LitTransformer - - -class LitVqTransformer(LitTransformer): - """A PyTorch Lightning model for transformer networks.""" - - def __init__( - self, - network: Type[nn.Module], - loss_fn: Type[nn.Module], - optimizer_config: DictConfig, - lr_scheduler_config: Optional[DictConfig], - mapping: EmnistMapping, - max_output_len: int = 682, - start_token: str = "<s>", - end_token: str = "<e>", - pad_token: str = "<p>", - vq_loss_weight: float = 0.1, - ) -> None: - super().__init__( - network, - loss_fn, - optimizer_config, - lr_scheduler_config, - mapping, - max_output_len, - start_token, - end_token, - pad_token, - ) - self.vq_loss_weight = vq_loss_weight - - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: - """Training step.""" - data, targets = batch - logits, vq_loss = self.network(data, targets[:, :-1]) - loss = self.loss_fn(logits, targets[:, 1:]) - total_loss = loss + self.vq_loss_weight * vq_loss - self.log("train/vq_loss", vq_loss) - self.log("train/loss", loss) - self.log("train/total_loss", total_loss) - return total_loss - - def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Validation step.""" - data, targets = batch - preds = self.predict(data) - self.val_acc(preds, targets) - self.log("val/acc", self.val_acc, on_step=False, on_epoch=True) - self.val_cer(preds, targets) - self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) - - def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Test step.""" - data, targets = batch - pred = self(data) - self.test_cer(pred, targets) - self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) - self.test_acc(pred, targets) - self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) - - @torch.no_grad() - def predict(self, x: Tensor) -> Tensor: - """Predicts text in image. - - Args: - x (Tensor): Image(s) to extract text from. - - Shapes: - - x: :math: `(B, H, W)` - - output: :math: `(B, S)` - - Returns: - Tensor: A tensor of token indices of the predictions from the model. - """ - bsz = x.shape[0] - - # Encode image(s) to latent vectors. - z, _ = self.network.encode(x) - - # Create a placeholder matrix for storing outputs from the network - output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) - output[:, 0] = self.start_index - - for Sy in range(1, self.max_output_len): - context = output[:, :Sy] # (B, Sy) - logits = self.network.decode(z, context) # (B, C, Sy) - tokens = torch.argmax(logits, dim=1) # (B, Sy) - output[:, Sy : Sy + 1] = tokens[:, -1:] - - # Early stopping of prediction loop if token is end or padding token. - if ( - (output[:, Sy - 1] == self.end_index) - | (output[:, Sy - 1] == self.pad_index) - ).all(): - break - - # Set all tokens after end token to pad token. - for Sy in range(1, self.max_output_len): - idx = (output[:, Sy - 1] == self.end_index) | ( - output[:, Sy - 1] == self.pad_index - ) - output[idx, Sy] = self.pad_index - - return output |