diff options
Diffstat (limited to 'text_recognizer/networks/transformer/layers.py')
-rw-r--r-- | text_recognizer/networks/transformer/layers.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index 66c9c50..ce443e5 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -12,7 +12,7 @@ from text_recognizer.networks.transformer.positional_encodings.rotary_embedding from text_recognizer.networks.util import load_partial_fn -@attr.s +@attr.s(eq=False) class AttentionLayers(nn.Module): """Standard transfomer layer.""" @@ -101,11 +101,11 @@ class AttentionLayers(nn.Module): return x -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, eq=False) class Encoder(AttentionLayers): causal: bool = attr.ib(default=False, init=False) -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, eq=False) class Decoder(AttentionLayers): causal: bool = attr.ib(default=True, init=False) |