summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-01-26 23:18:06 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-01-26 23:18:06 +0100
commit68640a70a769568fd571e50322d6e3fa40c78271 (patch)
treebf888f83b72ff3ef09edecb35ad41c7839f3733c /text_recognizer/networks
parent53a4af4ca22ced165fde14cd0de46e29aab7d80d (diff)
fix: refactor AttentionLayers
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/transformer/layers.py54
1 files changed, 25 insertions, 29 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index fc32f20..4263f52 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -2,7 +2,6 @@
from copy import deepcopy
from typing import Any, Optional, Tuple, Type
-import attr
from torch import nn, Tensor
from text_recognizer.networks.transformer.attention import Attention
@@ -10,26 +9,24 @@ from text_recognizer.networks.transformer.mlp import FeedForward
from text_recognizer.networks.transformer.residual import Residual
-@attr.s(eq=False)
class AttentionLayers(nn.Module):
"""Standard transfomer layer."""
- def __attrs_pre_init__(self) -> None:
+ def __init__(
+ self,
+ depth: int,
+ self_attn: Attention,
+ norm: Type[nn.Module],
+ ff: FeedForward,
+ cross_attn: Optional[Attention] = None,
+ pre_norm: bool = True,
+ has_pos_emb: bool = True,
+ ) -> None:
super().__init__()
-
- depth: int = attr.ib()
- self_attn: Attention = attr.ib()
- norm: Type[nn.Module] = attr.ib()
- ff: FeedForward = attr.ib()
- cross_attn: Optional[Attention] = attr.ib(default=None)
- pre_norm: bool = attr.ib(default=True)
- has_pos_emb: bool = attr.ib(default=False)
- layer_types: Tuple[str, ...] = attr.ib(init=False)
- layers: nn.ModuleList = attr.ib(init=False)
-
- def __attrs_post_init__(self) -> None:
- self.layer_types = self._get_layer_types() * self.depth
- self.layers = self._build()
+ self.pre_norm = pre_norm
+ self.has_pos_emb = has_pos_emb
+ self.layer_types = self._get_layer_types() * depth
+ self.layers = self._build(self_attn, norm, ff, cross_attn)
def _get_layer_types(self) -> Tuple:
"""Get layer specification."""
@@ -37,24 +34,23 @@ class AttentionLayers(nn.Module):
return "a", "c", "f"
return "a", "f"
- def _delete(self) -> None:
- del self.self_attn
- del self.ff
- del self.norm
- del self.cross_attn
-
- def _build(self) -> nn.ModuleList:
+ def _build(
+ self,
+ self_attn: Attention,
+ norm: Type[nn.Module],
+ ff: FeedForward,
+ cross_attn: Optional[Attention],
+ ) -> nn.ModuleList:
"""Configures transformer network."""
layers = nn.ModuleList([])
for layer_type in self.layer_types:
if layer_type == "a":
- layer = deepcopy(self.self_attn)
+ layer = deepcopy(self_attn)
elif layer_type == "c":
- layer = deepcopy(self.cross_attn)
+ layer = deepcopy(cross_attn)
elif layer_type == "f":
- layer = deepcopy(self.ff)
- layers.append(nn.ModuleList([deepcopy(self.norm), layer, Residual()]))
- self._delete()
+ layer = deepcopy(ff)
+ layers.append(nn.ModuleList([deepcopy(norm), layer, Residual()]))
return layers
def forward(