summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/models/perceiver.py76
-rw-r--r--text_recognizer/networks/conv_perceiver.py62
-rw-r--r--text_recognizer/networks/perceiver/__init__.py1
-rw-r--r--text_recognizer/networks/perceiver/attention.py10
-rw-r--r--text_recognizer/networks/perceiver/perceiver.py22
-rw-r--r--training/conf/model/lit_perceiver.yaml5
-rw-r--r--training/conf/network/conv_perceiver.yaml30
7 files changed, 192 insertions, 14 deletions
diff --git a/text_recognizer/models/perceiver.py b/text_recognizer/models/perceiver.py
new file mode 100644
index 0000000..c482235
--- /dev/null
+++ b/text_recognizer/models/perceiver.py
@@ -0,0 +1,76 @@
+"""Lightning model for base Perceiver."""
+from typing import Optional, Tuple, Type
+
+from omegaconf import DictConfig
+import torch
+from torch import nn, Tensor
+
+from text_recognizer.data.mappings import EmnistMapping
+from text_recognizer.models.base import LitBase
+from text_recognizer.models.metrics import CharacterErrorRate
+
+
+class LitPerceiver(LitBase):
+ """A PyTorch Lightning model for transformer networks."""
+
+ def __init__(
+ self,
+ network: Type[nn.Module],
+ loss_fn: Type[nn.Module],
+ optimizer_config: DictConfig,
+ lr_scheduler_config: Optional[DictConfig],
+ mapping: EmnistMapping,
+ max_output_len: int = 682,
+ start_token: str = "<s>",
+ end_token: str = "<e>",
+ pad_token: str = "<p>",
+ ) -> None:
+ super().__init__(
+ network, loss_fn, optimizer_config, lr_scheduler_config, mapping
+ )
+ self.max_output_len = max_output_len
+ self.start_token = start_token
+ self.end_token = end_token
+ self.pad_token = pad_token
+ self.start_index = int(self.mapping.get_index(self.start_token))
+ self.end_index = int(self.mapping.get_index(self.end_token))
+ self.pad_index = int(self.mapping.get_index(self.pad_token))
+ self.ignore_indices = set([self.start_index, self.end_index, self.pad_index])
+ self.val_cer = CharacterErrorRate(self.ignore_indices)
+ self.test_cer = CharacterErrorRate(self.ignore_indices)
+
+ def forward(self, data: Tensor) -> Tensor:
+ """Forward pass with the transformer network."""
+ return self.predict(data)
+
+ def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
+ """Training step."""
+ data, targets = batch
+ logits = self.network(data)
+ loss = self.loss_fn(logits, targets)
+ self.log("train/loss", loss)
+ return loss
+
+ def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+ """Validation step."""
+ data, targets = batch
+ preds = self.predict(data)
+ self.val_acc(preds, targets)
+ self.log("val/acc", self.val_acc, on_step=False, on_epoch=True)
+ self.val_cer(preds, targets)
+ self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
+
+ def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+ """Test step."""
+ data, targets = batch
+
+ # Compute the text prediction.
+ pred = self(data)
+ self.test_cer(pred, targets)
+ self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)
+ self.test_acc(pred, targets)
+ self.log("test/acc", self.test_acc, on_step=False, on_epoch=True)
+
+ @torch.no_grad()
+ def predict(self, x: Tensor) -> Tensor:
+ return self.network(x).argmax(dim=1)
diff --git a/text_recognizer/networks/conv_perceiver.py b/text_recognizer/networks/conv_perceiver.py
new file mode 100644
index 0000000..cda5e0a
--- /dev/null
+++ b/text_recognizer/networks/conv_perceiver.py
@@ -0,0 +1,62 @@
+"""Perceiver network module."""
+from typing import Optional, Tuple, Type
+
+from torch import nn, Tensor
+
+from text_recognizer.networks.perceiver.perceiver import PerceiverIO
+from text_recognizer.networks.transformer.embeddings.axial import (
+ AxialPositionalEmbedding,
+)
+
+
+class ConvPerceiver(nn.Module):
+ """Base transformer network."""
+
+ def __init__(
+ self,
+ input_dims: Tuple[int, int, int],
+ hidden_dim: int,
+ queries_dim: int,
+ num_classes: int,
+ pad_index: Tensor,
+ encoder: Type[nn.Module],
+ decoder: PerceiverIO,
+ max_length: int,
+ pixel_embedding: AxialPositionalEmbedding,
+ ) -> None:
+ super().__init__()
+ self.input_dims = input_dims
+ self.hidden_dim = hidden_dim
+ self.num_classes = num_classes
+ self.pad_index = pad_index
+ self.max_length = max_length
+ self.encoder = encoder
+ self.decoder = decoder
+ self.pixel_embedding = pixel_embedding
+ self.to_queries = nn.Linear(self.hidden_dim, queries_dim)
+ self.conv = nn.Conv2d(
+ in_channels=self.encoder.out_channels,
+ out_channels=self.hidden_dim,
+ kernel_size=1,
+ )
+
+ def encode(self, x: Tensor) -> Tensor:
+ z = self.encoder(x)
+ z = self.conv(z)
+ z = self.pixel_embedding(z)
+ z = z.flatten(start_dim=2)
+
+ # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
+ z = z.permute(0, 2, 1)
+ return z
+
+ def decode(self, z: Tensor) -> Tensor:
+ queries = self.to_queries(z[:, : self.max_length, :])
+ logits = self.decoder(data=z, queries=queries)
+ logits = logits.permute(0, 2, 1) # [B, C, Sy]
+ return logits
+
+ def forward(self, x: Tensor) -> Tensor:
+ z = self.encode(x)
+ logits = self.decode(z)
+ return logits
diff --git a/text_recognizer/networks/perceiver/__init__.py b/text_recognizer/networks/perceiver/__init__.py
index e69de29..ac2c102 100644
--- a/text_recognizer/networks/perceiver/__init__.py
+++ b/text_recognizer/networks/perceiver/__init__.py
@@ -0,0 +1 @@
+from text_recognizer.networks.perceiver.perceiver import PerceiverIO
diff --git a/text_recognizer/networks/perceiver/attention.py b/text_recognizer/networks/perceiver/attention.py
index 66aeaa8..0ee51b1 100644
--- a/text_recognizer/networks/perceiver/attention.py
+++ b/text_recognizer/networks/perceiver/attention.py
@@ -36,11 +36,11 @@ class Attention(nn.Module):
q, v, k = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
- if mask is not None:
- mask = rearrange(mask, "b ... -> b (...)")
- max_neg_val = -torch.finfo(sim.dtype).max
- mask = repeat(mask, "b j -> (b h) () j", h=h)
- sim.masked_fill_(~mask, max_neg_val)
+ # if mask is not None:
+ # mask = rearrange(mask, "b ... -> b (...)")
+ # max_neg_val = -torch.finfo(sim.dtype).max
+ # mask = repeat(mask, "b j -> (b h) () j", h=h)
+ # sim.masked_fill_(~mask, max_neg_val)
attn = sim.softmax(dim=-1)
out = einsum("b i j, b j d -> b i d", attn, v)
diff --git a/text_recognizer/networks/perceiver/perceiver.py b/text_recognizer/networks/perceiver/perceiver.py
index 65ee20c..d4bca0b 100644
--- a/text_recognizer/networks/perceiver/perceiver.py
+++ b/text_recognizer/networks/perceiver/perceiver.py
@@ -2,9 +2,9 @@
A copy from lucidrains.
"""
-from itertools import repeat
from typing import Optional
+from einops import repeat, rearrange
import torch
from torch import nn, Tensor
@@ -44,13 +44,17 @@ class PerceiverIO(nn.Module):
self.layers = nn.ModuleList(
[
- [
- PreNorm(
- latent_dim,
- Attention(latent_dim, heads=latent_heads, dim_head=latent_dim),
- ),
- PreNorm(latent_dim, FeedForward(latent_dim)),
- ]
+ nn.ModuleList(
+ [
+ PreNorm(
+ latent_dim,
+ Attention(
+ latent_dim, heads=latent_heads, dim_head=latent_dim
+ ),
+ ),
+ PreNorm(latent_dim, FeedForward(latent_dim)),
+ ]
+ )
for _ in range(depth)
]
)
@@ -69,7 +73,7 @@ class PerceiverIO(nn.Module):
self, data: Tensor, queries: Tensor, mask: Optional[Tensor] = None
) -> Tensor:
b = data.shape[0]
- x = repeat(self.latents, "nd -> bnd", b=b)
+ x = repeat(self.latents, "n d -> b n d", b=b)
cross_attn, cross_ff = self.cross_attn_block
diff --git a/training/conf/model/lit_perceiver.yaml b/training/conf/model/lit_perceiver.yaml
new file mode 100644
index 0000000..6d1ec82
--- /dev/null
+++ b/training/conf/model/lit_perceiver.yaml
@@ -0,0 +1,5 @@
+_target_: text_recognizer.models.LitPerceiver
+max_output_len: 682
+start_token: <s>
+end_token: <e>
+pad_token: <p>
diff --git a/training/conf/network/conv_perceiver.yaml b/training/conf/network/conv_perceiver.yaml
new file mode 100644
index 0000000..e6906fa
--- /dev/null
+++ b/training/conf/network/conv_perceiver.yaml
@@ -0,0 +1,30 @@
+_target_: text_recognizer.networks.ConvPerceiver
+input_dims: [1, 1, 576, 640]
+hidden_dim: &hidden_dim 144
+num_classes: &num_classes 58
+queries_dim: &queries_dim 16
+max_length: 89
+pad_index: 3
+encoder:
+ _target_: text_recognizer.networks.EfficientNet
+ arch: b0
+ stochastic_dropout_rate: 0.2
+ bn_momentum: 0.99
+ bn_eps: 1.0e-3
+ depth: 5
+ out_channels: *hidden_dim
+decoder:
+ _target_: text_recognizer.networks.perceiver.PerceiverIO
+ dim: *hidden_dim
+ cross_heads: 1
+ cross_head_dim: 64
+ num_latents: 256
+ latent_dim: 512
+ latent_heads: 8
+ depth: 6
+ queries_dim: *queries_dim
+ logits_dim: *num_classes
+pixel_embedding:
+ _target_: text_recognizer.networks.transformer.AxialPositionalEmbedding
+ dim: *hidden_dim
+ shape: [3, 64]