diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-09-11 22:12:25 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-09-11 22:12:25 +0200 |
commit | 684da19a2ca83ee61011c37e36fa71b9eeb5ca6a (patch) | |
tree | 5cce2ddda428648c137b3083673c9650454ac973 /text_recognizer/network/transformer/encoder.py | |
parent | 925cf2f4e92b222af7bc4dd95fe47dba136c10bd (diff) |
Update encoder/decoder attention and forward pass
Diffstat (limited to 'text_recognizer/network/transformer/encoder.py')
-rw-r--r-- | text_recognizer/network/transformer/encoder.py | 26 |
1 files changed, 11 insertions, 15 deletions
diff --git a/text_recognizer/network/transformer/encoder.py b/text_recognizer/network/transformer/encoder.py index 328a40c..1728c61 100644 --- a/text_recognizer/network/transformer/encoder.py +++ b/text_recognizer/network/transformer/encoder.py @@ -2,16 +2,15 @@ from torch import Tensor, nn from .attention import Attention -from .ff import FeedForward class Encoder(nn.Module): def __init__( self, dim: int, - inner_dim: int, heads: int, dim_head: int, + ff_mult: int, depth: int, dropout_rate: float = 0.0, ) -> None: @@ -19,17 +18,15 @@ class Encoder(nn.Module): self.norm = nn.LayerNorm(dim) self.layers = nn.ModuleList( [ - nn.ModuleList( - [ - Attention( - dim, - heads, - False, - dim_head, - dropout_rate, - ), - FeedForward(dim, inner_dim, dropout_rate), - ] + Attention( + dim=dim, + heads=heads, + causal=False, + dim_head=dim_head, + ff_mult=ff_mult, + dropout_rate=dropout_rate, + use_flash=True, + norm_context=False, ) for _ in range(depth) ] @@ -40,7 +37,6 @@ class Encoder(nn.Module): x: Tensor, ) -> Tensor: """Applies decoder block on input signals.""" - for self_attn, ff in self.layers: + for self_attn in self.layers: x = x + self_attn(x) - x = x + ff(x) return self.norm(x) |