diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-28 21:36:59 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-28 21:36:59 +0100 |
commit | 82a8efc3ba5dd2048b3b46e59c2da0face44fed1 (patch) | |
tree | fdd24a28d7d3f071d0ca61709d6ceec68227688b /text_recognizer | |
parent | 271e901a073c8e81335ddb929e57dbc144d54b05 (diff) |
Refactor attention layer module
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/networks/transformer/layers.py | 19 |
1 files changed, 15 insertions, 4 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index 67558ad..fc32f20 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -15,7 +15,6 @@ class AttentionLayers(nn.Module): """Standard transfomer layer.""" def __attrs_pre_init__(self) -> None: - """Pre init constructor.""" super().__init__() depth: int = attr.ib() @@ -29,9 +28,8 @@ class AttentionLayers(nn.Module): layers: nn.ModuleList = attr.ib(init=False) def __attrs_post_init__(self) -> None: - """Post init configuration.""" self.layer_types = self._get_layer_types() * self.depth - self.layers = self._build_network() + self.layers = self._build() def _get_layer_types(self) -> Tuple: """Get layer specification.""" @@ -45,7 +43,7 @@ class AttentionLayers(nn.Module): del self.norm del self.cross_attn - def _build_network(self) -> nn.ModuleList: + def _build(self) -> nn.ModuleList: """Configures transformer network.""" layers = nn.ModuleList([]) for layer_type in self.layer_types: @@ -97,4 +95,17 @@ class Decoder(AttentionLayers): """Decoder module.""" def __init__(self, **kwargs: Any) -> None: + if "cross_attn" not in kwargs: + ValueError("Decoder requires cross attention.") + + super().__init__(**kwargs) + + +class Encoder(AttentionLayers): + """Encoder module.""" + + def __init__(self, **kwargs: Any) -> None: + if "cross_attn" in kwargs: + ValueError("Encoder requires cross attention.") + super().__init__(**kwargs) |