diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-30 23:15:03 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-30 23:15:03 +0200 |
commit | 7268035fb9e57342612a8cc50a1fe04e8841ca2f (patch) | |
tree | 8d4cf3743975bd25f2c04d6a56ff3d4608a7e8d9 /text_recognizer/networks/conv_transformer.py | |
parent | 92fc1c7ed2f9f64552be8f71d9b8ab0d5a0a88d4 (diff) |
attr bug fix, properly loading network
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 69 |
1 files changed, 9 insertions, 60 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 4acdc36..7371be4 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -1,13 +1,10 @@ """Vision transformer for character recognition.""" import math -from typing import Tuple, Type +from typing import Tuple import attr -import torch from torch import nn, Tensor -from text_recognizer.data.mappings import AbstractMapping -from text_recognizer.networks.base import BaseNetwork from text_recognizer.networks.encoders.efficientnet import EfficientNet from text_recognizer.networks.transformer.layers import Decoder from text_recognizer.networks.transformer.positional_encodings import ( @@ -16,25 +13,24 @@ from text_recognizer.networks.transformer.positional_encodings import ( ) -@attr.s(auto_attribs=True) -class ConvTransformer(BaseNetwork): +@attr.s +class ConvTransformer(nn.Module): + """Convolutional encoder and transformer decoder network.""" + + def __attrs_pre_init__(self) -> None: + super().__init__() + # Parameters and placeholders, input_dims: Tuple[int, int, int] = attr.ib() hidden_dim: int = attr.ib() dropout_rate: float = attr.ib() max_output_len: int = attr.ib() num_classes: int = attr.ib() - start_token: str = attr.ib() - start_index: Tensor = attr.ib(init=False) - end_token: str = attr.ib() - end_index: Tensor = attr.ib(init=False) - pad_token: str = attr.ib() - pad_index: Tensor = attr.ib(init=False) + pad_index: Tensor = attr.ib() # Modules. encoder: EfficientNet = attr.ib() decoder: Decoder = attr.ib() - mapping: Type[AbstractMapping] = attr.ib() latent_encoder: nn.Sequential = attr.ib(init=False) token_embedding: nn.Embedding = attr.ib(init=False) @@ -43,10 +39,6 @@ class ConvTransformer(BaseNetwork): def __attrs_post_init__(self) -> None: """Post init configuration.""" - self.start_index = self.mapping.get_index(self.start_token) - self.end_index = self.mapping.get_index(self.end_token) - self.pad_index = self.mapping.get_index(self.pad_token) - # Latent projector for down sampling number of filters and 2d # positional encoding. self.latent_encoder = nn.Sequential( @@ -156,46 +148,3 @@ class ConvTransformer(BaseNetwork): z = self.encode(x) logits = self.decode(z, context) return logits - - def predict(self, x: Tensor) -> Tensor: - """Predicts text in image. - - Args: - x (Tensor): Image(s) to extract text from. - - Shapes: - - x: :math: `(B, H, W)` - - output: :math: `(B, S)` - - Returns: - Tensor: A tensor of token indices of the predictions from the model. - """ - bsz = x.shape[0] - - # Encode image(s) to latent vectors. - z = self.encode(x) - - # Create a placeholder matrix for storing outputs from the network - output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) - output[:, 0] = self.start_index - - for i in range(1, self.max_output_len): - context = output[:, :i] # (bsz, i) - logits = self.decode(z, context) # (i, bsz, c) - tokens = torch.argmax(logits, dim=-1) # (i, bsz) - output[:, i : i + 1] = tokens[-1:] - - # Early stopping of prediction loop if token is end or padding token. - if ( - output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index - ).all(): - break - - # Set all tokens after end token to pad token. - for i in range(1, self.max_output_len): - idx = ( - output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index - ) - output[idx, i] = self.pad_index - - return output |