summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/__init__.py4
-rw-r--r--text_recognizer/networks/base.py18
-rw-r--r--text_recognizer/networks/conv_transformer.py69
-rw-r--r--text_recognizer/networks/transformer/attention.py2
-rw-r--r--text_recognizer/networks/transformer/layers.py16
-rw-r--r--text_recognizer/networks/transformer/norm.py8
-rw-r--r--text_recognizer/networks/util.py4
-rw-r--r--text_recognizer/networks/vqvae/vqvae.py1
8 files changed, 22 insertions, 100 deletions
diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py
index 618450f..d9ef58b 100644
--- a/text_recognizer/networks/__init__.py
+++ b/text_recognizer/networks/__init__.py
@@ -1,5 +1 @@
"""Network modules"""
-# from .encoders import EfficientNet
-from .vqvae import VQVAE
-
-# from .cnn_transformer import CNNTransformer
diff --git a/text_recognizer/networks/base.py b/text_recognizer/networks/base.py
deleted file mode 100644
index 07b6a32..0000000
--- a/text_recognizer/networks/base.py
+++ /dev/null
@@ -1,18 +0,0 @@
-"""Base network with required methods."""
-from abc import abstractmethod
-
-import attr
-from torch import nn, Tensor
-
-
-@attr.s
-class BaseNetwork(nn.Module):
- """Base network."""
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
-
- @abstractmethod
- def predict(self, x: Tensor) -> Tensor:
- """Return token indices for predictions."""
- ...
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index 4acdc36..7371be4 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -1,13 +1,10 @@
"""Vision transformer for character recognition."""
import math
-from typing import Tuple, Type
+from typing import Tuple
import attr
-import torch
from torch import nn, Tensor
-from text_recognizer.data.mappings import AbstractMapping
-from text_recognizer.networks.base import BaseNetwork
from text_recognizer.networks.encoders.efficientnet import EfficientNet
from text_recognizer.networks.transformer.layers import Decoder
from text_recognizer.networks.transformer.positional_encodings import (
@@ -16,25 +13,24 @@ from text_recognizer.networks.transformer.positional_encodings import (
)
-@attr.s(auto_attribs=True)
-class ConvTransformer(BaseNetwork):
+@attr.s
+class ConvTransformer(nn.Module):
+ """Convolutional encoder and transformer decoder network."""
+
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
# Parameters and placeholders,
input_dims: Tuple[int, int, int] = attr.ib()
hidden_dim: int = attr.ib()
dropout_rate: float = attr.ib()
max_output_len: int = attr.ib()
num_classes: int = attr.ib()
- start_token: str = attr.ib()
- start_index: Tensor = attr.ib(init=False)
- end_token: str = attr.ib()
- end_index: Tensor = attr.ib(init=False)
- pad_token: str = attr.ib()
- pad_index: Tensor = attr.ib(init=False)
+ pad_index: Tensor = attr.ib()
# Modules.
encoder: EfficientNet = attr.ib()
decoder: Decoder = attr.ib()
- mapping: Type[AbstractMapping] = attr.ib()
latent_encoder: nn.Sequential = attr.ib(init=False)
token_embedding: nn.Embedding = attr.ib(init=False)
@@ -43,10 +39,6 @@ class ConvTransformer(BaseNetwork):
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
- self.start_index = self.mapping.get_index(self.start_token)
- self.end_index = self.mapping.get_index(self.end_token)
- self.pad_index = self.mapping.get_index(self.pad_token)
-
# Latent projector for down sampling number of filters and 2d
# positional encoding.
self.latent_encoder = nn.Sequential(
@@ -156,46 +148,3 @@ class ConvTransformer(BaseNetwork):
z = self.encode(x)
logits = self.decode(z, context)
return logits
-
- def predict(self, x: Tensor) -> Tensor:
- """Predicts text in image.
-
- Args:
- x (Tensor): Image(s) to extract text from.
-
- Shapes:
- - x: :math: `(B, H, W)`
- - output: :math: `(B, S)`
-
- Returns:
- Tensor: A tensor of token indices of the predictions from the model.
- """
- bsz = x.shape[0]
-
- # Encode image(s) to latent vectors.
- z = self.encode(x)
-
- # Create a placeholder matrix for storing outputs from the network
- output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
- output[:, 0] = self.start_index
-
- for i in range(1, self.max_output_len):
- context = output[:, :i] # (bsz, i)
- logits = self.decode(z, context) # (i, bsz, c)
- tokens = torch.argmax(logits, dim=-1) # (i, bsz)
- output[:, i : i + 1] = tokens[-1:]
-
- # Early stopping of prediction loop if token is end or padding token.
- if (
- output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index
- ).all():
- break
-
- # Set all tokens after end token to pad token.
- for i in range(1, self.max_output_len):
- idx = (
- output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index
- )
- output[idx, i] = self.pad_index
-
- return output
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 2770dc1..9202cce 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -24,9 +24,9 @@ class Attention(nn.Module):
dim: int = attr.ib()
num_heads: int = attr.ib()
+ causal: bool = attr.ib(default=False)
dim_head: int = attr.ib(default=64)
dropout_rate: float = attr.ib(default=0.0)
- casual: bool = attr.ib(default=False)
scale: float = attr.ib(init=False)
dropout: nn.Dropout = attr.ib(init=False)
fc: nn.Linear = attr.ib(init=False)
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index 9b2f236..66c9c50 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -30,8 +30,7 @@ class AttentionLayers(nn.Module):
causal: bool = attr.ib(default=False)
cross_attend: bool = attr.ib(default=False)
pre_norm: bool = attr.ib(default=True)
- rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None, init=False)
- has_pos_emb: bool = attr.ib(init=False)
+ rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None)
layer_types: Tuple[str, ...] = attr.ib(init=False)
layers: nn.ModuleList = attr.ib(init=False)
attn: partial = attr.ib(init=False)
@@ -40,12 +39,11 @@ class AttentionLayers(nn.Module):
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
- self.has_pos_emb = True if self.rotary_emb is not None else False
self.layer_types = self._get_layer_types() * self.depth
attn = load_partial_fn(
self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs
)
- norm = load_partial_fn(self.norm_fn, dim=self.dim)
+ norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim)
ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs)
self.layers = self._build_network(attn, norm, ff)
@@ -103,13 +101,11 @@ class AttentionLayers(nn.Module):
return x
+@attr.s(auto_attribs=True)
class Encoder(AttentionLayers):
- def __init__(self, **kwargs: Any) -> None:
- assert "causal" not in kwargs, "Cannot set causality on encoder"
- super().__init__(causal=False, **kwargs)
+ causal: bool = attr.ib(default=False, init=False)
+@attr.s(auto_attribs=True)
class Decoder(AttentionLayers):
- def __init__(self, **kwargs: Any) -> None:
- assert "causal" not in kwargs, "Cannot set causality on decoder"
- super().__init__(causal=True, **kwargs)
+ causal: bool = attr.ib(default=True, init=False)
diff --git a/text_recognizer/networks/transformer/norm.py b/text_recognizer/networks/transformer/norm.py
index 8bc3221..4930adf 100644
--- a/text_recognizer/networks/transformer/norm.py
+++ b/text_recognizer/networks/transformer/norm.py
@@ -12,9 +12,9 @@ from torch import Tensor
class ScaleNorm(nn.Module):
- def __init__(self, dim: int, eps: float = 1.0e-5) -> None:
+ def __init__(self, normalized_shape: int, eps: float = 1.0e-5) -> None:
super().__init__()
- self.scale = dim ** -0.5
+ self.scale = normalized_shape ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
@@ -24,9 +24,9 @@ class ScaleNorm(nn.Module):
class PreNorm(nn.Module):
- def __init__(self, dim: int, fn: Type[nn.Module]) -> None:
+ def __init__(self, normalized_shape: int, fn: Type[nn.Module]) -> None:
super().__init__()
- self.norm = nn.LayerNorm(dim)
+ self.norm = nn.LayerNorm(normalized_shape)
self.fn = fn
def forward(self, x: Tensor, **kwargs: Dict) -> Tensor:
diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py
index e822c57..c94e8dc 100644
--- a/text_recognizer/networks/util.py
+++ b/text_recognizer/networks/util.py
@@ -24,6 +24,6 @@ def activation_function(activation: str) -> Type[nn.Module]:
def load_partial_fn(fn: str, **kwargs: Any) -> partial:
- """Loads partial function."""
+ """Loads partial function/class."""
module = import_module(".".join(fn.split(".")[:-1]))
- return partial(getattr(module, fn.split(".")[0]), **kwargs)
+ return partial(getattr(module, fn.split(".")[-1]), **kwargs)
diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py
index 1f08e5e..5aa929b 100644
--- a/text_recognizer/networks/vqvae/vqvae.py
+++ b/text_recognizer/networks/vqvae/vqvae.py
@@ -1,5 +1,4 @@
"""The VQ-VAE."""
-
from typing import Any, Dict, List, Optional, Tuple
from torch import nn