summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/networks/transformer/attention_layers.py19
-rw-r--r--text_recognizer/networks/transformer/layers.py77
-rw-r--r--text_recognizer/networks/transformer/norm.py11
-rw-r--r--text_recognizer/networks/transformer/nystromer/__init__.py2
-rw-r--r--text_recognizer/networks/transformer/residual.py8
5 files changed, 87 insertions, 30 deletions
diff --git a/text_recognizer/networks/transformer/attention_layers.py b/text_recognizer/networks/transformer/attention_layers.py
deleted file mode 100644
index 721fa27..0000000
--- a/text_recognizer/networks/transformer/attention_layers.py
+++ /dev/null
@@ -1,19 +0,0 @@
-"""Generates the attention layer architecture."""
-from typing import Type
-
-import torch
-from torch import nn, Tensor
-
-
-class AttentionLayers(nn.Module):
- def __init__(
- self,
- dim: int,
- depth: int,
- num_heads: int,
- norm_layer: Type[nn.Module],
- causal: bool = False,
- cross_attend: bool = False,
- only_cross: bool = False,
- ) -> None:
- pass
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
new file mode 100644
index 0000000..1c951ae
--- /dev/null
+++ b/text_recognizer/networks/transformer/layers.py
@@ -0,0 +1,77 @@
+"""Generates the attention layer architecture."""
+from functools import partial
+from typing import Dict, Optional, Type
+
+from click.types import Tuple
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention
+from .mlp import FeedForward
+from .residual import Residual
+
+
+class AttentionLayers(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ depth: int,
+ num_heads: int,
+ ff_kwargs: Dict,
+ attn_kwargs: Dict,
+ attn_fn: Type[nn.Module] = Attention,
+ norm_fn: Type[nn.Module] = nn.LayerNorm,
+ ff_fn: Type[nn.Module] = FeedForward,
+ residual_fn: Type[nn.Module] = Residual,
+ causal: bool = False,
+ cross_attend: bool = False,
+ ) -> None:
+ super().__init__()
+ attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs)
+ norm_fn = partial(norm_fn, dim=dim)
+ ff_fn = partial(ff_fn, dim=dim, **ff_kwargs)
+ layer_types = self._get_layer_types(cross_attend) * depth
+ self.layers = self._build_network(
+ layer_types, causal, attn_fn, norm_fn, ff_fn, residual_fn
+ )
+
+ @staticmethod
+ def _get_layer_types(cross_attend: bool) -> Tuple:
+ """Get layer specification."""
+ if cross_attend:
+ return "a", "c", "f"
+ return "a", "f"
+
+ @staticmethod
+ def _build_network(
+ layer_types: Tuple,
+ causal: bool,
+ attn_fn: partial,
+ norm_fn: partial,
+ ff_fn: partial,
+ residual_fn: Type[nn.Module],
+ ) -> nn.ModuleList:
+ """Configures transformer layers."""
+ layers = nn.ModuleList([])
+ for layer_type in layer_types:
+ if layer_type == "a":
+ layer = attn_fn(causal=causal)
+ elif layer_type == "c":
+ layer = attn_fn()
+ elif layer_type == "f":
+ layer = ff_fn()
+
+ residual_fn = residual_fn()
+
+ layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
+ return layers
+
+ def forward(
+ self,
+ x: Tensor,
+ context: Optional[Tensor] = None,
+ mask: Optional[Tensor] = None,
+ context_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ pass
diff --git a/text_recognizer/networks/transformer/norm.py b/text_recognizer/networks/transformer/norm.py
index 58c8770..8bc3221 100644
--- a/text_recognizer/networks/transformer/norm.py
+++ b/text_recognizer/networks/transformer/norm.py
@@ -11,17 +11,6 @@ from torch import nn
from torch import Tensor
-class Rezero(nn.Module):
- def __init__(self, fn: Callable) -> None:
- super().__init__()
- self.fn = fn
- self.g = nn.Parameter(torch.zeros(1))
-
- def forward(self, x: Tensor, **kwargs: Dict) -> Tensor:
- x, *rest = self.fn(x, **kwargs)
- return (x * self.g, *rest)
-
-
class ScaleNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1.0e-5) -> None:
super().__init__()
diff --git a/text_recognizer/networks/transformer/nystromer/__init__.py b/text_recognizer/networks/transformer/nystromer/__init__.py
index e69de29..ea2c6fc 100644
--- a/text_recognizer/networks/transformer/nystromer/__init__.py
+++ b/text_recognizer/networks/transformer/nystromer/__init__.py
@@ -0,0 +1,2 @@
+"""Nyströmer module."""
+from .nystromer import Nystromer
diff --git a/text_recognizer/networks/transformer/residual.py b/text_recognizer/networks/transformer/residual.py
new file mode 100644
index 0000000..1547df6
--- /dev/null
+++ b/text_recognizer/networks/transformer/residual.py
@@ -0,0 +1,8 @@
+"""Residual function."""
+from torch import nn, Tensor
+
+
+class Residual(nn.Module):
+ def forward(self, x: Tensor, residual: Tensor) -> Tensor:
+ """Applies the residual function."""
+ return x + residual