diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-21 21:34:53 +0100 | 
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-21 21:34:53 +0100 | 
| commit | b44de0e11281c723ec426f8bec8ca0897ecfe3ff (patch) | |
| tree | 998841a3a681d3dedfbe8470c1b8544b4dcbe7a2 /text_recognizer/models | |
| parent | 3b2fb0fd977a6aff4dcf88e1a0f99faac51e05b1 (diff) | |
Remove VQVAE stuff, did not work...
Diffstat (limited to 'text_recognizer/models')
| -rw-r--r-- | text_recognizer/models/vq_transformer.py | 94 | ||||
| -rw-r--r-- | text_recognizer/models/vqgan.py | 116 | ||||
| -rw-r--r-- | text_recognizer/models/vqvae.py | 45 | 
3 files changed, 0 insertions, 255 deletions
diff --git a/text_recognizer/models/vq_transformer.py b/text_recognizer/models/vq_transformer.py deleted file mode 100644 index 8ec28fd..0000000 --- a/text_recognizer/models/vq_transformer.py +++ /dev/null @@ -1,94 +0,0 @@ -"""PyTorch Lightning model for base Transformers.""" -from typing import Tuple, Type - -import attr -import torch -from torch import Tensor - -from text_recognizer.models.transformer import TransformerLitModel - - -@attr.s(auto_attribs=True, eq=False) -class VqTransformerLitModel(TransformerLitModel): -    """A PyTorch Lightning model for transformer networks.""" - -    def forward(self, data: Tensor) -> Tensor: -        """Forward pass with the transformer network.""" -        return self.predict(data) - -    def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: -        """Training step.""" -        data, targets = batch -        logits, commitment_loss = self.network(data, targets[:, :-1]) -        loss = self.loss_fn(logits, targets[:, 1:]) + commitment_loss -        self.log("train/loss", loss) -        self.log("train/commitment_loss", commitment_loss) -        return loss - -    def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: -        """Validation step.""" -        data, targets = batch -        logits, commitment_loss = self.network(data, targets[:, :-1]) -        loss = self.loss_fn(logits, targets[:, 1:]) + commitment_loss -        self.log("val/loss", loss, prog_bar=True) -        self.log("val/commitment_loss", commitment_loss) - -        # Get the token prediction. -        # pred = self(data) -        # self.val_cer(pred, targets) -        # self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) -        # self.test_acc(pred, targets) -        # self.log("val/acc", self.test_acc, on_step=False, on_epoch=True) - -    def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: -        """Test step.""" -        data, targets = batch -        pred = self(data) -        self.test_cer(pred, targets) -        self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) -        self.test_acc(pred, targets) -        self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) - -    def predict(self, x: Tensor) -> Tensor: -        """Predicts text in image. -         -        Args: -            x (Tensor): Image(s) to extract text from. - -        Shapes: -            - x: :math: `(B, H, W)` -            - output: :math: `(B, S)` - -        Returns: -            Tensor: A tensor of token indices of the predictions from the model. -        """ -        bsz = x.shape[0] - -        # Encode image(s) to latent vectors. -        z, _ = self.network.encode(x) - -        # Create a placeholder matrix for storing outputs from the network -        output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) -        output[:, 0] = self.start_index - -        for Sy in range(1, self.max_output_len): -            context = output[:, :Sy]  # (B, Sy) -            logits = self.network.decode(z, context)  # (B, C, Sy) -            tokens = torch.argmax(logits, dim=1)  # (B, Sy) -            output[:, Sy : Sy + 1] = tokens[:, -1:] - -            # Early stopping of prediction loop if token is end or padding token. -            if ( -                (output[:, Sy - 1] == self.end_index) -                | (output[:, Sy - 1] == self.pad_index) -            ).all(): -                break - -        # Set all tokens after end token to pad token. -        for Sy in range(1, self.max_output_len): -            idx = (output[:, Sy - 1] == self.end_index) | ( -                output[:, Sy - 1] == self.pad_index -            ) -            output[idx, Sy] = self.pad_index - -        return output diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py deleted file mode 100644 index 6a90e06..0000000 --- a/text_recognizer/models/vqgan.py +++ /dev/null @@ -1,116 +0,0 @@ -"""PyTorch Lightning model for base Transformers.""" -from typing import Tuple - -import attr -from torch import Tensor - -from text_recognizer.criterion.vqgan_loss import VQGANLoss -from text_recognizer.models.base import BaseLitModel - - -@attr.s(auto_attribs=True, eq=False) -class VQGANLitModel(BaseLitModel): -    """A PyTorch Lightning model for transformer networks.""" - -    loss_fn: VQGANLoss = attr.ib() -    latent_loss_weight: float = attr.ib(default=0.25) - -    def forward(self, data: Tensor) -> Tensor: -        """Forward pass with the transformer network.""" -        return self.network(data) - -    def training_step( -        self, batch: Tuple[Tensor, Tensor], batch_idx: int, optimizer_idx: int -    ) -> Tensor: -        """Training step.""" -        data, _ = batch -        reconstructions, commitment_loss = self(data) - -        if optimizer_idx == 0: -            loss, log = self.loss_fn( -                data=data, -                reconstructions=reconstructions, -                commitment_loss=commitment_loss, -                decoder_last_layer=self.network.decoder.decoder[-1].weight, -                optimizer_idx=optimizer_idx, -                global_step=self.global_step, -                stage="train", -            ) -            self.log( -                "train/loss", loss, prog_bar=True, -            ) -            self.log_dict(log, logger=True, on_step=True, on_epoch=True) -            return loss - -        if optimizer_idx == 1: -            loss, log = self.loss_fn( -                data=data, -                reconstructions=reconstructions, -                commitment_loss=commitment_loss, -                decoder_last_layer=self.network.decoder.decoder[-1].weight, -                optimizer_idx=optimizer_idx, -                global_step=self.global_step, -                stage="train", -            ) -            self.log( -                "train/discriminator_loss", loss, prog_bar=True, -            ) -            self.log_dict(log, logger=True, on_step=True, on_epoch=True) -            return loss - -    def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: -        """Validation step.""" -        data, _ = batch -        reconstructions, commitment_loss = self(data) - -        loss, log = self.loss_fn( -            data=data, -            reconstructions=reconstructions, -            commitment_loss=commitment_loss, -            decoder_last_layer=self.network.decoder.decoder[-1].weight, -            optimizer_idx=0, -            global_step=self.global_step, -            stage="val", -        ) -        self.log( -            "val/loss", loss, prog_bar=True, -        ) -        self.log_dict(log) - -        _, log = self.loss_fn( -            data=data, -            reconstructions=reconstructions, -            commitment_loss=commitment_loss, -            decoder_last_layer=self.network.decoder.decoder[-1].weight, -            optimizer_idx=1, -            global_step=self.global_step, -            stage="val", -        ) -        self.log_dict(log) - -    def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: -        """Test step.""" -        data, _ = batch -        reconstructions, commitment_loss = self(data) - -        _, log = self.loss_fn( -            data=data, -            reconstructions=reconstructions, -            commitment_loss=commitment_loss, -            decoder_last_layer=self.network.decoder.decoder[-1].weight, -            optimizer_idx=0, -            global_step=self.global_step, -            stage="test", -        ) -        self.log_dict(log) - -        _, log = self.loss_fn( -            data=data, -            reconstructions=reconstructions, -            commitment_loss=commitment_loss, -            decoder_last_layer=self.network.decoder.decoder[-1].weight, -            optimizer_idx=1, -            global_step=self.global_step, -            stage="test", -        ) -        self.log_dict(log) diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py deleted file mode 100644 index 4898852..0000000 --- a/text_recognizer/models/vqvae.py +++ /dev/null @@ -1,45 +0,0 @@ -"""PyTorch Lightning model for base Transformers.""" -from typing import Tuple - -import attr -from torch import Tensor - -from text_recognizer.models.base import BaseLitModel - - -@attr.s(auto_attribs=True, eq=False) -class VQVAELitModel(BaseLitModel): -    """A PyTorch Lightning model for transformer networks.""" - -    commitment: float = attr.ib(default=0.25) - -    def forward(self, data: Tensor) -> Tensor: -        """Forward pass with the transformer network.""" -        return self.network(data) - -    def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: -        """Training step.""" -        data, _ = batch -        reconstructions, commitment_loss = self(data) -        loss = self.loss_fn(reconstructions, data) -        loss = loss + self.commitment * commitment_loss -        self.log("train/commitment_loss", commitment_loss) -        self.log("train/loss", loss) -        return loss - -    def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: -        """Validation step.""" -        data, _ = batch -        reconstructions, commitment_loss = self(data) -        loss = self.loss_fn(reconstructions, data) -        self.log("val/commitment_loss", commitment_loss) -        self.log("val/loss", loss, prog_bar=True) - -    def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: -        """Test step.""" -        data, _ = batch -        reconstructions, commitment_loss = self(data) -        loss = self.loss_fn(reconstructions, data) -        loss = loss + self.commitment * commitment_loss -        self.log("test/commitment_loss", commitment_loss) -        self.log("test/loss", loss)  |