From 8c7768e8d321efec558e12bff9b89b2de615d541 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 13 May 2021 23:02:20 +0200 Subject: Decoder module working --- text_recognizer/networks/transformer/layers.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) (limited to 'text_recognizer/networks/transformer/layers.py') diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index 4063425..b2c703f 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -4,13 +4,12 @@ from typing import Any, Dict, Optional, Type from click.types import Tuple -import torch from torch import nn, Tensor from .attention import Attention from .mlp import FeedForward from .residual import Residual -from .rotary_embedding import RotaryEmbedding +from .positional_encodings.rotary_embedding import RotaryEmbedding class AttentionLayers(nn.Module): @@ -24,7 +23,6 @@ class AttentionLayers(nn.Module): attn_fn: Type[nn.Module] = Attention, norm_fn: Type[nn.Module] = nn.LayerNorm, ff_fn: Type[nn.Module] = FeedForward, - residual_fn: Type[nn.Module] = Residual, rotary_emb: Optional[Type[nn.Module]] = None, rotary_emb_dim: Optional[int] = None, causal: bool = False, @@ -33,10 +31,10 @@ class AttentionLayers(nn.Module): ) -> None: super().__init__() attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs) - norm_fn = partial(norm_fn, dim=dim) + norm_fn = partial(norm_fn, dim) ff_fn = partial(ff_fn, dim=dim, **ff_kwargs) self.layer_types = self._get_layer_types(cross_attend) * depth - self.layers = self._build_network(causal, attn_fn, norm_fn, ff_fn, residual_fn) + self.layers = self._build_network(causal, attn_fn, norm_fn, ff_fn) rotary_emb_dim = max(rotary_emb_dim, 32) if rotary_emb_dim is not None else None self.rotary_emb = RotaryEmbedding(rotary_emb_dim) if rotary_emb else None self.pre_norm = pre_norm @@ -55,7 +53,6 @@ class AttentionLayers(nn.Module): attn_fn: partial, norm_fn: partial, ff_fn: partial, - residual_fn: Type[nn.Module], ) -> nn.ModuleList: """Configures transformer network.""" layers = nn.ModuleList([]) @@ -67,9 +64,9 @@ class AttentionLayers(nn.Module): elif layer_type == "f": layer = ff_fn() - residual_fn = residual_fn() + residual_fn = Residual() - layers.append(nn.modulelist([norm_fn(), layer, residual_fn])) + layers.append(nn.ModuleList([norm_fn(), layer, residual_fn])) return layers def forward( -- cgit v1.2.3-70-g09d2