diff options
Diffstat (limited to 'text_recognizer/networks')
| -rw-r--r-- | text_recognizer/networks/__init__.py | 2 | ||||
| -rw-r--r-- | text_recognizer/networks/backbones/__init__.py | 2 | ||||
| -rw-r--r-- | text_recognizer/networks/backbones/efficientnet.py | 145 | ||||
| -rw-r--r-- | text_recognizer/networks/cnn_transformer.py | 37 | ||||
| -rw-r--r-- | text_recognizer/networks/coat/__init__.py | 0 | ||||
| -rw-r--r-- | text_recognizer/networks/coat/factor_attention.py | 9 | ||||
| -rw-r--r-- | text_recognizer/networks/coat/patch_embedding.py | 38 | ||||
| -rw-r--r-- | text_recognizer/networks/coat/positional_encodings.py | 76 | ||||
| -rw-r--r-- | text_recognizer/networks/residual_network.py | 6 | ||||
| -rw-r--r-- | text_recognizer/networks/transducer/transducer.py | 7 | ||||
| -rw-r--r-- | text_recognizer/networks/transformer/positional_encoding.py | 6 | ||||
| -rw-r--r-- | text_recognizer/networks/transformer/rotary_embedding.py | 39 | ||||
| -rw-r--r-- | text_recognizer/networks/vqvae/decoder.py | 18 | ||||
| -rw-r--r-- | text_recognizer/networks/vqvae/encoder.py | 12 | 
14 files changed, 227 insertions, 170 deletions
diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py index 41fd43f..63b43b2 100644 --- a/text_recognizer/networks/__init__.py +++ b/text_recognizer/networks/__init__.py @@ -1,2 +1,4 @@  """Network modules""" +from .backbones import EfficientNet  from .vqvae import VQVAE +from .cnn_transformer import CNNTransformer diff --git a/text_recognizer/networks/backbones/__init__.py b/text_recognizer/networks/backbones/__init__.py new file mode 100644 index 0000000..25aed0e --- /dev/null +++ b/text_recognizer/networks/backbones/__init__.py @@ -0,0 +1,2 @@ +"""Vision backbones.""" +from .efficientnet import EfficientNet diff --git a/text_recognizer/networks/backbones/efficientnet.py b/text_recognizer/networks/backbones/efficientnet.py new file mode 100644 index 0000000..61dea77 --- /dev/null +++ b/text_recognizer/networks/backbones/efficientnet.py @@ -0,0 +1,145 @@ +"""Efficient net b0 implementation.""" +import torch +from torch import nn +from torch import Tensor + + +class ConvNorm(nn.Module): +    def __init__( +        self, +        in_channels: int, +        out_channels: int, +        kernel_size: int, +        stride: int, +        padding: int, +        groups: int = 1, +    ) -> None: +        super().__init__() +        self.block = nn.Sequential( +            nn.Conv2d( +                in_channels=in_channels, +                out_channels=out_channels, +                kernel_size=kernel_size, +                stride=stride, +                padding=padding, +                groups=groups, +                bias=False, +            ), +            nn.BatchNorm2d(num_features=out_channels), +            nn.SiLU(inplace=True), +        ) + +    def forward(self, x: Tensor) -> Tensor: +        return self.block(x) + + +class SqueezeExcite(nn.Module): +    def __init__(self, in_channels: int, reduce_dim: int) -> None: +        super().__init__() +        self.se = nn.Sequential( +            nn.AdaptiveAvgPool2d(1),  # [C, H, W] -> [C, 1, 1] +            nn.Conv2d(in_channels=in_channels, out_channels=reduce_dim, kernel_size=1), +            nn.SiLU(), +            nn.Conv2d(in_channels=reduce_dim, out_channels=in_channels, kernel_size=1), +            nn.Sigmoid(), +        ) + +    def forward(self, x: Tensor) -> Tensor: +        return x * self.se(x) + + +class InvertedResidulaBlock(nn.Module): +    def __init__( +        self, +        in_channels: int, +        out_channels: int, +        kernel_size: int, +        stride: int, +        padding: int, +        expand_ratio: float, +        reduction: int = 4, +        survival_prob: float = 0.8, +    ) -> None: +        super().__init__() +        self.survival_prob = survival_prob +        self.use_residual = in_channels == out_channels and stride == 1 +        hidden_dim = in_channels * expand_ratio +        self.expand = in_channels != hidden_dim +        reduce_dim = in_channels // reduction + +        if self.expand: +            self.expand_conv = ConvNorm( +                in_channels, hidden_dim, kernel_size=3, stride=1, padding=1 +            ) + +        self.conv = nn.Sequential( +            ConvNorm( +                hidden_dim, hidden_dim, kernel_size, stride, padding, groups=hidden_dim +            ), +            SqueezeExcite(hidden_dim, reduce_dim), +            nn.Conv2d( +                in_channels=hidden_dim, +                out_channels=out_channels, +                kernel_size=1, +                bias=False, +            ), +            nn.BatchNorm2d(num_features=out_channels), +        ) + +    def stochastic_depth(self, x: Tensor) -> Tensor: +        if not self.training: +            return x + +        binary_tensor = ( +            torch.rand(x.shape[0], 1, 1, 1, device=x.device) < self.survival_prob +        ) +        return torch.div(x, self.survival_prob) * binary_tensor + +    def forward(self, x: Tensor) -> Tensor: +        out = self.expand_conv(x) if self.expand else x +        if self.use_residual: +            return self.stochastic_depth(self.conv(out)) + x +        return self.conv(out) + + +class EfficientNet(nn.Module): +    """Efficient net b0 backbone.""" + +    def __init__(self) -> None: +        super().__init__() +        self.base_model = [ +            # expand_ratio, channels, repeats, stride, kernel_size +            [1, 16, 1, 1, 3], +            [6, 24, 2, 2, 3], +            [6, 40, 2, 2, 5], +            [6, 80, 3, 2, 3], +            [6, 112, 3, 1, 5], +            [6, 192, 4, 2, 5], +            [6, 320, 1, 1, 3], +        ] + +        self.backbone = self._build_b0() + +    def _build_b0(self) -> nn.Sequential: +        in_channels = 32 +        layers = [ConvNorm(1, in_channels, 3, stride=2, padding=1)] + +        for expand_ratio, out_channels, repeats, stride, kernel_size in self.base_model: +            for i in range(repeats): +                layers.append( +                    InvertedResidulaBlock( +                        in_channels, +                        out_channels, +                        expand_ratio=expand_ratio, +                        stride=stride if i == 0 else 1, +                        kernel_size=kernel_size, +                        padding=kernel_size // 2, +                    ) +                ) +                in_channels = out_channels +        layers.append(ConvNorm(in_channels, 256, kernel_size=1, stride=1, padding=0)) + +        return nn.Sequential(*layers) + +    def forward(self, x: Tensor) -> Tensor: +        return self.backbone(x) diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py index e23a15d..d42c29d 100644 --- a/text_recognizer/networks/cnn_transformer.py +++ b/text_recognizer/networks/cnn_transformer.py @@ -33,8 +33,8 @@ NUM_WORD_PIECES = 1000  class CNNTransformer(nn.Module):      def __init__(          self, -        input_shape: Sequence[int], -        output_shape: Sequence[int], +        input_dim: Sequence[int], +        output_dims: Sequence[int],          encoder: Union[DictConfig, Dict],          vocab_size: Optional[int] = None,          num_decoder_layers: int = 4, @@ -43,22 +43,29 @@ class CNNTransformer(nn.Module):          expansion_dim: int = 1024,          dropout_rate: float = 0.1,          transformer_activation: str = "glu", +        *args, +        **kwargs,      ) -> None: +        super().__init__()          self.vocab_size = (              NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size          ) +        self.pad_index = 3  # TODO: fix me          self.hidden_dim = hidden_dim -        self.max_output_length = output_shape[0] +        self.max_output_length = output_dims[0]          # Image backbone          self.encoder = self._configure_encoder(encoder) +        self.encoder_proj = nn.Conv2d(256, hidden_dim, kernel_size=1)          self.feature_map_encoding = PositionalEncoding2D( -            hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2] +            hidden_dim=hidden_dim, max_h=input_dim[1], max_w=input_dim[2]          )          # Target token embedding          self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) -        self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) +        self.trg_position_encoding = PositionalEncoding( +            hidden_dim, dropout_rate, max_len=output_dims[0] +        )          # Transformer decoder          self.decoder = Decoder( @@ -86,24 +93,25 @@ class CNNTransformer(nn.Module):          self.head.weight.data.uniform_(-0.1, 0.1)          nn.init.kaiming_normal_( -            self.feature_map_encoding.weight.data, +            self.encoder_proj.weight.data,              a=0,              mode="fan_out",              nonlinearity="relu",          ) -        if self.feature_map_encoding.bias is not None: +        if self.encoder_proj.bias is not None:              _, fan_out = nn.init._calculate_fan_in_and_fan_out( -                self.feature_map_encoding.weight.data +                self.encoder_proj.weight.data              )              bound = 1 / math.sqrt(fan_out) -            nn.init.normal_(self.feature_map_encoding.bias, -bound, bound) +            nn.init.normal_(self.encoder_proj.bias, -bound, bound)      @staticmethod      def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]:          encoder = OmegaConf.create(encoder) +        args = encoder.args or {}          network_module = importlib.import_module("text_recognizer.networks")          encoder_class = getattr(network_module, encoder.type) -        return encoder_class(**encoder.args) +        return encoder_class(**args)      def encode(self, image: Tensor) -> Tensor:          """Extracts image features with backbone. @@ -121,6 +129,7 @@ class CNNTransformer(nn.Module):          """          # Extract image features.          image_features = self.encoder(image) +        image_features = self.encoder_proj(image_features)          # Add 2d encoding to the feature maps.          image_features = self.feature_map_encoding(image_features) @@ -133,11 +142,19 @@ class CNNTransformer(nn.Module):          """Decodes image features with transformer decoder."""          trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index)          trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim) +        trg = rearrange(trg, "b t d -> t b d")          trg = self.trg_position_encoding(trg) +        trg = rearrange(trg, "t b d -> b t d")          out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None)          logits = self.head(out)          return logits +    def forward(self, image: Tensor, trg: Tensor) -> Tensor: +        image_features = self.encode(image) +        output = self.decode(image_features, trg) +        output = rearrange(output, "b t c -> b c t") +        return output +      def predict(self, image: Tensor) -> Tensor:          """Transcribes text in image(s)."""          bsz = image.shape[0] diff --git a/text_recognizer/networks/coat/__init__.py b/text_recognizer/networks/coat/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/text_recognizer/networks/coat/__init__.py +++ /dev/null diff --git a/text_recognizer/networks/coat/factor_attention.py b/text_recognizer/networks/coat/factor_attention.py deleted file mode 100644 index f91c5ef..0000000 --- a/text_recognizer/networks/coat/factor_attention.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Factorized attention with convolutional relative positional encodings.""" -from torch import nn - - -class FactorAttention(nn.Module): -    """Factorized attention with relative positional encodings.""" -    def __init__(self, dim: int, num_heads: int) -> None: -        pass - diff --git a/text_recognizer/networks/coat/patch_embedding.py b/text_recognizer/networks/coat/patch_embedding.py deleted file mode 100644 index 3b7b76a..0000000 --- a/text_recognizer/networks/coat/patch_embedding.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Patch embedding for images and feature maps.""" -from typing import Sequence, Tuple - -from einops import rearrange -from loguru import logger -from torch import nn -from torch import Tensor - - -class PatchEmbedding(nn.Module): -    """Patch embedding of images.""" - -    def __init__( -        self, -        image_shape: Sequence[int], -        patch_size: int = 16, -        in_channels: int = 1, -        embedding_dim: int = 512, -    ) -> None: -        if image_shape[0] % patch_size == 0 and image_shape[1] % patch_size == 0: -            logger.error( -                f"Image shape {image_shape} not divisable by patch size {patch_size}" -            ) - -        self.patch_size = patch_size -        self.embedding = nn.Conv2d( -            in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size -        ) -        self.norm = nn.LayerNorm(embedding_dim) - -    def forward(self, x: Tensor) -> Tuple[Tensor, Tuple[int, int]]: -        """Embeds image or feature maps with patch embedding.""" -        _, _, h, w = x.shape -        h_out, w_out = h // self.patch_size, w // self.patch_size -        x = self.embedding(x) -        x = rearrange(x, "b c h w -> b (h w) c") -        x = self.norm(x) -        return x, (h_out, w_out) diff --git a/text_recognizer/networks/coat/positional_encodings.py b/text_recognizer/networks/coat/positional_encodings.py deleted file mode 100644 index 925db04..0000000 --- a/text_recognizer/networks/coat/positional_encodings.py +++ /dev/null @@ -1,76 +0,0 @@ -"""Positional encodings for input sequence to transformer.""" -from typing import Dict, Union, Tuple - -from einops import rearrange -from loguru import logger -import torch -from torch import nn -from torch import Tensor - - -class RelativeEncoding(nn.Module): -    """Relative positional encoding.""" -    def __init__(self, channels: int, heads: int, windows: Union[int, Dict[int, int]]) -> None: -        super().__init__() -        self.windows = {windows: heads} if isinstance(windows, int) else windows -        self.heads = list(self.windows.values()) -        self.channel_heads = [head * channels for head in self.heads] -        self.convs = nn.ModuleList([ -            nn.Conv2d(in_channels=head * channels, -                             out_channels=head * channels, -                             kernel_shape=window, -                             padding=window // 2, -                             dilation=1, -                             groups=head * channels, -            ) for window, head in self.windows.items()]) - -    def forward(self, q: Tensor, v: Tensor, shape: Tuple[int, int]) -> Tensor: -        """Applies relative positional encoding.""" -        b, heads, hw, c = q.shape -        h, w = shape -        if hw != h * w: -            logger.exception(f"Query width {hw} neq to height x width {h * w}") -            raise ValueError -         -        v = rearrange(v, "b heads (h w) c -> b (heads c) h w", h=h, w=w) -        v = torch.split(v, self.channel_heads, dim=1) -        v = [conv(x) for conv, x in zip(self.convs, v)] -        v = torch.cat(v, dim=1) -        v = rearrange(v, "b (heads c) h w -> b heads (h w) c", heads=heads) - -        encoding = q * v -        zeros = torch.zeros((b, heads, 1, c), dtype=q.dtype, layout=q.layout, device=q.device) -        encoding = torch.cat((zeros, encoding), dim=2) -        return encoding - - -class PositionalEncoding(nn.Module): -    """Convolutional positional encoding.""" -    def __init__(self, dim: int, k: int = 3) -> None: -        super().__init__() -        self.encode = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=k, stride=1, padding=k//2, groups=dim) - -    def forward(self, x: Tensor, shape: Tuple[int, int]) -> Tensor: -        """Applies convolutional encoding.""" -        _, hw, _ = x.shape -        h, w = shape - -        if hw != h * w: -            logger.exception(f"Query width {hw} neq to height x width {h * w}") -            raise ValueError - -        # Depthwise convolution. -        x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) -        x = self.encode(x) + x -        x = rearrange(x, "b c h w -> b (h w) c") -        return x - - -         - - - -             -             - -         diff --git a/text_recognizer/networks/residual_network.py b/text_recognizer/networks/residual_network.py index da7553d..c33f419 100644 --- a/text_recognizer/networks/residual_network.py +++ b/text_recognizer/networks/residual_network.py @@ -20,11 +20,7 @@ class Conv2dAuto(nn.Conv2d):  def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential:      """3x3 convolution with batch norm.""" -    conv3x3 = partial( -        Conv2dAuto, -        kernel_size=3, -        bias=False, -    ) +    conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,)      return nn.Sequential(          conv3x3(in_channels, out_channels, *args, **kwargs),          nn.BatchNorm2d(out_channels), diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py index b10f93a..d7e3d08 100644 --- a/text_recognizer/networks/transducer/transducer.py +++ b/text_recognizer/networks/transducer/transducer.py @@ -392,12 +392,7 @@ def load_transducer_loss(          transitions = gtn.load(str(processed_path / transitions))      preprocessor = Preprocessor( -        data_dir, -        num_features, -        tokens_path, -        lexicon_path, -        use_words, -        prepend_wordsep, +        data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep,      )      num_tokens = preprocessor.num_tokens diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py index 5874e97..c50afc3 100644 --- a/text_recognizer/networks/transformer/positional_encoding.py +++ b/text_recognizer/networks/transformer/positional_encoding.py @@ -33,7 +33,10 @@ class PositionalEncoding(nn.Module):      def forward(self, x: Tensor) -> Tensor:          """Encodes the tensor with a postional embedding.""" -        x = x + self.pe[:, : x.shape[1]] +        # [T, B, D] +        if x.shape[2] != self.pe.shape[2]: +            raise ValueError(f"x shape does not match pe in the 3rd dim.") +        x = x + self.pe[: x.shape[0]]          return self.dropout(x) @@ -48,6 +51,7 @@ class PositionalEncoding2D(nn.Module):          pe = self.make_pe(hidden_dim, max_h, max_w)          self.register_buffer("pe", pe) +    @staticmethod      def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor:          """Returns 2d postional encoding."""          pe_h = PositionalEncoding.make_pe( diff --git a/text_recognizer/networks/transformer/rotary_embedding.py b/text_recognizer/networks/transformer/rotary_embedding.py new file mode 100644 index 0000000..5e80572 --- /dev/null +++ b/text_recognizer/networks/transformer/rotary_embedding.py @@ -0,0 +1,39 @@ +"""Roatary embedding. + +Stolen from lucidrains: +    https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py + +Explanation of roatary: +    https://blog.eleuther.ai/rotary-embeddings/ + +""" +from typing import Tuple + +from einops import rearrange +import torch +from torch import nn +from torch import Tensor + + +class RotaryEmbedding(nn.Module): +    def __init__(self, dim: int): +        super().__init__() +        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) +        self.register_buffer("inv_freq", inv_freq) + +    def forward(self, x: Tensor, seq_dim: int = 1) -> Tensor: +        t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) +        freqs = torch.einsum("i , j -> i j", t, self.inv_freq) +        emb = torch.cat((freqs, freqs), dim=-1) +        return emb[None, :, :] + + +def rotate_half(x: Tensor) -> Tensor: +    x = rearrange(x, "... (j d) -> ... j d", j=2) +    x1, x2 = x.unbind(dim=-2) +    return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q: Tensor, k: Tensor, freqs: Tensor) -> Tuple[Tensor, Tensor]: +    q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k)) +    return q, k diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py index 93a1e43..32de912 100644 --- a/text_recognizer/networks/vqvae/decoder.py +++ b/text_recognizer/networks/vqvae/decoder.py @@ -44,12 +44,7 @@ class Decoder(nn.Module):          # Configure encoder.          self.decoder = self._build_decoder( -            channels, -            kernel_sizes, -            strides, -            num_residual_layers, -            activation, -            dropout, +            channels, kernel_sizes, strides, num_residual_layers, activation, dropout,          )      def _build_decompression_block( @@ -78,9 +73,7 @@ class Decoder(nn.Module):              )              if self.upsampling and i < len(self.upsampling): -                modules.append( -                    nn.Upsample(size=self.upsampling[i]), -                ) +                modules.append(nn.Upsample(size=self.upsampling[i]),)              if dropout is not None:                  modules.append(dropout) @@ -109,12 +102,7 @@ class Decoder(nn.Module):      ) -> nn.Sequential:          self.res_block.append( -            nn.Conv2d( -                self.embedding_dim, -                channels[0], -                kernel_size=1, -                stride=1, -            ) +            nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,)          )          # Bottleneck module. diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py index b0cceed..65801df 100644 --- a/text_recognizer/networks/vqvae/encoder.py +++ b/text_recognizer/networks/vqvae/encoder.py @@ -11,10 +11,7 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer  class _ResidualBlock(nn.Module):      def __init__( -        self, -        in_channels: int, -        out_channels: int, -        dropout: Optional[Type[nn.Module]], +        self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]],      ) -> None:          super().__init__()          self.block = [ @@ -138,12 +135,7 @@ class Encoder(nn.Module):          )          encoder.append( -            nn.Conv2d( -                channels[-1], -                self.embedding_dim, -                kernel_size=1, -                stride=1, -            ) +            nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,)          )          return nn.Sequential(*encoder)  |