diff options
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/__init__.py | 4 | ||||
-rw-r--r-- | text_recognizer/networks/base.py | 18 | ||||
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 69 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/attention.py | 2 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/layers.py | 16 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/norm.py | 8 | ||||
-rw-r--r-- | text_recognizer/networks/util.py | 4 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/vqvae.py | 1 |
8 files changed, 22 insertions, 100 deletions
diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py index 618450f..d9ef58b 100644 --- a/text_recognizer/networks/__init__.py +++ b/text_recognizer/networks/__init__.py @@ -1,5 +1 @@ """Network modules""" -# from .encoders import EfficientNet -from .vqvae import VQVAE - -# from .cnn_transformer import CNNTransformer diff --git a/text_recognizer/networks/base.py b/text_recognizer/networks/base.py deleted file mode 100644 index 07b6a32..0000000 --- a/text_recognizer/networks/base.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Base network with required methods.""" -from abc import abstractmethod - -import attr -from torch import nn, Tensor - - -@attr.s -class BaseNetwork(nn.Module): - """Base network.""" - - def __attrs_pre_init__(self) -> None: - super().__init__() - - @abstractmethod - def predict(self, x: Tensor) -> Tensor: - """Return token indices for predictions.""" - ... diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 4acdc36..7371be4 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -1,13 +1,10 @@ """Vision transformer for character recognition.""" import math -from typing import Tuple, Type +from typing import Tuple import attr -import torch from torch import nn, Tensor -from text_recognizer.data.mappings import AbstractMapping -from text_recognizer.networks.base import BaseNetwork from text_recognizer.networks.encoders.efficientnet import EfficientNet from text_recognizer.networks.transformer.layers import Decoder from text_recognizer.networks.transformer.positional_encodings import ( @@ -16,25 +13,24 @@ from text_recognizer.networks.transformer.positional_encodings import ( ) -@attr.s(auto_attribs=True) -class ConvTransformer(BaseNetwork): +@attr.s +class ConvTransformer(nn.Module): + """Convolutional encoder and transformer decoder network.""" + + def __attrs_pre_init__(self) -> None: + super().__init__() + # Parameters and placeholders, input_dims: Tuple[int, int, int] = attr.ib() hidden_dim: int = attr.ib() dropout_rate: float = attr.ib() max_output_len: int = attr.ib() num_classes: int = attr.ib() - start_token: str = attr.ib() - start_index: Tensor = attr.ib(init=False) - end_token: str = attr.ib() - end_index: Tensor = attr.ib(init=False) - pad_token: str = attr.ib() - pad_index: Tensor = attr.ib(init=False) + pad_index: Tensor = attr.ib() # Modules. encoder: EfficientNet = attr.ib() decoder: Decoder = attr.ib() - mapping: Type[AbstractMapping] = attr.ib() latent_encoder: nn.Sequential = attr.ib(init=False) token_embedding: nn.Embedding = attr.ib(init=False) @@ -43,10 +39,6 @@ class ConvTransformer(BaseNetwork): def __attrs_post_init__(self) -> None: """Post init configuration.""" - self.start_index = self.mapping.get_index(self.start_token) - self.end_index = self.mapping.get_index(self.end_token) - self.pad_index = self.mapping.get_index(self.pad_token) - # Latent projector for down sampling number of filters and 2d # positional encoding. self.latent_encoder = nn.Sequential( @@ -156,46 +148,3 @@ class ConvTransformer(BaseNetwork): z = self.encode(x) logits = self.decode(z, context) return logits - - def predict(self, x: Tensor) -> Tensor: - """Predicts text in image. - - Args: - x (Tensor): Image(s) to extract text from. - - Shapes: - - x: :math: `(B, H, W)` - - output: :math: `(B, S)` - - Returns: - Tensor: A tensor of token indices of the predictions from the model. - """ - bsz = x.shape[0] - - # Encode image(s) to latent vectors. - z = self.encode(x) - - # Create a placeholder matrix for storing outputs from the network - output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) - output[:, 0] = self.start_index - - for i in range(1, self.max_output_len): - context = output[:, :i] # (bsz, i) - logits = self.decode(z, context) # (i, bsz, c) - tokens = torch.argmax(logits, dim=-1) # (i, bsz) - output[:, i : i + 1] = tokens[-1:] - - # Early stopping of prediction loop if token is end or padding token. - if ( - output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index - ).all(): - break - - # Set all tokens after end token to pad token. - for i in range(1, self.max_output_len): - idx = ( - output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index - ) - output[idx, i] = self.pad_index - - return output diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 2770dc1..9202cce 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -24,9 +24,9 @@ class Attention(nn.Module): dim: int = attr.ib() num_heads: int = attr.ib() + causal: bool = attr.ib(default=False) dim_head: int = attr.ib(default=64) dropout_rate: float = attr.ib(default=0.0) - casual: bool = attr.ib(default=False) scale: float = attr.ib(init=False) dropout: nn.Dropout = attr.ib(init=False) fc: nn.Linear = attr.ib(init=False) diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index 9b2f236..66c9c50 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -30,8 +30,7 @@ class AttentionLayers(nn.Module): causal: bool = attr.ib(default=False) cross_attend: bool = attr.ib(default=False) pre_norm: bool = attr.ib(default=True) - rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None, init=False) - has_pos_emb: bool = attr.ib(init=False) + rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None) layer_types: Tuple[str, ...] = attr.ib(init=False) layers: nn.ModuleList = attr.ib(init=False) attn: partial = attr.ib(init=False) @@ -40,12 +39,11 @@ class AttentionLayers(nn.Module): def __attrs_post_init__(self) -> None: """Post init configuration.""" - self.has_pos_emb = True if self.rotary_emb is not None else False self.layer_types = self._get_layer_types() * self.depth attn = load_partial_fn( self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs ) - norm = load_partial_fn(self.norm_fn, dim=self.dim) + norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim) ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs) self.layers = self._build_network(attn, norm, ff) @@ -103,13 +101,11 @@ class AttentionLayers(nn.Module): return x +@attr.s(auto_attribs=True) class Encoder(AttentionLayers): - def __init__(self, **kwargs: Any) -> None: - assert "causal" not in kwargs, "Cannot set causality on encoder" - super().__init__(causal=False, **kwargs) + causal: bool = attr.ib(default=False, init=False) +@attr.s(auto_attribs=True) class Decoder(AttentionLayers): - def __init__(self, **kwargs: Any) -> None: - assert "causal" not in kwargs, "Cannot set causality on decoder" - super().__init__(causal=True, **kwargs) + causal: bool = attr.ib(default=True, init=False) diff --git a/text_recognizer/networks/transformer/norm.py b/text_recognizer/networks/transformer/norm.py index 8bc3221..4930adf 100644 --- a/text_recognizer/networks/transformer/norm.py +++ b/text_recognizer/networks/transformer/norm.py @@ -12,9 +12,9 @@ from torch import Tensor class ScaleNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1.0e-5) -> None: + def __init__(self, normalized_shape: int, eps: float = 1.0e-5) -> None: super().__init__() - self.scale = dim ** -0.5 + self.scale = normalized_shape ** -0.5 self.eps = eps self.g = nn.Parameter(torch.ones(1)) @@ -24,9 +24,9 @@ class ScaleNorm(nn.Module): class PreNorm(nn.Module): - def __init__(self, dim: int, fn: Type[nn.Module]) -> None: + def __init__(self, normalized_shape: int, fn: Type[nn.Module]) -> None: super().__init__() - self.norm = nn.LayerNorm(dim) + self.norm = nn.LayerNorm(normalized_shape) self.fn = fn def forward(self, x: Tensor, **kwargs: Dict) -> Tensor: diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py index e822c57..c94e8dc 100644 --- a/text_recognizer/networks/util.py +++ b/text_recognizer/networks/util.py @@ -24,6 +24,6 @@ def activation_function(activation: str) -> Type[nn.Module]: def load_partial_fn(fn: str, **kwargs: Any) -> partial: - """Loads partial function.""" + """Loads partial function/class.""" module = import_module(".".join(fn.split(".")[:-1])) - return partial(getattr(module, fn.split(".")[0]), **kwargs) + return partial(getattr(module, fn.split(".")[-1]), **kwargs) diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py index 1f08e5e..5aa929b 100644 --- a/text_recognizer/networks/vqvae/vqvae.py +++ b/text_recognizer/networks/vqvae/vqvae.py @@ -1,5 +1,4 @@ """The VQ-VAE.""" - from typing import Any, Dict, List, Optional, Tuple from torch import nn |