diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-09 22:46:09 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-09 22:46:09 +0200 |
commit | c9c60678673e19ad3367339eb8e7a093e5a98474 (patch) | |
tree | b787a7fbb535c2ee44f935720d75034cc24ffd30 /text_recognizer/networks/transformer/layers.py | |
parent | a2a3133ed5da283888efbdb9924d0e3733c274c8 (diff) |
Reformatting of positional encodings and ViT working
Diffstat (limited to 'text_recognizer/networks/transformer/layers.py')
-rw-r--r-- | text_recognizer/networks/transformer/layers.py | 29 |
1 files changed, 21 insertions, 8 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index a2fdb1a..4063425 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -1,6 +1,6 @@ """Generates the attention layer architecture.""" from functools import partial -from typing import Dict, Optional, Type +from typing import Any, Dict, Optional, Type from click.types import Tuple @@ -36,12 +36,11 @@ class AttentionLayers(nn.Module): norm_fn = partial(norm_fn, dim=dim) ff_fn = partial(ff_fn, dim=dim, **ff_kwargs) self.layer_types = self._get_layer_types(cross_attend) * depth - self.layers = self._build_network( - causal, attn_fn, norm_fn, ff_fn, residual_fn - ) + self.layers = self._build_network(causal, attn_fn, norm_fn, ff_fn, residual_fn) rotary_emb_dim = max(rotary_emb_dim, 32) if rotary_emb_dim is not None else None self.rotary_emb = RotaryEmbedding(rotary_emb_dim) if rotary_emb else None self.pre_norm = pre_norm + self.has_pos_emb = True if self.rotary_emb is not None else False @staticmethod def _get_layer_types(cross_attend: bool) -> Tuple: @@ -70,7 +69,7 @@ class AttentionLayers(nn.Module): residual_fn = residual_fn() - layers.append(nn.ModuleList([norm_fn(), layer, residual_fn])) + layers.append(nn.modulelist([norm_fn(), layer, residual_fn])) return layers def forward( @@ -82,10 +81,12 @@ class AttentionLayers(nn.Module): ) -> Tensor: rotary_pos_emb = self.rotary_emb(x) if self.rotary_emb is not None else None - for i, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + for i, (layer_type, (norm, block, residual_fn)) in enumerate( + zip(self.layer_types, self.layers) + ): is_last = i == len(self.layers) - 1 - - residual = x + + residual = x if self.pre_norm: x = norm(x) @@ -103,3 +104,15 @@ class AttentionLayers(nn.Module): x = norm(x) return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs: Any) -> None: + assert "causal" not in kwargs, "Cannot set causality on encoder" + super().__init__(causal=False, **kwargs) + + +class Decoder(AttentionLayers): + def __init__(self, **kwargs: Any) -> None: + assert "causal" not in kwargs, "Cannot set causality on decoder" + super().__init__(causal=True, **kwargs) |