diff options
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r-- | text_recognizer/networks/transformer/attention.py | 2 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/layers.py | 16 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/norm.py | 8 |
3 files changed, 11 insertions, 15 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 2770dc1..9202cce 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -24,9 +24,9 @@ class Attention(nn.Module): dim: int = attr.ib() num_heads: int = attr.ib() + causal: bool = attr.ib(default=False) dim_head: int = attr.ib(default=64) dropout_rate: float = attr.ib(default=0.0) - casual: bool = attr.ib(default=False) scale: float = attr.ib(init=False) dropout: nn.Dropout = attr.ib(init=False) fc: nn.Linear = attr.ib(init=False) diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index 9b2f236..66c9c50 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -30,8 +30,7 @@ class AttentionLayers(nn.Module): causal: bool = attr.ib(default=False) cross_attend: bool = attr.ib(default=False) pre_norm: bool = attr.ib(default=True) - rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None, init=False) - has_pos_emb: bool = attr.ib(init=False) + rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None) layer_types: Tuple[str, ...] = attr.ib(init=False) layers: nn.ModuleList = attr.ib(init=False) attn: partial = attr.ib(init=False) @@ -40,12 +39,11 @@ class AttentionLayers(nn.Module): def __attrs_post_init__(self) -> None: """Post init configuration.""" - self.has_pos_emb = True if self.rotary_emb is not None else False self.layer_types = self._get_layer_types() * self.depth attn = load_partial_fn( self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs ) - norm = load_partial_fn(self.norm_fn, dim=self.dim) + norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim) ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs) self.layers = self._build_network(attn, norm, ff) @@ -103,13 +101,11 @@ class AttentionLayers(nn.Module): return x +@attr.s(auto_attribs=True) class Encoder(AttentionLayers): - def __init__(self, **kwargs: Any) -> None: - assert "causal" not in kwargs, "Cannot set causality on encoder" - super().__init__(causal=False, **kwargs) + causal: bool = attr.ib(default=False, init=False) +@attr.s(auto_attribs=True) class Decoder(AttentionLayers): - def __init__(self, **kwargs: Any) -> None: - assert "causal" not in kwargs, "Cannot set causality on decoder" - super().__init__(causal=True, **kwargs) + causal: bool = attr.ib(default=True, init=False) diff --git a/text_recognizer/networks/transformer/norm.py b/text_recognizer/networks/transformer/norm.py index 8bc3221..4930adf 100644 --- a/text_recognizer/networks/transformer/norm.py +++ b/text_recognizer/networks/transformer/norm.py @@ -12,9 +12,9 @@ from torch import Tensor class ScaleNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1.0e-5) -> None: + def __init__(self, normalized_shape: int, eps: float = 1.0e-5) -> None: super().__init__() - self.scale = dim ** -0.5 + self.scale = normalized_shape ** -0.5 self.eps = eps self.g = nn.Parameter(torch.ones(1)) @@ -24,9 +24,9 @@ class ScaleNorm(nn.Module): class PreNorm(nn.Module): - def __init__(self, dim: int, fn: Type[nn.Module]) -> None: + def __init__(self, normalized_shape: int, fn: Type[nn.Module]) -> None: super().__init__() - self.norm = nn.LayerNorm(dim) + self.norm = nn.LayerNorm(normalized_shape) self.fn = fn def forward(self, x: Tensor, **kwargs: Dict) -> Tensor: |