From 684da19a2ca83ee61011c37e36fa71b9eeb5ca6a Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 11 Sep 2023 22:12:25 +0200 Subject: Update encoder/decoder attention and forward pass --- text_recognizer/network/transformer/decoder.py | 40 ++++++++++++++++---------- text_recognizer/network/transformer/encoder.py | 26 +++++++---------- 2 files changed, 36 insertions(+), 30 deletions(-) diff --git a/text_recognizer/network/transformer/decoder.py b/text_recognizer/network/transformer/decoder.py index 24a8ac4..4ebdd2c 100644 --- a/text_recognizer/network/transformer/decoder.py +++ b/text_recognizer/network/transformer/decoder.py @@ -3,14 +3,14 @@ from typing import Optional from torch import Tensor, nn from .attention import Attention -from .ff import FeedForward +from .embedding.rotary import RotaryEmbedding class Decoder(nn.Module): def __init__( self, dim: int, - inner_dim: int, + ff_mult: int, heads: int, dim_head: int, depth: int, @@ -23,19 +23,25 @@ class Decoder(nn.Module): nn.ModuleList( [ Attention( - dim, - heads, - True, - dim_head, - dropout_rate, + dim=dim, + heads=heads, + causal=True, + dim_head=dim_head, + ff_mult=ff_mult, + dropout_rate=dropout_rate, + use_flash=True, + norm_context=False, + rotary_emb=RotaryEmbedding(dim_head), ), - FeedForward(dim, inner_dim, dropout_rate), Attention( - dim, - heads, - False, - dim_head, - dropout_rate, + dim=dim, + heads=heads, + causal=False, + dim_head=dim_head, + ff_mult=ff_mult, + dropout_rate=dropout_rate, + use_flash=True, + norm_context=False, ), ] ) @@ -43,6 +49,11 @@ class Decoder(nn.Module): ] ) + def self_attn(self, x: Tensor, mask: Tensor) -> Tensor: + for self_attn, _ in self.layers: + x = x + self_attn(x, mask=mask) + return self.norm(x) + def forward( self, x: Tensor, @@ -50,8 +61,7 @@ class Decoder(nn.Module): mask: Optional[Tensor] = None, ) -> Tensor: """Applies decoder block on input signals.""" - for self_attn, ff, cross_attn in self.layers: + for self_attn, cross_attn in self.layers: x = x + self_attn(x, mask=mask) - x = x + ff(x) x = x + cross_attn(x, context=context) return self.norm(x) diff --git a/text_recognizer/network/transformer/encoder.py b/text_recognizer/network/transformer/encoder.py index 328a40c..1728c61 100644 --- a/text_recognizer/network/transformer/encoder.py +++ b/text_recognizer/network/transformer/encoder.py @@ -2,16 +2,15 @@ from torch import Tensor, nn from .attention import Attention -from .ff import FeedForward class Encoder(nn.Module): def __init__( self, dim: int, - inner_dim: int, heads: int, dim_head: int, + ff_mult: int, depth: int, dropout_rate: float = 0.0, ) -> None: @@ -19,17 +18,15 @@ class Encoder(nn.Module): self.norm = nn.LayerNorm(dim) self.layers = nn.ModuleList( [ - nn.ModuleList( - [ - Attention( - dim, - heads, - False, - dim_head, - dropout_rate, - ), - FeedForward(dim, inner_dim, dropout_rate), - ] + Attention( + dim=dim, + heads=heads, + causal=False, + dim_head=dim_head, + ff_mult=ff_mult, + dropout_rate=dropout_rate, + use_flash=True, + norm_context=False, ) for _ in range(depth) ] @@ -40,7 +37,6 @@ class Encoder(nn.Module): x: Tensor, ) -> Tensor: """Applies decoder block on input signals.""" - for self_attn, ff in self.layers: + for self_attn in self.layers: x = x + self_attn(x) - x = x + ff(x) return self.norm(x) -- cgit v1.2.3-70-g09d2