diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-08-25 23:19:14 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-08-25 23:19:14 +0200 |
commit | 49ca6ade1a19f7f9c702171537fe4be0dfcda66d (patch) | |
tree | 20062ed1910758481f3d5fff11159706c7b990c6 /text_recognizer/networks/transformer/decoder.py | |
parent | 0421daf6bd97596703f426ba61c401599b538eeb (diff) |
Rename and add flash atten
Diffstat (limited to 'text_recognizer/networks/transformer/decoder.py')
-rw-r--r-- | text_recognizer/networks/transformer/decoder.py | 41 |
1 files changed, 0 insertions, 41 deletions
diff --git a/text_recognizer/networks/transformer/decoder.py b/text_recognizer/networks/transformer/decoder.py deleted file mode 100644 index 826bc13..0000000 --- a/text_recognizer/networks/transformer/decoder.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Transformer decoder module.""" -from copy import deepcopy -from typing import Optional - -from torch import Tensor, nn - -from text_recognizer.networks.transformer.decoder_block import DecoderBlock -from text_recognizer.networks.transformer.embeddings.rotary import RotaryEmbedding - - -class Decoder(nn.Module): - """Decoder Network.""" - - def __init__( - self, - depth: int, - dim: int, - block: DecoderBlock, - rotary_embedding: RotaryEmbedding, - ) -> None: - super().__init__() - self.depth = depth - self.rotary_embedding = rotary_embedding - self.blocks = nn.ModuleList([deepcopy(block) for _ in range(self.depth)]) - self.ln = nn.LayerNorm(dim) - - def forward( - self, - x: Tensor, - context: Optional[Tensor] = None, - mask: Optional[Tensor] = None, - ) -> Tensor: - """Applies attention blocks.""" - for block in self.blocks: - x = block( - x=x, - context=context, - mask=mask, - rotary_embedding=self.rotary_embedding, - ) - return self.ln(x) |