diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-03 12:13:02 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-03 12:13:02 +0200 |
commit | 73ccaaa24936faed36fcc467532baa5386d402ae (patch) | |
tree | c7230fff21b8a780c2b0cd8a5d610075cbb7f21e /text_recognizer | |
parent | 5dd76ca9a3ff35c57cbc7c607afbdb4ee1c8b36f (diff) |
Update perceiver
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/models/perceiver.py | 76 | ||||
-rw-r--r-- | text_recognizer/networks/conv_perceiver.py | 62 | ||||
-rw-r--r-- | text_recognizer/networks/perceiver/__init__.py | 1 | ||||
-rw-r--r-- | text_recognizer/networks/perceiver/attention.py | 10 | ||||
-rw-r--r-- | text_recognizer/networks/perceiver/perceiver.py | 22 |
5 files changed, 157 insertions, 14 deletions
diff --git a/text_recognizer/models/perceiver.py b/text_recognizer/models/perceiver.py new file mode 100644 index 0000000..c482235 --- /dev/null +++ b/text_recognizer/models/perceiver.py @@ -0,0 +1,76 @@ +"""Lightning model for base Perceiver.""" +from typing import Optional, Tuple, Type + +from omegaconf import DictConfig +import torch +from torch import nn, Tensor + +from text_recognizer.data.mappings import EmnistMapping +from text_recognizer.models.base import LitBase +from text_recognizer.models.metrics import CharacterErrorRate + + +class LitPerceiver(LitBase): + """A PyTorch Lightning model for transformer networks.""" + + def __init__( + self, + network: Type[nn.Module], + loss_fn: Type[nn.Module], + optimizer_config: DictConfig, + lr_scheduler_config: Optional[DictConfig], + mapping: EmnistMapping, + max_output_len: int = 682, + start_token: str = "<s>", + end_token: str = "<e>", + pad_token: str = "<p>", + ) -> None: + super().__init__( + network, loss_fn, optimizer_config, lr_scheduler_config, mapping + ) + self.max_output_len = max_output_len + self.start_token = start_token + self.end_token = end_token + self.pad_token = pad_token + self.start_index = int(self.mapping.get_index(self.start_token)) + self.end_index = int(self.mapping.get_index(self.end_token)) + self.pad_index = int(self.mapping.get_index(self.pad_token)) + self.ignore_indices = set([self.start_index, self.end_index, self.pad_index]) + self.val_cer = CharacterErrorRate(self.ignore_indices) + self.test_cer = CharacterErrorRate(self.ignore_indices) + + def forward(self, data: Tensor) -> Tensor: + """Forward pass with the transformer network.""" + return self.predict(data) + + def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: + """Training step.""" + data, targets = batch + logits = self.network(data) + loss = self.loss_fn(logits, targets) + self.log("train/loss", loss) + return loss + + def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Validation step.""" + data, targets = batch + preds = self.predict(data) + self.val_acc(preds, targets) + self.log("val/acc", self.val_acc, on_step=False, on_epoch=True) + self.val_cer(preds, targets) + self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) + + def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Test step.""" + data, targets = batch + + # Compute the text prediction. + pred = self(data) + self.test_cer(pred, targets) + self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) + self.test_acc(pred, targets) + self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) + + @torch.no_grad() + def predict(self, x: Tensor) -> Tensor: + return self.network(x).argmax(dim=1) diff --git a/text_recognizer/networks/conv_perceiver.py b/text_recognizer/networks/conv_perceiver.py new file mode 100644 index 0000000..cda5e0a --- /dev/null +++ b/text_recognizer/networks/conv_perceiver.py @@ -0,0 +1,62 @@ +"""Perceiver network module.""" +from typing import Optional, Tuple, Type + +from torch import nn, Tensor + +from text_recognizer.networks.perceiver.perceiver import PerceiverIO +from text_recognizer.networks.transformer.embeddings.axial import ( + AxialPositionalEmbedding, +) + + +class ConvPerceiver(nn.Module): + """Base transformer network.""" + + def __init__( + self, + input_dims: Tuple[int, int, int], + hidden_dim: int, + queries_dim: int, + num_classes: int, + pad_index: Tensor, + encoder: Type[nn.Module], + decoder: PerceiverIO, + max_length: int, + pixel_embedding: AxialPositionalEmbedding, + ) -> None: + super().__init__() + self.input_dims = input_dims + self.hidden_dim = hidden_dim + self.num_classes = num_classes + self.pad_index = pad_index + self.max_length = max_length + self.encoder = encoder + self.decoder = decoder + self.pixel_embedding = pixel_embedding + self.to_queries = nn.Linear(self.hidden_dim, queries_dim) + self.conv = nn.Conv2d( + in_channels=self.encoder.out_channels, + out_channels=self.hidden_dim, + kernel_size=1, + ) + + def encode(self, x: Tensor) -> Tensor: + z = self.encoder(x) + z = self.conv(z) + z = self.pixel_embedding(z) + z = z.flatten(start_dim=2) + + # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] + z = z.permute(0, 2, 1) + return z + + def decode(self, z: Tensor) -> Tensor: + queries = self.to_queries(z[:, : self.max_length, :]) + logits = self.decoder(data=z, queries=queries) + logits = logits.permute(0, 2, 1) # [B, C, Sy] + return logits + + def forward(self, x: Tensor) -> Tensor: + z = self.encode(x) + logits = self.decode(z) + return logits diff --git a/text_recognizer/networks/perceiver/__init__.py b/text_recognizer/networks/perceiver/__init__.py index e69de29..ac2c102 100644 --- a/text_recognizer/networks/perceiver/__init__.py +++ b/text_recognizer/networks/perceiver/__init__.py @@ -0,0 +1 @@ +from text_recognizer.networks.perceiver.perceiver import PerceiverIO diff --git a/text_recognizer/networks/perceiver/attention.py b/text_recognizer/networks/perceiver/attention.py index 66aeaa8..0ee51b1 100644 --- a/text_recognizer/networks/perceiver/attention.py +++ b/text_recognizer/networks/perceiver/attention.py @@ -36,11 +36,11 @@ class Attention(nn.Module): q, v, k = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) sim = einsum("b i d, b j d -> b i j", q, k) * self.scale - if mask is not None: - mask = rearrange(mask, "b ... -> b (...)") - max_neg_val = -torch.finfo(sim.dtype).max - mask = repeat(mask, "b j -> (b h) () j", h=h) - sim.masked_fill_(~mask, max_neg_val) + # if mask is not None: + # mask = rearrange(mask, "b ... -> b (...)") + # max_neg_val = -torch.finfo(sim.dtype).max + # mask = repeat(mask, "b j -> (b h) () j", h=h) + # sim.masked_fill_(~mask, max_neg_val) attn = sim.softmax(dim=-1) out = einsum("b i j, b j d -> b i d", attn, v) diff --git a/text_recognizer/networks/perceiver/perceiver.py b/text_recognizer/networks/perceiver/perceiver.py index 65ee20c..d4bca0b 100644 --- a/text_recognizer/networks/perceiver/perceiver.py +++ b/text_recognizer/networks/perceiver/perceiver.py @@ -2,9 +2,9 @@ A copy from lucidrains. """ -from itertools import repeat from typing import Optional +from einops import repeat, rearrange import torch from torch import nn, Tensor @@ -44,13 +44,17 @@ class PerceiverIO(nn.Module): self.layers = nn.ModuleList( [ - [ - PreNorm( - latent_dim, - Attention(latent_dim, heads=latent_heads, dim_head=latent_dim), - ), - PreNorm(latent_dim, FeedForward(latent_dim)), - ] + nn.ModuleList( + [ + PreNorm( + latent_dim, + Attention( + latent_dim, heads=latent_heads, dim_head=latent_dim + ), + ), + PreNorm(latent_dim, FeedForward(latent_dim)), + ] + ) for _ in range(depth) ] ) @@ -69,7 +73,7 @@ class PerceiverIO(nn.Module): self, data: Tensor, queries: Tensor, mask: Optional[Tensor] = None ) -> Tensor: b = data.shape[0] - x = repeat(self.latents, "nd -> bnd", b=b) + x = repeat(self.latents, "n d -> b n d", b=b) cross_attn, cross_ff = self.cross_attn_block |