summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/conv_transformer.py42
-rw-r--r--text_recognizer/networks/encoders/efficientnet/mbconv.py9
-rw-r--r--text_recognizer/networks/transformer/layers.py27
3 files changed, 32 insertions, 46 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index 09cc654..f3ba49d 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -2,7 +2,6 @@
import math
from typing import Tuple
-import attr
from torch import nn, Tensor
from text_recognizer.networks.encoders.efficientnet import EfficientNet
@@ -13,32 +12,28 @@ from text_recognizer.networks.transformer.positional_encodings import (
)
-@attr.s(eq=False)
class ConvTransformer(nn.Module):
"""Convolutional encoder and transformer decoder network."""
- def __attrs_pre_init__(self) -> None:
+ def __init__(
+ self,
+ input_dims: Tuple[int, int, int],
+ hidden_dim: int,
+ dropout_rate: float,
+ num_classes: int,
+ pad_index: Tensor,
+ encoder: EfficientNet,
+ decoder: Decoder,
+ ) -> None:
super().__init__()
+ self.input_dims = input_dims
+ self.hidden_dim = hidden_dim
+ self.dropout_rate = dropout_rate
+ self.num_classes = num_classes
+ self.pad_index = pad_index
+ self.encoder = encoder
+ self.decoder = decoder
- # Parameters and placeholders,
- input_dims: Tuple[int, int, int] = attr.ib()
- hidden_dim: int = attr.ib()
- dropout_rate: float = attr.ib()
- max_output_len: int = attr.ib()
- num_classes: int = attr.ib()
- pad_index: Tensor = attr.ib()
-
- # Modules.
- encoder: EfficientNet = attr.ib()
- decoder: Decoder = attr.ib()
-
- latent_encoder: nn.Sequential = attr.ib(init=False)
- token_embedding: nn.Embedding = attr.ib(init=False)
- token_pos_encoder: PositionalEncoding = attr.ib(init=False)
- head: nn.Linear = attr.ib(init=False)
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
# Latent projector for down sampling number of filters and 2d
# positional encoding.
self.latent_encoder = nn.Sequential(
@@ -126,7 +121,8 @@ class ConvTransformer(nn.Module):
context = self.token_embedding(context) * math.sqrt(self.hidden_dim)
context = self.token_pos_encoder(context)
out = self.decoder(x=context, context=z, mask=context_mask)
- logits = self.head(out)
+ logits = self.head(out) # [B, Sy, T]
+ logits = logits.permute(0, 2, 1) # [B, T, Sy]
return logits
def forward(self, x: Tensor, context: Tensor) -> Tensor:
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py
index e85df87..7bfd9ba 100644
--- a/text_recognizer/networks/encoders/efficientnet/mbconv.py
+++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py
@@ -11,9 +11,7 @@ from text_recognizer.networks.encoders.efficientnet.utils import stochastic_dept
def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]:
"""Converts int to tuple."""
- return (
- (stride,) * 2 if isinstance(stride, int) else stride
- )
+ return (stride,) * 2 if isinstance(stride, int) else stride
@attr.s(eq=False)
@@ -41,10 +39,7 @@ class MBConvBlock(nn.Module):
def _configure_padding(self) -> Tuple[int, int, int, int]:
"""Set padding for convolutional layers."""
if self.stride == (2, 2):
- return (
- (self.kernel_size - 1) // 2 - 1,
- (self.kernel_size - 1) // 2,
- ) * 2
+ return ((self.kernel_size - 1) // 2 - 1, (self.kernel_size - 1) // 2,) * 2
return ((self.kernel_size - 1) // 2,) * 4
def __attrs_post_init__(self) -> None:
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index ce443e5..70a0ac7 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -1,5 +1,4 @@
"""Transformer attention layer."""
-from functools import partial
from typing import Any, Dict, Optional, Tuple
import attr
@@ -27,25 +26,17 @@ class AttentionLayers(nn.Module):
norm_fn: str = attr.ib()
ff_fn: str = attr.ib()
ff_kwargs: Dict = attr.ib()
+ rotary_emb: Optional[RotaryEmbedding] = attr.ib()
causal: bool = attr.ib(default=False)
cross_attend: bool = attr.ib(default=False)
pre_norm: bool = attr.ib(default=True)
- rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None)
layer_types: Tuple[str, ...] = attr.ib(init=False)
layers: nn.ModuleList = attr.ib(init=False)
- attn: partial = attr.ib(init=False)
- norm: partial = attr.ib(init=False)
- ff: partial = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
self.layer_types = self._get_layer_types() * self.depth
- attn = load_partial_fn(
- self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs
- )
- norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim)
- ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs)
- self.layers = self._build_network(attn, norm, ff)
+ self.layers = self._build_network()
def _get_layer_types(self) -> Tuple:
"""Get layer specification."""
@@ -53,10 +44,13 @@ class AttentionLayers(nn.Module):
return "a", "c", "f"
return "a", "f"
- def _build_network(
- self, attn: partial, norm: partial, ff: partial,
- ) -> nn.ModuleList:
+ def _build_network(self) -> nn.ModuleList:
"""Configures transformer network."""
+ attn = load_partial_fn(
+ self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs
+ )
+ norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim)
+ ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs)
layers = nn.ModuleList([])
for layer_type in self.layer_types:
if layer_type == "a":
@@ -106,6 +100,7 @@ class Encoder(AttentionLayers):
causal: bool = attr.ib(default=False, init=False)
-@attr.s(auto_attribs=True, eq=False)
class Decoder(AttentionLayers):
- causal: bool = attr.ib(default=True, init=False)
+ def __init__(self, **kwargs: Any) -> None:
+ assert "causal" not in kwargs, "Cannot set causality on decoder"
+ super().__init__(causal=True, **kwargs)