diff options
Diffstat (limited to 'text_recognizer/networks/transformer/layers.py')
-rw-r--r-- | text_recognizer/networks/transformer/layers.py | 13 |
1 files changed, 5 insertions, 8 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index 4063425..b2c703f 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -4,13 +4,12 @@ from typing import Any, Dict, Optional, Type from click.types import Tuple -import torch from torch import nn, Tensor from .attention import Attention from .mlp import FeedForward from .residual import Residual -from .rotary_embedding import RotaryEmbedding +from .positional_encodings.rotary_embedding import RotaryEmbedding class AttentionLayers(nn.Module): @@ -24,7 +23,6 @@ class AttentionLayers(nn.Module): attn_fn: Type[nn.Module] = Attention, norm_fn: Type[nn.Module] = nn.LayerNorm, ff_fn: Type[nn.Module] = FeedForward, - residual_fn: Type[nn.Module] = Residual, rotary_emb: Optional[Type[nn.Module]] = None, rotary_emb_dim: Optional[int] = None, causal: bool = False, @@ -33,10 +31,10 @@ class AttentionLayers(nn.Module): ) -> None: super().__init__() attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs) - norm_fn = partial(norm_fn, dim=dim) + norm_fn = partial(norm_fn, dim) ff_fn = partial(ff_fn, dim=dim, **ff_kwargs) self.layer_types = self._get_layer_types(cross_attend) * depth - self.layers = self._build_network(causal, attn_fn, norm_fn, ff_fn, residual_fn) + self.layers = self._build_network(causal, attn_fn, norm_fn, ff_fn) rotary_emb_dim = max(rotary_emb_dim, 32) if rotary_emb_dim is not None else None self.rotary_emb = RotaryEmbedding(rotary_emb_dim) if rotary_emb else None self.pre_norm = pre_norm @@ -55,7 +53,6 @@ class AttentionLayers(nn.Module): attn_fn: partial, norm_fn: partial, ff_fn: partial, - residual_fn: Type[nn.Module], ) -> nn.ModuleList: """Configures transformer network.""" layers = nn.ModuleList([]) @@ -67,9 +64,9 @@ class AttentionLayers(nn.Module): elif layer_type == "f": layer = ff_fn() - residual_fn = residual_fn() + residual_fn = Residual() - layers.append(nn.modulelist([norm_fn(), layer, residual_fn])) + layers.append(nn.ModuleList([norm_fn(), layer, residual_fn])) return layers def forward( |