From 49ca6ade1a19f7f9c702171537fe4be0dfcda66d Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 25 Aug 2023 23:19:14 +0200 Subject: Rename and add flash atten --- text_recognizer/networks/transformer/decoder.py | 41 ------------------------- 1 file changed, 41 deletions(-) delete mode 100644 text_recognizer/networks/transformer/decoder.py (limited to 'text_recognizer/networks/transformer/decoder.py') diff --git a/text_recognizer/networks/transformer/decoder.py b/text_recognizer/networks/transformer/decoder.py deleted file mode 100644 index 826bc13..0000000 --- a/text_recognizer/networks/transformer/decoder.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Transformer decoder module.""" -from copy import deepcopy -from typing import Optional - -from torch import Tensor, nn - -from text_recognizer.networks.transformer.decoder_block import DecoderBlock -from text_recognizer.networks.transformer.embeddings.rotary import RotaryEmbedding - - -class Decoder(nn.Module): - """Decoder Network.""" - - def __init__( - self, - depth: int, - dim: int, - block: DecoderBlock, - rotary_embedding: RotaryEmbedding, - ) -> None: - super().__init__() - self.depth = depth - self.rotary_embedding = rotary_embedding - self.blocks = nn.ModuleList([deepcopy(block) for _ in range(self.depth)]) - self.ln = nn.LayerNorm(dim) - - def forward( - self, - x: Tensor, - context: Optional[Tensor] = None, - mask: Optional[Tensor] = None, - ) -> Tensor: - """Applies attention blocks.""" - for block in self.blocks: - x = block( - x=x, - context=context, - mask=mask, - rotary_embedding=self.rotary_embedding, - ) - return self.ln(x) -- cgit v1.2.3-70-g09d2