summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/layers.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-17 22:42:29 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-17 22:42:29 +0100
commit2cb2c5b38f0711267fecfe9c5e10940f4b4f79fc (patch)
treee6e3dfe027a365e2ad5a14c373cad5f2aa77b3ac /text_recognizer/networks/transformer/layers.py
parent91db5e23f86ec0b829aebef6eef642bcf63da53b (diff)
Remove local attention
Diffstat (limited to 'text_recognizer/networks/transformer/layers.py')
-rw-r--r--text_recognizer/networks/transformer/layers.py16
1 files changed, 1 insertions, 15 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index 8387fa4..67558ad 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -6,7 +6,6 @@ import attr
from torch import nn, Tensor
from text_recognizer.networks.transformer.attention import Attention
-from text_recognizer.networks.transformer.local_attention import LocalAttention
from text_recognizer.networks.transformer.mlp import FeedForward
from text_recognizer.networks.transformer.residual import Residual
@@ -24,18 +23,13 @@ class AttentionLayers(nn.Module):
norm: Type[nn.Module] = attr.ib()
ff: FeedForward = attr.ib()
cross_attn: Optional[Attention] = attr.ib(default=None)
- local_self_attn: Optional[LocalAttention] = attr.ib(default=None)
pre_norm: bool = attr.ib(default=True)
- local_depth: Optional[int] = attr.ib(default=None)
has_pos_emb: bool = attr.ib(default=False)
layer_types: Tuple[str, ...] = attr.ib(init=False)
layers: nn.ModuleList = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
- if self.local_self_attn is not None:
- if self.local_depth is None:
- ValueError("Local depth has to be specified")
self.layer_types = self._get_layer_types() * self.depth
self.layers = self._build_network()
@@ -45,14 +39,8 @@ class AttentionLayers(nn.Module):
return "a", "c", "f"
return "a", "f"
- def _self_attn_block(self, i: int) -> Type[nn.Module]:
- if self.local_depth is not None and i < self.local_depth:
- return deepcopy(self.local_self_attn)
- return deepcopy(self.self_attn)
-
def _delete(self) -> None:
del self.self_attn
- del self.local_self_attn
del self.ff
del self.norm
del self.cross_attn
@@ -60,11 +48,9 @@ class AttentionLayers(nn.Module):
def _build_network(self) -> nn.ModuleList:
"""Configures transformer network."""
layers = nn.ModuleList([])
- self_attn_depth = 0
for layer_type in self.layer_types:
if layer_type == "a":
- layer = self._self_attn_block(self_attn_depth)
- self_attn_depth += 1
+ layer = deepcopy(self.self_attn)
elif layer_type == "c":
layer = deepcopy(self.cross_attn)
elif layer_type == "f":