From 684da19a2ca83ee61011c37e36fa71b9eeb5ca6a Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 11 Sep 2023 22:12:25 +0200 Subject: Update encoder/decoder attention and forward pass --- text_recognizer/network/transformer/decoder.py | 40 ++++++++++++++++---------- 1 file changed, 25 insertions(+), 15 deletions(-) (limited to 'text_recognizer/network/transformer/decoder.py') diff --git a/text_recognizer/network/transformer/decoder.py b/text_recognizer/network/transformer/decoder.py index 24a8ac4..4ebdd2c 100644 --- a/text_recognizer/network/transformer/decoder.py +++ b/text_recognizer/network/transformer/decoder.py @@ -3,14 +3,14 @@ from typing import Optional from torch import Tensor, nn from .attention import Attention -from .ff import FeedForward +from .embedding.rotary import RotaryEmbedding class Decoder(nn.Module): def __init__( self, dim: int, - inner_dim: int, + ff_mult: int, heads: int, dim_head: int, depth: int, @@ -23,19 +23,25 @@ class Decoder(nn.Module): nn.ModuleList( [ Attention( - dim, - heads, - True, - dim_head, - dropout_rate, + dim=dim, + heads=heads, + causal=True, + dim_head=dim_head, + ff_mult=ff_mult, + dropout_rate=dropout_rate, + use_flash=True, + norm_context=False, + rotary_emb=RotaryEmbedding(dim_head), ), - FeedForward(dim, inner_dim, dropout_rate), Attention( - dim, - heads, - False, - dim_head, - dropout_rate, + dim=dim, + heads=heads, + causal=False, + dim_head=dim_head, + ff_mult=ff_mult, + dropout_rate=dropout_rate, + use_flash=True, + norm_context=False, ), ] ) @@ -43,6 +49,11 @@ class Decoder(nn.Module): ] ) + def self_attn(self, x: Tensor, mask: Tensor) -> Tensor: + for self_attn, _ in self.layers: + x = x + self_attn(x, mask=mask) + return self.norm(x) + def forward( self, x: Tensor, @@ -50,8 +61,7 @@ class Decoder(nn.Module): mask: Optional[Tensor] = None, ) -> Tensor: """Applies decoder block on input signals.""" - for self_attn, ff, cross_attn in self.layers: + for self_attn, cross_attn in self.layers: x = x + self_attn(x, mask=mask) - x = x + ff(x) x = x + cross_attn(x, context=context) return self.norm(x) -- cgit v1.2.3-70-g09d2