From 82a8efc3ba5dd2048b3b46e59c2da0face44fed1 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 28 Nov 2021 21:36:59 +0100 Subject: Refactor attention layer module --- text_recognizer/networks/transformer/layers.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index 67558ad..fc32f20 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -15,7 +15,6 @@ class AttentionLayers(nn.Module): """Standard transfomer layer.""" def __attrs_pre_init__(self) -> None: - """Pre init constructor.""" super().__init__() depth: int = attr.ib() @@ -29,9 +28,8 @@ class AttentionLayers(nn.Module): layers: nn.ModuleList = attr.ib(init=False) def __attrs_post_init__(self) -> None: - """Post init configuration.""" self.layer_types = self._get_layer_types() * self.depth - self.layers = self._build_network() + self.layers = self._build() def _get_layer_types(self) -> Tuple: """Get layer specification.""" @@ -45,7 +43,7 @@ class AttentionLayers(nn.Module): del self.norm del self.cross_attn - def _build_network(self) -> nn.ModuleList: + def _build(self) -> nn.ModuleList: """Configures transformer network.""" layers = nn.ModuleList([]) for layer_type in self.layer_types: @@ -97,4 +95,17 @@ class Decoder(AttentionLayers): """Decoder module.""" def __init__(self, **kwargs: Any) -> None: + if "cross_attn" not in kwargs: + ValueError("Decoder requires cross attention.") + + super().__init__(**kwargs) + + +class Encoder(AttentionLayers): + """Encoder module.""" + + def __init__(self, **kwargs: Any) -> None: + if "cross_attn" in kwargs: + ValueError("Encoder requires cross attention.") + super().__init__(**kwargs) -- cgit v1.2.3-70-g09d2