From 8c7768e8d321efec558e12bff9b89b2de615d541 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 13 May 2021 23:02:20 +0200 Subject: Decoder module working --- text_recognizer/networks/transformer/attention.py | 1 + text_recognizer/networks/transformer/layers.py | 13 +++++-------- text_recognizer/networks/transformer/transformer.py | 4 ++-- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index eabeadf..a3b53f0 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -23,6 +23,7 @@ class Attention(nn.Module): dropout_rate: float = 0.0, causal: bool = False, ) -> None: + super().__init__() self.scale = dim ** -0.5 self.num_heads = num_heads self.causal = causal diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index 4063425..b2c703f 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -4,13 +4,12 @@ from typing import Any, Dict, Optional, Type from click.types import Tuple -import torch from torch import nn, Tensor from .attention import Attention from .mlp import FeedForward from .residual import Residual -from .rotary_embedding import RotaryEmbedding +from .positional_encodings.rotary_embedding import RotaryEmbedding class AttentionLayers(nn.Module): @@ -24,7 +23,6 @@ class AttentionLayers(nn.Module): attn_fn: Type[nn.Module] = Attention, norm_fn: Type[nn.Module] = nn.LayerNorm, ff_fn: Type[nn.Module] = FeedForward, - residual_fn: Type[nn.Module] = Residual, rotary_emb: Optional[Type[nn.Module]] = None, rotary_emb_dim: Optional[int] = None, causal: bool = False, @@ -33,10 +31,10 @@ class AttentionLayers(nn.Module): ) -> None: super().__init__() attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs) - norm_fn = partial(norm_fn, dim=dim) + norm_fn = partial(norm_fn, dim) ff_fn = partial(ff_fn, dim=dim, **ff_kwargs) self.layer_types = self._get_layer_types(cross_attend) * depth - self.layers = self._build_network(causal, attn_fn, norm_fn, ff_fn, residual_fn) + self.layers = self._build_network(causal, attn_fn, norm_fn, ff_fn) rotary_emb_dim = max(rotary_emb_dim, 32) if rotary_emb_dim is not None else None self.rotary_emb = RotaryEmbedding(rotary_emb_dim) if rotary_emb else None self.pre_norm = pre_norm @@ -55,7 +53,6 @@ class AttentionLayers(nn.Module): attn_fn: partial, norm_fn: partial, ff_fn: partial, - residual_fn: Type[nn.Module], ) -> nn.ModuleList: """Configures transformer network.""" layers = nn.ModuleList([]) @@ -67,9 +64,9 @@ class AttentionLayers(nn.Module): elif layer_type == "f": layer = ff_fn() - residual_fn = residual_fn() + residual_fn = Residual() - layers.append(nn.modulelist([norm_fn(), layer, residual_fn])) + layers.append(nn.ModuleList([norm_fn(), layer, residual_fn])) return layers def forward( diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py index 36f86ac..60ab1ce 100644 --- a/text_recognizer/networks/transformer/transformer.py +++ b/text_recognizer/networks/transformer/transformer.py @@ -1,5 +1,5 @@ """Transformer wrapper.""" -from typing import Optional, Type +from typing import Any, Optional, Type from torch import nn, Tensor @@ -58,4 +58,4 @@ class Transformer(nn.Module): x = self.project_emb(x) x = self.attn_layers(x, mask=mask, **kwargs) out = self.logits(x) if not return_embeddings else x - return x + return out -- cgit v1.2.3-70-g09d2