summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r--text_recognizer/networks/transformer/attention.py2
-rw-r--r--text_recognizer/networks/transformer/layers.py16
-rw-r--r--text_recognizer/networks/transformer/norm.py8
3 files changed, 11 insertions, 15 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 2770dc1..9202cce 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -24,9 +24,9 @@ class Attention(nn.Module):
dim: int = attr.ib()
num_heads: int = attr.ib()
+ causal: bool = attr.ib(default=False)
dim_head: int = attr.ib(default=64)
dropout_rate: float = attr.ib(default=0.0)
- casual: bool = attr.ib(default=False)
scale: float = attr.ib(init=False)
dropout: nn.Dropout = attr.ib(init=False)
fc: nn.Linear = attr.ib(init=False)
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index 9b2f236..66c9c50 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -30,8 +30,7 @@ class AttentionLayers(nn.Module):
causal: bool = attr.ib(default=False)
cross_attend: bool = attr.ib(default=False)
pre_norm: bool = attr.ib(default=True)
- rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None, init=False)
- has_pos_emb: bool = attr.ib(init=False)
+ rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None)
layer_types: Tuple[str, ...] = attr.ib(init=False)
layers: nn.ModuleList = attr.ib(init=False)
attn: partial = attr.ib(init=False)
@@ -40,12 +39,11 @@ class AttentionLayers(nn.Module):
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
- self.has_pos_emb = True if self.rotary_emb is not None else False
self.layer_types = self._get_layer_types() * self.depth
attn = load_partial_fn(
self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs
)
- norm = load_partial_fn(self.norm_fn, dim=self.dim)
+ norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim)
ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs)
self.layers = self._build_network(attn, norm, ff)
@@ -103,13 +101,11 @@ class AttentionLayers(nn.Module):
return x
+@attr.s(auto_attribs=True)
class Encoder(AttentionLayers):
- def __init__(self, **kwargs: Any) -> None:
- assert "causal" not in kwargs, "Cannot set causality on encoder"
- super().__init__(causal=False, **kwargs)
+ causal: bool = attr.ib(default=False, init=False)
+@attr.s(auto_attribs=True)
class Decoder(AttentionLayers):
- def __init__(self, **kwargs: Any) -> None:
- assert "causal" not in kwargs, "Cannot set causality on decoder"
- super().__init__(causal=True, **kwargs)
+ causal: bool = attr.ib(default=True, init=False)
diff --git a/text_recognizer/networks/transformer/norm.py b/text_recognizer/networks/transformer/norm.py
index 8bc3221..4930adf 100644
--- a/text_recognizer/networks/transformer/norm.py
+++ b/text_recognizer/networks/transformer/norm.py
@@ -12,9 +12,9 @@ from torch import Tensor
class ScaleNorm(nn.Module):
- def __init__(self, dim: int, eps: float = 1.0e-5) -> None:
+ def __init__(self, normalized_shape: int, eps: float = 1.0e-5) -> None:
super().__init__()
- self.scale = dim ** -0.5
+ self.scale = normalized_shape ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
@@ -24,9 +24,9 @@ class ScaleNorm(nn.Module):
class PreNorm(nn.Module):
- def __init__(self, dim: int, fn: Type[nn.Module]) -> None:
+ def __init__(self, normalized_shape: int, fn: Type[nn.Module]) -> None:
super().__init__()
- self.norm = nn.LayerNorm(dim)
+ self.norm = nn.LayerNorm(normalized_shape)
self.fn = fn
def forward(self, x: Tensor, **kwargs: Dict) -> Tensor: