summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-28 21:36:59 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-28 21:36:59 +0100
commit82a8efc3ba5dd2048b3b46e59c2da0face44fed1 (patch)
treefdd24a28d7d3f071d0ca61709d6ceec68227688b
parent271e901a073c8e81335ddb929e57dbc144d54b05 (diff)
Refactor attention layer module
-rw-r--r--text_recognizer/networks/transformer/layers.py19
1 files changed, 15 insertions, 4 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index 67558ad..fc32f20 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -15,7 +15,6 @@ class AttentionLayers(nn.Module):
"""Standard transfomer layer."""
def __attrs_pre_init__(self) -> None:
- """Pre init constructor."""
super().__init__()
depth: int = attr.ib()
@@ -29,9 +28,8 @@ class AttentionLayers(nn.Module):
layers: nn.ModuleList = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
- """Post init configuration."""
self.layer_types = self._get_layer_types() * self.depth
- self.layers = self._build_network()
+ self.layers = self._build()
def _get_layer_types(self) -> Tuple:
"""Get layer specification."""
@@ -45,7 +43,7 @@ class AttentionLayers(nn.Module):
del self.norm
del self.cross_attn
- def _build_network(self) -> nn.ModuleList:
+ def _build(self) -> nn.ModuleList:
"""Configures transformer network."""
layers = nn.ModuleList([])
for layer_type in self.layer_types:
@@ -97,4 +95,17 @@ class Decoder(AttentionLayers):
"""Decoder module."""
def __init__(self, **kwargs: Any) -> None:
+ if "cross_attn" not in kwargs:
+ ValueError("Decoder requires cross attention.")
+
+ super().__init__(**kwargs)
+
+
+class Encoder(AttentionLayers):
+ """Encoder module."""
+
+ def __init__(self, **kwargs: Any) -> None:
+ if "cross_attn" in kwargs:
+ ValueError("Encoder requires cross attention.")
+
super().__init__(**kwargs)