summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2023-09-11 22:12:25 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2023-09-11 22:12:25 +0200
commit684da19a2ca83ee61011c37e36fa71b9eeb5ca6a (patch)
tree5cce2ddda428648c137b3083673c9650454ac973
parent925cf2f4e92b222af7bc4dd95fe47dba136c10bd (diff)
Update encoder/decoder attention and forward pass
-rw-r--r--text_recognizer/network/transformer/decoder.py40
-rw-r--r--text_recognizer/network/transformer/encoder.py26
2 files changed, 36 insertions, 30 deletions
diff --git a/text_recognizer/network/transformer/decoder.py b/text_recognizer/network/transformer/decoder.py
index 24a8ac4..4ebdd2c 100644
--- a/text_recognizer/network/transformer/decoder.py
+++ b/text_recognizer/network/transformer/decoder.py
@@ -3,14 +3,14 @@ from typing import Optional
from torch import Tensor, nn
from .attention import Attention
-from .ff import FeedForward
+from .embedding.rotary import RotaryEmbedding
class Decoder(nn.Module):
def __init__(
self,
dim: int,
- inner_dim: int,
+ ff_mult: int,
heads: int,
dim_head: int,
depth: int,
@@ -23,19 +23,25 @@ class Decoder(nn.Module):
nn.ModuleList(
[
Attention(
- dim,
- heads,
- True,
- dim_head,
- dropout_rate,
+ dim=dim,
+ heads=heads,
+ causal=True,
+ dim_head=dim_head,
+ ff_mult=ff_mult,
+ dropout_rate=dropout_rate,
+ use_flash=True,
+ norm_context=False,
+ rotary_emb=RotaryEmbedding(dim_head),
),
- FeedForward(dim, inner_dim, dropout_rate),
Attention(
- dim,
- heads,
- False,
- dim_head,
- dropout_rate,
+ dim=dim,
+ heads=heads,
+ causal=False,
+ dim_head=dim_head,
+ ff_mult=ff_mult,
+ dropout_rate=dropout_rate,
+ use_flash=True,
+ norm_context=False,
),
]
)
@@ -43,6 +49,11 @@ class Decoder(nn.Module):
]
)
+ def self_attn(self, x: Tensor, mask: Tensor) -> Tensor:
+ for self_attn, _ in self.layers:
+ x = x + self_attn(x, mask=mask)
+ return self.norm(x)
+
def forward(
self,
x: Tensor,
@@ -50,8 +61,7 @@ class Decoder(nn.Module):
mask: Optional[Tensor] = None,
) -> Tensor:
"""Applies decoder block on input signals."""
- for self_attn, ff, cross_attn in self.layers:
+ for self_attn, cross_attn in self.layers:
x = x + self_attn(x, mask=mask)
- x = x + ff(x)
x = x + cross_attn(x, context=context)
return self.norm(x)
diff --git a/text_recognizer/network/transformer/encoder.py b/text_recognizer/network/transformer/encoder.py
index 328a40c..1728c61 100644
--- a/text_recognizer/network/transformer/encoder.py
+++ b/text_recognizer/network/transformer/encoder.py
@@ -2,16 +2,15 @@
from torch import Tensor, nn
from .attention import Attention
-from .ff import FeedForward
class Encoder(nn.Module):
def __init__(
self,
dim: int,
- inner_dim: int,
heads: int,
dim_head: int,
+ ff_mult: int,
depth: int,
dropout_rate: float = 0.0,
) -> None:
@@ -19,17 +18,15 @@ class Encoder(nn.Module):
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList(
[
- nn.ModuleList(
- [
- Attention(
- dim,
- heads,
- False,
- dim_head,
- dropout_rate,
- ),
- FeedForward(dim, inner_dim, dropout_rate),
- ]
+ Attention(
+ dim=dim,
+ heads=heads,
+ causal=False,
+ dim_head=dim_head,
+ ff_mult=ff_mult,
+ dropout_rate=dropout_rate,
+ use_flash=True,
+ norm_context=False,
)
for _ in range(depth)
]
@@ -40,7 +37,6 @@ class Encoder(nn.Module):
x: Tensor,
) -> Tensor:
"""Applies decoder block on input signals."""
- for self_attn, ff in self.layers:
+ for self_attn in self.layers:
x = x + self_attn(x)
- x = x + ff(x)
return self.norm(x)