diff options
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
| -rw-r--r-- | text_recognizer/networks/conv_transformer.py | 69 | 
1 files changed, 9 insertions, 60 deletions
| diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 4acdc36..7371be4 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -1,13 +1,10 @@  """Vision transformer for character recognition."""  import math -from typing import Tuple, Type +from typing import Tuple  import attr -import torch  from torch import nn, Tensor -from text_recognizer.data.mappings import AbstractMapping -from text_recognizer.networks.base import BaseNetwork  from text_recognizer.networks.encoders.efficientnet import EfficientNet  from text_recognizer.networks.transformer.layers import Decoder  from text_recognizer.networks.transformer.positional_encodings import ( @@ -16,25 +13,24 @@ from text_recognizer.networks.transformer.positional_encodings import (  ) -@attr.s(auto_attribs=True) -class ConvTransformer(BaseNetwork): +@attr.s +class ConvTransformer(nn.Module): +    """Convolutional encoder and transformer decoder network.""" + +    def __attrs_pre_init__(self) -> None: +        super().__init__() +      # Parameters and placeholders,      input_dims: Tuple[int, int, int] = attr.ib()      hidden_dim: int = attr.ib()      dropout_rate: float = attr.ib()      max_output_len: int = attr.ib()      num_classes: int = attr.ib() -    start_token: str = attr.ib() -    start_index: Tensor = attr.ib(init=False) -    end_token: str = attr.ib() -    end_index: Tensor = attr.ib(init=False) -    pad_token: str = attr.ib() -    pad_index: Tensor = attr.ib(init=False) +    pad_index: Tensor = attr.ib()      # Modules.      encoder: EfficientNet = attr.ib()      decoder: Decoder = attr.ib() -    mapping: Type[AbstractMapping] = attr.ib()      latent_encoder: nn.Sequential = attr.ib(init=False)      token_embedding: nn.Embedding = attr.ib(init=False) @@ -43,10 +39,6 @@ class ConvTransformer(BaseNetwork):      def __attrs_post_init__(self) -> None:          """Post init configuration.""" -        self.start_index = self.mapping.get_index(self.start_token) -        self.end_index = self.mapping.get_index(self.end_token) -        self.pad_index = self.mapping.get_index(self.pad_token) -          # Latent projector for down sampling number of filters and 2d          # positional encoding.          self.latent_encoder = nn.Sequential( @@ -156,46 +148,3 @@ class ConvTransformer(BaseNetwork):          z = self.encode(x)          logits = self.decode(z, context)          return logits - -    def predict(self, x: Tensor) -> Tensor: -        """Predicts text in image. -         -        Args: -            x (Tensor): Image(s) to extract text from. - -        Shapes: -            - x: :math: `(B, H, W)` -            - output: :math: `(B, S)` - -        Returns: -            Tensor: A tensor of token indices of the predictions from the model. -        """ -        bsz = x.shape[0] - -        # Encode image(s) to latent vectors. -        z = self.encode(x) - -        # Create a placeholder matrix for storing outputs from the network -        output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) -        output[:, 0] = self.start_index - -        for i in range(1, self.max_output_len): -            context = output[:, :i]  # (bsz, i) -            logits = self.decode(z, context)  # (i, bsz, c) -            tokens = torch.argmax(logits, dim=-1)  # (i, bsz) -            output[:, i : i + 1] = tokens[-1:] - -            # Early stopping of prediction loop if token is end or padding token. -            if ( -                output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index -            ).all(): -                break - -        # Set all tokens after end token to pad token. -        for i in range(1, self.max_output_len): -            idx = ( -                output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index -            ) -            output[idx, i] = self.pad_index - -        return output |