summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/layers.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-13 23:02:20 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-13 23:02:20 +0200
commit8c7768e8d321efec558e12bff9b89b2de615d541 (patch)
tree67f5928c5584e8826c01834d06d34cd7e60546ba /text_recognizer/networks/transformer/layers.py
parentc9c60678673e19ad3367339eb8e7a093e5a98474 (diff)
Decoder module working
Diffstat (limited to 'text_recognizer/networks/transformer/layers.py')
-rw-r--r--text_recognizer/networks/transformer/layers.py13
1 files changed, 5 insertions, 8 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index 4063425..b2c703f 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -4,13 +4,12 @@ from typing import Any, Dict, Optional, Type
from click.types import Tuple
-import torch
from torch import nn, Tensor
from .attention import Attention
from .mlp import FeedForward
from .residual import Residual
-from .rotary_embedding import RotaryEmbedding
+from .positional_encodings.rotary_embedding import RotaryEmbedding
class AttentionLayers(nn.Module):
@@ -24,7 +23,6 @@ class AttentionLayers(nn.Module):
attn_fn: Type[nn.Module] = Attention,
norm_fn: Type[nn.Module] = nn.LayerNorm,
ff_fn: Type[nn.Module] = FeedForward,
- residual_fn: Type[nn.Module] = Residual,
rotary_emb: Optional[Type[nn.Module]] = None,
rotary_emb_dim: Optional[int] = None,
causal: bool = False,
@@ -33,10 +31,10 @@ class AttentionLayers(nn.Module):
) -> None:
super().__init__()
attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs)
- norm_fn = partial(norm_fn, dim=dim)
+ norm_fn = partial(norm_fn, dim)
ff_fn = partial(ff_fn, dim=dim, **ff_kwargs)
self.layer_types = self._get_layer_types(cross_attend) * depth
- self.layers = self._build_network(causal, attn_fn, norm_fn, ff_fn, residual_fn)
+ self.layers = self._build_network(causal, attn_fn, norm_fn, ff_fn)
rotary_emb_dim = max(rotary_emb_dim, 32) if rotary_emb_dim is not None else None
self.rotary_emb = RotaryEmbedding(rotary_emb_dim) if rotary_emb else None
self.pre_norm = pre_norm
@@ -55,7 +53,6 @@ class AttentionLayers(nn.Module):
attn_fn: partial,
norm_fn: partial,
ff_fn: partial,
- residual_fn: Type[nn.Module],
) -> nn.ModuleList:
"""Configures transformer network."""
layers = nn.ModuleList([])
@@ -67,9 +64,9 @@ class AttentionLayers(nn.Module):
elif layer_type == "f":
layer = ff_fn()
- residual_fn = residual_fn()
+ residual_fn = Residual()
- layers.append(nn.modulelist([norm_fn(), layer, residual_fn]))
+ layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
return layers
def forward(