summaryrefslogtreecommitdiff
path: root/text_recognizer/network/transformer/encoder.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2023-09-11 22:12:25 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2023-09-11 22:12:25 +0200
commit684da19a2ca83ee61011c37e36fa71b9eeb5ca6a (patch)
tree5cce2ddda428648c137b3083673c9650454ac973 /text_recognizer/network/transformer/encoder.py
parent925cf2f4e92b222af7bc4dd95fe47dba136c10bd (diff)
Update encoder/decoder attention and forward pass
Diffstat (limited to 'text_recognizer/network/transformer/encoder.py')
-rw-r--r--text_recognizer/network/transformer/encoder.py26
1 files changed, 11 insertions, 15 deletions
diff --git a/text_recognizer/network/transformer/encoder.py b/text_recognizer/network/transformer/encoder.py
index 328a40c..1728c61 100644
--- a/text_recognizer/network/transformer/encoder.py
+++ b/text_recognizer/network/transformer/encoder.py
@@ -2,16 +2,15 @@
from torch import Tensor, nn
from .attention import Attention
-from .ff import FeedForward
class Encoder(nn.Module):
def __init__(
self,
dim: int,
- inner_dim: int,
heads: int,
dim_head: int,
+ ff_mult: int,
depth: int,
dropout_rate: float = 0.0,
) -> None:
@@ -19,17 +18,15 @@ class Encoder(nn.Module):
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList(
[
- nn.ModuleList(
- [
- Attention(
- dim,
- heads,
- False,
- dim_head,
- dropout_rate,
- ),
- FeedForward(dim, inner_dim, dropout_rate),
- ]
+ Attention(
+ dim=dim,
+ heads=heads,
+ causal=False,
+ dim_head=dim_head,
+ ff_mult=ff_mult,
+ dropout_rate=dropout_rate,
+ use_flash=True,
+ norm_context=False,
)
for _ in range(depth)
]
@@ -40,7 +37,6 @@ class Encoder(nn.Module):
x: Tensor,
) -> Tensor:
"""Applies decoder block on input signals."""
- for self_attn, ff in self.layers:
+ for self_attn in self.layers:
x = x + self_attn(x)
- x = x + ff(x)
return self.norm(x)