summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conv_transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
-rw-r--r--text_recognizer/networks/conv_transformer.py69
1 files changed, 9 insertions, 60 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index 4acdc36..7371be4 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -1,13 +1,10 @@
"""Vision transformer for character recognition."""
import math
-from typing import Tuple, Type
+from typing import Tuple
import attr
-import torch
from torch import nn, Tensor
-from text_recognizer.data.mappings import AbstractMapping
-from text_recognizer.networks.base import BaseNetwork
from text_recognizer.networks.encoders.efficientnet import EfficientNet
from text_recognizer.networks.transformer.layers import Decoder
from text_recognizer.networks.transformer.positional_encodings import (
@@ -16,25 +13,24 @@ from text_recognizer.networks.transformer.positional_encodings import (
)
-@attr.s(auto_attribs=True)
-class ConvTransformer(BaseNetwork):
+@attr.s
+class ConvTransformer(nn.Module):
+ """Convolutional encoder and transformer decoder network."""
+
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
# Parameters and placeholders,
input_dims: Tuple[int, int, int] = attr.ib()
hidden_dim: int = attr.ib()
dropout_rate: float = attr.ib()
max_output_len: int = attr.ib()
num_classes: int = attr.ib()
- start_token: str = attr.ib()
- start_index: Tensor = attr.ib(init=False)
- end_token: str = attr.ib()
- end_index: Tensor = attr.ib(init=False)
- pad_token: str = attr.ib()
- pad_index: Tensor = attr.ib(init=False)
+ pad_index: Tensor = attr.ib()
# Modules.
encoder: EfficientNet = attr.ib()
decoder: Decoder = attr.ib()
- mapping: Type[AbstractMapping] = attr.ib()
latent_encoder: nn.Sequential = attr.ib(init=False)
token_embedding: nn.Embedding = attr.ib(init=False)
@@ -43,10 +39,6 @@ class ConvTransformer(BaseNetwork):
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
- self.start_index = self.mapping.get_index(self.start_token)
- self.end_index = self.mapping.get_index(self.end_token)
- self.pad_index = self.mapping.get_index(self.pad_token)
-
# Latent projector for down sampling number of filters and 2d
# positional encoding.
self.latent_encoder = nn.Sequential(
@@ -156,46 +148,3 @@ class ConvTransformer(BaseNetwork):
z = self.encode(x)
logits = self.decode(z, context)
return logits
-
- def predict(self, x: Tensor) -> Tensor:
- """Predicts text in image.
-
- Args:
- x (Tensor): Image(s) to extract text from.
-
- Shapes:
- - x: :math: `(B, H, W)`
- - output: :math: `(B, S)`
-
- Returns:
- Tensor: A tensor of token indices of the predictions from the model.
- """
- bsz = x.shape[0]
-
- # Encode image(s) to latent vectors.
- z = self.encode(x)
-
- # Create a placeholder matrix for storing outputs from the network
- output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
- output[:, 0] = self.start_index
-
- for i in range(1, self.max_output_len):
- context = output[:, :i] # (bsz, i)
- logits = self.decode(z, context) # (i, bsz, c)
- tokens = torch.argmax(logits, dim=-1) # (i, bsz)
- output[:, i : i + 1] = tokens[-1:]
-
- # Early stopping of prediction loop if token is end or padding token.
- if (
- output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index
- ).all():
- break
-
- # Set all tokens after end token to pad token.
- for i in range(1, self.max_output_len):
- idx = (
- output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index
- )
- output[idx, i] = self.pad_index
-
- return output