summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-02 21:13:48 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-02 21:13:48 +0200
commit75801019981492eedf9280cb352eea3d8e99b65f (patch)
tree6521cc4134459e42591b2375f70acd348741474e /text_recognizer/networks
parente5eca28438cd17d436359f2c6eee0bb9e55d2a8b (diff)
Fix log import, fix mapping in datamodules, fix nn modules can be hashed
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/conv_transformer.py3
-rw-r--r--text_recognizer/networks/encoders/efficientnet/efficientnet.py15
-rw-r--r--text_recognizer/networks/encoders/efficientnet/mbconv.py139
-rw-r--r--text_recognizer/networks/transformer/attention.py7
-rw-r--r--text_recognizer/networks/transformer/layers.py6
5 files changed, 75 insertions, 95 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index 7371be4..09cc654 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -13,7 +13,7 @@ from text_recognizer.networks.transformer.positional_encodings import (
)
-@attr.s
+@attr.s(eq=False)
class ConvTransformer(nn.Module):
"""Convolutional encoder and transformer decoder network."""
@@ -121,6 +121,7 @@ class ConvTransformer(nn.Module):
Returns:
Tensor: Sequence of word piece embeddings.
"""
+ context = context.long()
context_mask = context != self.pad_index
context = self.token_embedding(context) * math.sqrt(self.hidden_dim)
context = self.token_pos_encoder(context)
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
index a36150a..b8eb53b 100644
--- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
@@ -1,4 +1,4 @@
-"""Efficient net."""
+"""Efficientnet backbone."""
from typing import Tuple
import attr
@@ -12,8 +12,10 @@ from .utils import (
)
-@attr.s
+@attr.s(eq=False)
class EfficientNet(nn.Module):
+ """Efficientnet without classification head."""
+
def __attrs_pre_init__(self) -> None:
super().__init__()
@@ -47,11 +49,13 @@ class EfficientNet(nn.Module):
@arch.validator
def check_arch(self, attribute: attr._make.Attribute, value: str) -> None:
+ """Validates the efficientnet architecure."""
if value not in self.archs:
raise ValueError(f"{value} not a valid architecure.")
self.params = self.archs[value]
def _build(self) -> None:
+ """Builds the efficientnet backbone."""
_block_args = block_args()
in_channels = 1 # BW
out_channels = round_filters(32, self.params)
@@ -73,8 +77,9 @@ class EfficientNet(nn.Module):
for args in _block_args:
args.in_channels = round_filters(args.in_channels, self.params)
args.out_channels = round_filters(args.out_channels, self.params)
- args.num_repeats = round_repeats(args.num_repeats, self.params)
- for _ in range(args.num_repeats):
+ num_repeats = round_repeats(args.num_repeats, self.params)
+ del args.num_repeats
+ for _ in range(num_repeats):
self._blocks.append(
MBConvBlock(
**args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps,
@@ -93,6 +98,7 @@ class EfficientNet(nn.Module):
)
def extract_features(self, x: Tensor) -> Tensor:
+ """Extracts the final feature map layer."""
x = self._conv_stem(x)
for i, block in enumerate(self._blocks):
stochastic_dropout_rate = self.stochastic_dropout_rate
@@ -103,4 +109,5 @@ class EfficientNet(nn.Module):
return x
def forward(self, x: Tensor) -> Tensor:
+ """Returns efficientnet image features."""
return self.extract_features(x)
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py
index 3aa63d0..e85df87 100644
--- a/text_recognizer/networks/encoders/efficientnet/mbconv.py
+++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py
@@ -1,76 +1,62 @@
"""Mobile inverted residual block."""
-from typing import Any, Optional, Union, Tuple
+from typing import Optional, Sequence, Union, Tuple
+import attr
import torch
from torch import nn, Tensor
import torch.nn.functional as F
-from .utils import stochastic_depth
+from text_recognizer.networks.encoders.efficientnet.utils import stochastic_depth
+def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]:
+ """Converts int to tuple."""
+ return (
+ (stride,) * 2 if isinstance(stride, int) else stride
+ )
+
+
+@attr.s(eq=False)
class MBConvBlock(nn.Module):
"""Mobile Inverted Residual Bottleneck block."""
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int,
- stride: Union[Tuple[int, int], int],
- bn_momentum: float,
- bn_eps: float,
- se_ratio: float,
- expand_ratio: int,
- *args: Any,
- **kwargs: Any,
- ) -> None:
+ def __attrs_pre_init__(self) -> None:
super().__init__()
- self.kernel_size = kernel_size
- self.stride = (stride,) * 2 if isinstance(stride, int) else stride
- self.bn_momentum = bn_momentum
- self.bn_eps = bn_eps
- self.in_channels = in_channels
- self.out_channels = out_channels
+ in_channels: int = attr.ib()
+ out_channels: int = attr.ib()
+ kernel_size: Tuple[int, int] = attr.ib()
+ stride: Tuple[int, int] = attr.ib(converter=_convert_stride)
+ bn_momentum: float = attr.ib()
+ bn_eps: float = attr.ib()
+ se_ratio: float = attr.ib()
+ expand_ratio: int = attr.ib()
+ pad: Tuple[int, int, int, int] = attr.ib(init=False)
+ _inverted_bottleneck: nn.Sequential = attr.ib(init=False)
+ _depthwise: nn.Sequential = attr.ib(init=False)
+ _squeeze_excite: nn.Sequential = attr.ib(init=False)
+ _pointwise: nn.Sequential = attr.ib(init=False)
+
+ @pad.default
+ def _configure_padding(self) -> Tuple[int, int, int, int]:
+ """Set padding for convolutional layers."""
if self.stride == (2, 2):
- self.pad = [
+ return (
(self.kernel_size - 1) // 2 - 1,
(self.kernel_size - 1) // 2,
- ] * 2
- else:
- self.pad = [(self.kernel_size - 1) // 2] * 4
-
- # Placeholders for layers.
- self._inverted_bottleneck: nn.Sequential = None
- self._depthwise: nn.Sequential = None
- self._squeeze_excite: nn.Sequential = None
- self._pointwise: nn.Sequential = None
-
- self._build(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- stride=stride,
- expand_ratio=expand_ratio,
- se_ratio=se_ratio,
- )
+ ) * 2
+ return ((self.kernel_size - 1) // 2,) * 4
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+ self._build()
- def _build(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int,
- stride: Union[Tuple[int, int], int],
- expand_ratio: int,
- se_ratio: float,
- ) -> None:
- has_se = se_ratio is not None and 0.0 < se_ratio < 1.0
- inner_channels = in_channels * expand_ratio
+ def _build(self) -> None:
+ has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0
+ inner_channels = self.in_channels * self.expand_ratio
self._inverted_bottleneck = (
- self._configure_inverted_bottleneck(
- in_channels=in_channels, out_channels=inner_channels,
- )
- if expand_ratio != 1
+ self._configure_inverted_bottleneck(out_channels=inner_channels)
+ if self.expand_ratio != 1
else None
)
@@ -78,31 +64,23 @@ class MBConvBlock(nn.Module):
in_channels=inner_channels,
out_channels=inner_channels,
groups=inner_channels,
- kernel_size=kernel_size,
- stride=stride,
)
self._squeeze_excite = (
self._configure_squeeze_excite(
- in_channels=inner_channels,
- out_channels=inner_channels,
- se_ratio=se_ratio,
+ in_channels=inner_channels, out_channels=inner_channels,
)
if has_se
else None
)
- self._pointwise = self._configure_pointwise(
- in_channels=inner_channels, out_channels=out_channels
- )
+ self._pointwise = self._configure_pointwise(in_channels=inner_channels)
- def _configure_inverted_bottleneck(
- self, in_channels: int, out_channels: int,
- ) -> nn.Sequential:
+ def _configure_inverted_bottleneck(self, out_channels: int) -> nn.Sequential:
"""Expansion phase."""
return nn.Sequential(
nn.Conv2d(
- in_channels=in_channels,
+ in_channels=self.in_channels,
out_channels=out_channels,
kernel_size=1,
bias=False,
@@ -114,19 +92,14 @@ class MBConvBlock(nn.Module):
)
def _configure_depthwise(
- self,
- in_channels: int,
- out_channels: int,
- groups: int,
- kernel_size: int,
- stride: Union[Tuple[int, int], int],
+ self, in_channels: int, out_channels: int, groups: int,
) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
- kernel_size=kernel_size,
- stride=stride,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
groups=groups,
bias=False,
),
@@ -137,9 +110,9 @@ class MBConvBlock(nn.Module):
)
def _configure_squeeze_excite(
- self, in_channels: int, out_channels: int, se_ratio: float
+ self, in_channels: int, out_channels: int
) -> nn.Sequential:
- num_squeezed_channels = max(1, int(in_channels * se_ratio))
+ num_squeezed_channels = max(1, int(in_channels * self.se_ratio))
return nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
@@ -154,18 +127,18 @@ class MBConvBlock(nn.Module):
),
)
- def _configure_pointwise(
- self, in_channels: int, out_channels: int
- ) -> nn.Sequential:
+ def _configure_pointwise(self, in_channels: int) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
- out_channels=out_channels,
+ out_channels=self.out_channels,
kernel_size=1,
bias=False,
),
nn.BatchNorm2d(
- num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
+ num_features=self.out_channels,
+ momentum=self.bn_momentum,
+ eps=self.bn_eps,
),
)
@@ -186,8 +159,8 @@ class MBConvBlock(nn.Module):
residual = x
if self._inverted_bottleneck is not None:
x = self._inverted_bottleneck(x)
- x = F.pad(x, self.pad)
+ x = F.pad(x, self.pad)
x = self._depthwise(x)
if self._squeeze_excite is not None:
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 9202cce..37ce29e 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -15,7 +15,7 @@ from text_recognizer.networks.transformer.positional_encodings.rotary_embedding
)
-@attr.s
+@attr.s(eq=False)
class Attention(nn.Module):
"""Standard attention."""
@@ -31,7 +31,6 @@ class Attention(nn.Module):
dropout: nn.Dropout = attr.ib(init=False)
fc: nn.Linear = attr.ib(init=False)
qkv_fn: nn.Sequential = attr.ib(init=False)
- attn_fn: F.softmax = attr.ib(init=False, default=F.softmax)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
@@ -80,7 +79,7 @@ class Attention(nn.Module):
else k_mask
)
q_mask = rearrange(q_mask, "b i -> b () i ()")
- k_mask = rearrange(k_mask, "b i -> b () () j")
+ k_mask = rearrange(k_mask, "b j -> b () () j")
return q_mask * k_mask
return
@@ -129,7 +128,7 @@ class Attention(nn.Module):
if self.causal:
energy = self._apply_causal_mask(energy, mask, mask_value, device)
- attn = self.attn_fn(energy, dim=-1)
+ attn = F.softmax(energy, dim=-1)
attn = self.dropout(attn)
out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index 66c9c50..ce443e5 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -12,7 +12,7 @@ from text_recognizer.networks.transformer.positional_encodings.rotary_embedding
from text_recognizer.networks.util import load_partial_fn
-@attr.s
+@attr.s(eq=False)
class AttentionLayers(nn.Module):
"""Standard transfomer layer."""
@@ -101,11 +101,11 @@ class AttentionLayers(nn.Module):
return x
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, eq=False)
class Encoder(AttentionLayers):
causal: bool = attr.ib(default=False, init=False)
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, eq=False)
class Decoder(AttentionLayers):
causal: bool = attr.ib(default=True, init=False)