diff options
Diffstat (limited to 'text_recognizer/networks/transformer/layers.py')
-rw-r--r-- | text_recognizer/networks/transformer/layers.py | 5 |
1 files changed, 2 insertions, 3 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index b2c703f..a44a525 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -1,8 +1,6 @@ """Generates the attention layer architecture.""" from functools import partial -from typing import Any, Dict, Optional, Type - -from click.types import Tuple +from typing import Any, Dict, Optional, Tuple, Type from torch import nn, Tensor @@ -30,6 +28,7 @@ class AttentionLayers(nn.Module): pre_norm: bool = True, ) -> None: super().__init__() + self.dim = dim attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs) norm_fn = partial(norm_fn, dim) ff_fn = partial(ff_fn, dim=dim, **ff_kwargs) |