summaryrefslogtreecommitdiff
path: root/text_recognizer/network/transformer/decoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/network/transformer/decoder.py')
-rw-r--r--text_recognizer/network/transformer/decoder.py40
1 files changed, 25 insertions, 15 deletions
diff --git a/text_recognizer/network/transformer/decoder.py b/text_recognizer/network/transformer/decoder.py
index 24a8ac4..4ebdd2c 100644
--- a/text_recognizer/network/transformer/decoder.py
+++ b/text_recognizer/network/transformer/decoder.py
@@ -3,14 +3,14 @@ from typing import Optional
from torch import Tensor, nn
from .attention import Attention
-from .ff import FeedForward
+from .embedding.rotary import RotaryEmbedding
class Decoder(nn.Module):
def __init__(
self,
dim: int,
- inner_dim: int,
+ ff_mult: int,
heads: int,
dim_head: int,
depth: int,
@@ -23,19 +23,25 @@ class Decoder(nn.Module):
nn.ModuleList(
[
Attention(
- dim,
- heads,
- True,
- dim_head,
- dropout_rate,
+ dim=dim,
+ heads=heads,
+ causal=True,
+ dim_head=dim_head,
+ ff_mult=ff_mult,
+ dropout_rate=dropout_rate,
+ use_flash=True,
+ norm_context=False,
+ rotary_emb=RotaryEmbedding(dim_head),
),
- FeedForward(dim, inner_dim, dropout_rate),
Attention(
- dim,
- heads,
- False,
- dim_head,
- dropout_rate,
+ dim=dim,
+ heads=heads,
+ causal=False,
+ dim_head=dim_head,
+ ff_mult=ff_mult,
+ dropout_rate=dropout_rate,
+ use_flash=True,
+ norm_context=False,
),
]
)
@@ -43,6 +49,11 @@ class Decoder(nn.Module):
]
)
+ def self_attn(self, x: Tensor, mask: Tensor) -> Tensor:
+ for self_attn, _ in self.layers:
+ x = x + self_attn(x, mask=mask)
+ return self.norm(x)
+
def forward(
self,
x: Tensor,
@@ -50,8 +61,7 @@ class Decoder(nn.Module):
mask: Optional[Tensor] = None,
) -> Tensor:
"""Applies decoder block on input signals."""
- for self_attn, ff, cross_attn in self.layers:
+ for self_attn, cross_attn in self.layers:
x = x + self_attn(x, mask=mask)
- x = x + ff(x)
x = x + cross_attn(x, context=context)
return self.norm(x)