summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/networks/transformer/attention.py1
-rw-r--r--text_recognizer/networks/transformer/layers.py13
-rw-r--r--text_recognizer/networks/transformer/transformer.py4
3 files changed, 8 insertions, 10 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index eabeadf..a3b53f0 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -23,6 +23,7 @@ class Attention(nn.Module):
dropout_rate: float = 0.0,
causal: bool = False,
) -> None:
+ super().__init__()
self.scale = dim ** -0.5
self.num_heads = num_heads
self.causal = causal
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index 4063425..b2c703f 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -4,13 +4,12 @@ from typing import Any, Dict, Optional, Type
from click.types import Tuple
-import torch
from torch import nn, Tensor
from .attention import Attention
from .mlp import FeedForward
from .residual import Residual
-from .rotary_embedding import RotaryEmbedding
+from .positional_encodings.rotary_embedding import RotaryEmbedding
class AttentionLayers(nn.Module):
@@ -24,7 +23,6 @@ class AttentionLayers(nn.Module):
attn_fn: Type[nn.Module] = Attention,
norm_fn: Type[nn.Module] = nn.LayerNorm,
ff_fn: Type[nn.Module] = FeedForward,
- residual_fn: Type[nn.Module] = Residual,
rotary_emb: Optional[Type[nn.Module]] = None,
rotary_emb_dim: Optional[int] = None,
causal: bool = False,
@@ -33,10 +31,10 @@ class AttentionLayers(nn.Module):
) -> None:
super().__init__()
attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs)
- norm_fn = partial(norm_fn, dim=dim)
+ norm_fn = partial(norm_fn, dim)
ff_fn = partial(ff_fn, dim=dim, **ff_kwargs)
self.layer_types = self._get_layer_types(cross_attend) * depth
- self.layers = self._build_network(causal, attn_fn, norm_fn, ff_fn, residual_fn)
+ self.layers = self._build_network(causal, attn_fn, norm_fn, ff_fn)
rotary_emb_dim = max(rotary_emb_dim, 32) if rotary_emb_dim is not None else None
self.rotary_emb = RotaryEmbedding(rotary_emb_dim) if rotary_emb else None
self.pre_norm = pre_norm
@@ -55,7 +53,6 @@ class AttentionLayers(nn.Module):
attn_fn: partial,
norm_fn: partial,
ff_fn: partial,
- residual_fn: Type[nn.Module],
) -> nn.ModuleList:
"""Configures transformer network."""
layers = nn.ModuleList([])
@@ -67,9 +64,9 @@ class AttentionLayers(nn.Module):
elif layer_type == "f":
layer = ff_fn()
- residual_fn = residual_fn()
+ residual_fn = Residual()
- layers.append(nn.modulelist([norm_fn(), layer, residual_fn]))
+ layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
return layers
def forward(
diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py
index 36f86ac..60ab1ce 100644
--- a/text_recognizer/networks/transformer/transformer.py
+++ b/text_recognizer/networks/transformer/transformer.py
@@ -1,5 +1,5 @@
"""Transformer wrapper."""
-from typing import Optional, Type
+from typing import Any, Optional, Type
from torch import nn, Tensor
@@ -58,4 +58,4 @@ class Transformer(nn.Module):
x = self.project_emb(x)
x = self.attn_layers(x, mask=mask, **kwargs)
out = self.logits(x) if not return_embeddings else x
- return x
+ return out