From bd4bd443f339e95007bfdabf3e060db720f4d4b9 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 3 Aug 2021 18:18:48 +0200 Subject: Training working, multiple bug fixes --- text_recognizer/networks/conv_transformer.py | 42 ++++++++++------------ .../networks/encoders/efficientnet/mbconv.py | 9 ++--- text_recognizer/networks/transformer/layers.py | 27 ++++++-------- 3 files changed, 32 insertions(+), 46 deletions(-) (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 09cc654..f3ba49d 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -2,7 +2,6 @@ import math from typing import Tuple -import attr from torch import nn, Tensor from text_recognizer.networks.encoders.efficientnet import EfficientNet @@ -13,32 +12,28 @@ from text_recognizer.networks.transformer.positional_encodings import ( ) -@attr.s(eq=False) class ConvTransformer(nn.Module): """Convolutional encoder and transformer decoder network.""" - def __attrs_pre_init__(self) -> None: + def __init__( + self, + input_dims: Tuple[int, int, int], + hidden_dim: int, + dropout_rate: float, + num_classes: int, + pad_index: Tensor, + encoder: EfficientNet, + decoder: Decoder, + ) -> None: super().__init__() + self.input_dims = input_dims + self.hidden_dim = hidden_dim + self.dropout_rate = dropout_rate + self.num_classes = num_classes + self.pad_index = pad_index + self.encoder = encoder + self.decoder = decoder - # Parameters and placeholders, - input_dims: Tuple[int, int, int] = attr.ib() - hidden_dim: int = attr.ib() - dropout_rate: float = attr.ib() - max_output_len: int = attr.ib() - num_classes: int = attr.ib() - pad_index: Tensor = attr.ib() - - # Modules. - encoder: EfficientNet = attr.ib() - decoder: Decoder = attr.ib() - - latent_encoder: nn.Sequential = attr.ib(init=False) - token_embedding: nn.Embedding = attr.ib(init=False) - token_pos_encoder: PositionalEncoding = attr.ib(init=False) - head: nn.Linear = attr.ib(init=False) - - def __attrs_post_init__(self) -> None: - """Post init configuration.""" # Latent projector for down sampling number of filters and 2d # positional encoding. self.latent_encoder = nn.Sequential( @@ -126,7 +121,8 @@ class ConvTransformer(nn.Module): context = self.token_embedding(context) * math.sqrt(self.hidden_dim) context = self.token_pos_encoder(context) out = self.decoder(x=context, context=z, mask=context_mask) - logits = self.head(out) + logits = self.head(out) # [B, Sy, T] + logits = logits.permute(0, 2, 1) # [B, T, Sy] return logits def forward(self, x: Tensor, context: Tensor) -> Tensor: diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py index e85df87..7bfd9ba 100644 --- a/text_recognizer/networks/encoders/efficientnet/mbconv.py +++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py @@ -11,9 +11,7 @@ from text_recognizer.networks.encoders.efficientnet.utils import stochastic_dept def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]: """Converts int to tuple.""" - return ( - (stride,) * 2 if isinstance(stride, int) else stride - ) + return (stride,) * 2 if isinstance(stride, int) else stride @attr.s(eq=False) @@ -41,10 +39,7 @@ class MBConvBlock(nn.Module): def _configure_padding(self) -> Tuple[int, int, int, int]: """Set padding for convolutional layers.""" if self.stride == (2, 2): - return ( - (self.kernel_size - 1) // 2 - 1, - (self.kernel_size - 1) // 2, - ) * 2 + return ((self.kernel_size - 1) // 2 - 1, (self.kernel_size - 1) // 2,) * 2 return ((self.kernel_size - 1) // 2,) * 4 def __attrs_post_init__(self) -> None: diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index ce443e5..70a0ac7 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -1,5 +1,4 @@ """Transformer attention layer.""" -from functools import partial from typing import Any, Dict, Optional, Tuple import attr @@ -27,25 +26,17 @@ class AttentionLayers(nn.Module): norm_fn: str = attr.ib() ff_fn: str = attr.ib() ff_kwargs: Dict = attr.ib() + rotary_emb: Optional[RotaryEmbedding] = attr.ib() causal: bool = attr.ib(default=False) cross_attend: bool = attr.ib(default=False) pre_norm: bool = attr.ib(default=True) - rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None) layer_types: Tuple[str, ...] = attr.ib(init=False) layers: nn.ModuleList = attr.ib(init=False) - attn: partial = attr.ib(init=False) - norm: partial = attr.ib(init=False) - ff: partial = attr.ib(init=False) def __attrs_post_init__(self) -> None: """Post init configuration.""" self.layer_types = self._get_layer_types() * self.depth - attn = load_partial_fn( - self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs - ) - norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim) - ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs) - self.layers = self._build_network(attn, norm, ff) + self.layers = self._build_network() def _get_layer_types(self) -> Tuple: """Get layer specification.""" @@ -53,10 +44,13 @@ class AttentionLayers(nn.Module): return "a", "c", "f" return "a", "f" - def _build_network( - self, attn: partial, norm: partial, ff: partial, - ) -> nn.ModuleList: + def _build_network(self) -> nn.ModuleList: """Configures transformer network.""" + attn = load_partial_fn( + self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs + ) + norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim) + ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs) layers = nn.ModuleList([]) for layer_type in self.layer_types: if layer_type == "a": @@ -106,6 +100,7 @@ class Encoder(AttentionLayers): causal: bool = attr.ib(default=False, init=False) -@attr.s(auto_attribs=True, eq=False) class Decoder(AttentionLayers): - causal: bool = attr.ib(default=True, init=False) + def __init__(self, **kwargs: Any) -> None: + assert "causal" not in kwargs, "Cannot set causality on decoder" + super().__init__(causal=True, **kwargs) -- cgit v1.2.3-70-g09d2