summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/cnn_tranformer.py30
-rw-r--r--text_recognizer/networks/encoders/efficientnet/efficientnet.py58
-rw-r--r--text_recognizer/networks/encoders/efficientnet/mbconv.py9
-rw-r--r--text_recognizer/networks/transformer/__init__.py2
-rw-r--r--text_recognizer/networks/transformer/attention.py40
-rw-r--r--text_recognizer/networks/transformer/layers.py91
-rw-r--r--text_recognizer/networks/util.py10
7 files changed, 135 insertions, 105 deletions
diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/cnn_tranformer.py
index e030cb8..ce7ec43 100644
--- a/text_recognizer/networks/cnn_tranformer.py
+++ b/text_recognizer/networks/cnn_tranformer.py
@@ -7,6 +7,7 @@ import torch
from torch import nn, Tensor
from text_recognizer.data.mappings import AbstractMapping
+from text_recognizer.networks.encoders.efficientnet import EfficientNet
from text_recognizer.networks.transformer.layers import Decoder
from text_recognizer.networks.transformer.positional_encodings import (
PositionalEncoding,
@@ -15,7 +16,7 @@ from text_recognizer.networks.transformer.positional_encodings import (
@attr.s
-class CnnTransformer(nn.Module):
+class Reader(nn.Module):
def __attrs_pre_init__(self) -> None:
super().__init__()
@@ -27,21 +28,20 @@ class CnnTransformer(nn.Module):
num_classes: int = attr.ib()
padding_idx: int = attr.ib()
start_token: str = attr.ib()
- start_index: int = attr.ib(init=False, default=None)
+ start_index: int = attr.ib(init=False)
end_token: str = attr.ib()
- end_index: int = attr.ib(init=False, default=None)
+ end_index: int = attr.ib(init=False)
pad_token: str = attr.ib()
- pad_index: int = attr.ib(init=False, default=None)
+ pad_index: int = attr.ib(init=False)
# Modules.
- encoder: Type[nn.Module] = attr.ib()
+ encoder: EfficientNet = attr.ib()
decoder: Decoder = attr.ib()
- embedding: nn.Embedding = attr.ib(init=False, default=None)
- latent_encoder: nn.Sequential = attr.ib(init=False, default=None)
- token_embedding: nn.Embedding = attr.ib(init=False, default=None)
- token_pos_encoder: PositionalEncoding = attr.ib(init=False, default=None)
- head: nn.Linear = attr.ib(init=False, default=None)
- mapping: AbstractMapping = attr.ib(init=False, default=None)
+ latent_encoder: nn.Sequential = attr.ib(init=False)
+ token_embedding: nn.Embedding = attr.ib(init=False)
+ token_pos_encoder: PositionalEncoding = attr.ib(init=False)
+ head: nn.Linear = attr.ib(init=False)
+ mapping: Type[AbstractMapping] = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
@@ -187,12 +187,16 @@ class CnnTransformer(nn.Module):
output[:, i : i + 1] = tokens[-1:]
# Early stopping of prediction loop if token is end or padding token.
- if (output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index).all():
+ if (
+ output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index
+ ).all():
break
# Set all tokens after end token to pad token.
for i in range(1, self.max_output_len):
- idx = (output[:, i -1] == self.end_index | output[:, i - 1] == self.pad_index)
+ idx = (
+ output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index
+ )
output[idx, i] = self.pad_index
return output
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
index 6719efb..a36150a 100644
--- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
@@ -1,4 +1,7 @@
"""Efficient net."""
+from typing import Tuple
+
+import attr
from torch import nn, Tensor
from .mbconv import MBConvBlock
@@ -9,10 +12,13 @@ from .utils import (
)
+@attr.s
class EfficientNet(nn.Module):
- # TODO: attr
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
archs = {
- # width,depth0res,dropout
+ # width, depth, dropout
"b0": (1.0, 1.0, 0.2),
"b1": (1.0, 1.1, 0.2),
"b2": (1.1, 1.2, 0.3),
@@ -25,30 +31,30 @@ class EfficientNet(nn.Module):
"l2": (4.3, 5.3, 0.5),
}
- def __init__(
- self,
- arch: str,
- out_channels: int = 1280,
- stochastic_dropout_rate: float = 0.2,
- bn_momentum: float = 0.99,
- bn_eps: float = 1.0e-3,
- ) -> None:
- super().__init__()
- assert arch in self.archs, f"{arch} not a valid efficient net architecure!"
- self.arch = self.archs[arch]
- self.out_channels = out_channels
- self.stochastic_dropout_rate = stochastic_dropout_rate
- self.bn_momentum = bn_momentum
- self.bn_eps = bn_eps
- self._conv_stem: nn.Sequential = None
- self._blocks: nn.ModuleList = None
- self._conv_head: nn.Sequential = None
+ arch: str = attr.ib()
+ params: Tuple[float, float, float] = attr.ib(default=None, init=False)
+ out_channels: int = attr.ib(default=1280)
+ stochastic_dropout_rate: float = attr.ib(default=0.2)
+ bn_momentum: float = attr.ib(default=0.99)
+ bn_eps: float = attr.ib(default=1.0e-3)
+ _conv_stem: nn.Sequential = attr.ib(default=None, init=False)
+ _blocks: nn.ModuleList = attr.ib(default=None, init=False)
+ _conv_head: nn.Sequential = attr.ib(default=None, init=False)
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
self._build()
+ @arch.validator
+ def check_arch(self, attribute: attr._make.Attribute, value: str) -> None:
+ if value not in self.archs:
+ raise ValueError(f"{value} not a valid architecure.")
+ self.params = self.archs[value]
+
def _build(self) -> None:
_block_args = block_args()
in_channels = 1 # BW
- out_channels = round_filters(32, self.arch)
+ out_channels = round_filters(32, self.params)
self._conv_stem = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(
@@ -65,9 +71,9 @@ class EfficientNet(nn.Module):
)
self._blocks = nn.ModuleList([])
for args in _block_args:
- args.in_channels = round_filters(args.in_channels, self.arch)
- args.out_channels = round_filters(args.out_channels, self.arch)
- args.num_repeats = round_repeats(args.num_repeats, self.arch)
+ args.in_channels = round_filters(args.in_channels, self.params)
+ args.out_channels = round_filters(args.out_channels, self.params)
+ args.num_repeats = round_repeats(args.num_repeats, self.params)
for _ in range(args.num_repeats):
self._blocks.append(
MBConvBlock(
@@ -77,8 +83,8 @@ class EfficientNet(nn.Module):
args.in_channels = args.out_channels
args.stride = 1
- in_channels = round_filters(320, self.arch)
- out_channels = round_filters(self.out_channels, self.arch)
+ in_channels = round_filters(320, self.params)
+ out_channels = round_filters(self.out_channels, self.params)
self._conv_head = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py
index e43771a..3aa63d0 100644
--- a/text_recognizer/networks/encoders/efficientnet/mbconv.py
+++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py
@@ -26,7 +26,7 @@ class MBConvBlock(nn.Module):
) -> None:
super().__init__()
self.kernel_size = kernel_size
- self.stride = (stride, ) * 2 if isinstance(stride, int) else stride
+ self.stride = (stride,) * 2 if isinstance(stride, int) else stride
self.bn_momentum = bn_momentum
self.bn_eps = bn_eps
self.in_channels = in_channels
@@ -68,8 +68,7 @@ class MBConvBlock(nn.Module):
inner_channels = in_channels * expand_ratio
self._inverted_bottleneck = (
self._configure_inverted_bottleneck(
- in_channels=in_channels,
- out_channels=inner_channels,
+ in_channels=in_channels, out_channels=inner_channels,
)
if expand_ratio != 1
else None
@@ -98,9 +97,7 @@ class MBConvBlock(nn.Module):
)
def _configure_inverted_bottleneck(
- self,
- in_channels: int,
- out_channels: int,
+ self, in_channels: int, out_channels: int,
) -> nn.Sequential:
"""Expansion phase."""
return nn.Sequential(
diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py
index a3f3011..51de619 100644
--- a/text_recognizer/networks/transformer/__init__.py
+++ b/text_recognizer/networks/transformer/__init__.py
@@ -1 +1,3 @@
"""Transformer modules."""
+from .layers import Decoder, Encoder
+from .transformer import Transformer
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 7bafc58..2770dc1 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -1,6 +1,7 @@
"""Implementes the attention module for the transformer."""
from typing import Optional, Tuple
+import attr
from einops import rearrange
from einops.layers.torch import Rearrange
import torch
@@ -14,31 +15,38 @@ from text_recognizer.networks.transformer.positional_encodings.rotary_embedding
)
+@attr.s
class Attention(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int,
- dim_head: int = 64,
- dropout_rate: float = 0.0,
- causal: bool = False,
- ) -> None:
+ """Standard attention."""
+
+ def __attrs_pre_init__(self) -> None:
super().__init__()
- self.scale = dim ** -0.5
- self.num_heads = num_heads
- self.causal = causal
- inner_dim = dim * dim_head
+
+ dim: int = attr.ib()
+ num_heads: int = attr.ib()
+ dim_head: int = attr.ib(default=64)
+ dropout_rate: float = attr.ib(default=0.0)
+ casual: bool = attr.ib(default=False)
+ scale: float = attr.ib(init=False)
+ dropout: nn.Dropout = attr.ib(init=False)
+ fc: nn.Linear = attr.ib(init=False)
+ qkv_fn: nn.Sequential = attr.ib(init=False)
+ attn_fn: F.softmax = attr.ib(init=False, default=F.softmax)
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+ self.scale = self.dim ** -0.5
+ inner_dim = self.dim * self.dim_head
# Attnetion
self.qkv_fn = nn.Sequential(
- nn.Linear(dim, 3 * inner_dim, bias=False),
+ nn.Linear(self.dim, 3 * inner_dim, bias=False),
Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=self.num_heads),
)
- self.dropout = nn.Dropout(dropout_rate)
- self.attn_fn = F.softmax
+ self.dropout = nn.Dropout(p=self.dropout_rate)
# Feedforward
- self.fc = nn.Linear(inner_dim, dim)
+ self.fc = nn.Linear(inner_dim, self.dim)
@staticmethod
def _apply_rotary_emb(
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index 4daa265..9b2f236 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -1,67 +1,74 @@
"""Transformer attention layer."""
from functools import partial
-from typing import Any, Dict, Optional, Tuple, Type
+from typing import Any, Dict, Optional, Tuple
+import attr
from torch import nn, Tensor
-from .attention import Attention
-from .mlp import FeedForward
-from .residual import Residual
-from .positional_encodings.rotary_embedding import RotaryEmbedding
+from text_recognizer.networks.transformer.residual import Residual
+from text_recognizer.networks.transformer.positional_encodings.rotary_embedding import (
+ RotaryEmbedding,
+)
+from text_recognizer.networks.util import load_partial_fn
+@attr.s
class AttentionLayers(nn.Module):
- def __init__(
- self,
- dim: int,
- depth: int,
- num_heads: int,
- ff_kwargs: Dict,
- attn_kwargs: Dict,
- attn_fn: Type[nn.Module] = Attention,
- norm_fn: Type[nn.Module] = nn.LayerNorm,
- ff_fn: Type[nn.Module] = FeedForward,
- rotary_emb: Optional[Type[nn.Module]] = None,
- rotary_emb_dim: Optional[int] = None,
- causal: bool = False,
- cross_attend: bool = False,
- pre_norm: bool = True,
- ) -> None:
+ """Standard transfomer layer."""
+
+ def __attrs_pre_init__(self) -> None:
super().__init__()
- self.dim = dim
- attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs)
- norm_fn = partial(norm_fn, dim)
- ff_fn = partial(ff_fn, dim=dim, **ff_kwargs)
- self.layer_types = self._get_layer_types(cross_attend) * depth
- self.layers = self._build_network(causal, attn_fn, norm_fn, ff_fn)
- rotary_emb_dim = max(rotary_emb_dim, 32) if rotary_emb_dim is not None else None
- self.rotary_emb = RotaryEmbedding(rotary_emb_dim) if rotary_emb else None
- self.pre_norm = pre_norm
- self.has_pos_emb = True if self.rotary_emb is not None else False
- @staticmethod
- def _get_layer_types(cross_attend: bool) -> Tuple:
+ dim: int = attr.ib()
+ depth: int = attr.ib()
+ num_heads: int = attr.ib()
+ attn_fn: str = attr.ib()
+ attn_kwargs: Dict = attr.ib()
+ norm_fn: str = attr.ib()
+ ff_fn: str = attr.ib()
+ ff_kwargs: Dict = attr.ib()
+ causal: bool = attr.ib(default=False)
+ cross_attend: bool = attr.ib(default=False)
+ pre_norm: bool = attr.ib(default=True)
+ rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None, init=False)
+ has_pos_emb: bool = attr.ib(init=False)
+ layer_types: Tuple[str, ...] = attr.ib(init=False)
+ layers: nn.ModuleList = attr.ib(init=False)
+ attn: partial = attr.ib(init=False)
+ norm: partial = attr.ib(init=False)
+ ff: partial = attr.ib(init=False)
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+ self.has_pos_emb = True if self.rotary_emb is not None else False
+ self.layer_types = self._get_layer_types() * self.depth
+ attn = load_partial_fn(
+ self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs
+ )
+ norm = load_partial_fn(self.norm_fn, dim=self.dim)
+ ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs)
+ self.layers = self._build_network(attn, norm, ff)
+
+ def _get_layer_types(self) -> Tuple:
"""Get layer specification."""
- if cross_attend:
+ if self.cross_attend:
return "a", "c", "f"
return "a", "f"
def _build_network(
- self, causal: bool, attn_fn: partial, norm_fn: partial, ff_fn: partial,
+ self, attn: partial, norm: partial, ff: partial,
) -> nn.ModuleList:
"""Configures transformer network."""
layers = nn.ModuleList([])
for layer_type in self.layer_types:
if layer_type == "a":
- layer = attn_fn(causal=causal)
+ layer = attn(causal=self.causal)
elif layer_type == "c":
- layer = attn_fn()
+ layer = attn()
elif layer_type == "f":
- layer = ff_fn()
-
+ layer = ff()
residual_fn = Residual()
-
- layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
+ layers.append(nn.ModuleList([norm(), layer, residual_fn]))
return layers
def forward(
@@ -72,12 +79,10 @@ class AttentionLayers(nn.Module):
context_mask: Optional[Tensor] = None,
) -> Tensor:
rotary_pos_emb = self.rotary_emb(x) if self.rotary_emb is not None else None
-
for i, (layer_type, (norm, block, residual_fn)) in enumerate(
zip(self.layer_types, self.layers)
):
is_last = i == len(self.layers) - 1
-
residual = x
if self.pre_norm:
diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py
index 85094f1..e822c57 100644
--- a/text_recognizer/networks/util.py
+++ b/text_recognizer/networks/util.py
@@ -1,5 +1,7 @@
"""Miscellaneous neural network utility functionality."""
-from typing import Type
+from functools import partial
+from importlib import import_module
+from typing import Any, Type
from torch import nn
@@ -19,3 +21,9 @@ def activation_function(activation: str) -> Type[nn.Module]:
]
)
return activation_fns[activation.lower()]
+
+
+def load_partial_fn(fn: str, **kwargs: Any) -> partial:
+ """Loads partial function."""
+ module = import_module(".".join(fn.split(".")[:-1]))
+ return partial(getattr(module, fn.split(".")[0]), **kwargs)