From 49ca6ade1a19f7f9c702171537fe4be0dfcda66d Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 25 Aug 2023 23:19:14 +0200 Subject: Rename and add flash atten --- .../networks/transformer/decoder_block.py | 44 ---------------------- 1 file changed, 44 deletions(-) delete mode 100644 text_recognizer/networks/transformer/decoder_block.py (limited to 'text_recognizer/networks/transformer/decoder_block.py') diff --git a/text_recognizer/networks/transformer/decoder_block.py b/text_recognizer/networks/transformer/decoder_block.py deleted file mode 100644 index b8eb5c4..0000000 --- a/text_recognizer/networks/transformer/decoder_block.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Transformer decoder module.""" -from copy import deepcopy -from typing import Optional, Type - -from torch import Tensor, nn - -from text_recognizer.networks.transformer.attention import Attention -from text_recognizer.networks.transformer.embeddings.rotary import RotaryEmbedding -from text_recognizer.networks.transformer.ff import FeedForward - - -class DecoderBlock(nn.Module): - """Residual decoder block.""" - - def __init__( - self, - self_attn: Attention, - norm: Type[nn.Module], - ff: FeedForward, - cross_attn: Optional[Attention] = None, - ) -> None: - super().__init__() - self.ln_attn = norm - self.attn = self_attn - self.ln_cross_attn = deepcopy(norm) - self.cross_attn = cross_attn - self.ln_ff = deepcopy(norm) - self.ff = ff - - def forward( - self, - x: Tensor, - rotary_embedding: RotaryEmbedding, - context: Optional[Tensor] = None, - mask: Optional[Tensor] = None, - ) -> Tensor: - """Applies decoder block on input signals.""" - x = x + self.attn(self.ln_attn(x), mask=mask, rotary_embedding=rotary_embedding) - x = x + self.cross_attn( - x=self.ln_cross_attn(x), - context=context, - ) - x = x + self.ff(self.ln_ff(x)) - return x -- cgit v1.2.3-70-g09d2