diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/networks/transformer/layers.py | 27 |
1 files changed, 14 insertions, 13 deletions
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index b132522..ca4569f 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -26,7 +26,7 @@ class AttentionLayers(nn.Module): norm_fn: str = attr.ib() ff_fn: str = attr.ib() ff_kwargs: DictConfig = attr.ib() - rotary_emb: Optional[RotaryEmbedding] = attr.ib() + rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None) local_attn_fn: Optional[str] = attr.ib(default=None) local_attn_kwargs: Optional[DictConfig] = attr.ib(default=None) causal: bool = attr.ib(default=False) @@ -35,16 +35,17 @@ class AttentionLayers(nn.Module): local_depth: Optional[int] = attr.ib(default=None) layer_types: Tuple[str, ...] = attr.ib(init=False) layers: nn.ModuleList = attr.ib(init=False) + has_pos_emb: bool = attr.ib(init=False, default=False) def __attrs_post_init__(self) -> None: """Post init configuration.""" - self.layer_types = self._get_layer_types() * self.depth - self.layers = self._build_network() - if self.local_attn_kwargs is not None and self.local_attn_fn is not None: if "depth" not in self.local_attn_kwargs: ValueError("Local depth has to be specified") self.local_depth = self.local_attn_kwargs.pop("depth") + self.layer_types = self._get_layer_types() * self.depth + self.layers = self._build_network() + self.has_pos_emb = self.rotary_emb is None def _get_layer_types(self) -> Tuple: """Get layer specification.""" @@ -53,7 +54,7 @@ class AttentionLayers(nn.Module): return "a", "f" def _configure_causal_attn(self, i: int) -> Type[nn.Module]: - if self.local_depth is not None and i <= self.local_depth: + if self.local_depth is not None and i < self.local_depth: return load_partial_fn( self.local_attn_fn, dim=self.dim, @@ -73,9 +74,11 @@ class AttentionLayers(nn.Module): norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim) ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs) layers = nn.ModuleList([]) - for i, layer_type in enumerate(self.layer_types): + self_attn_depth = 0 + for layer_type in self.layer_types: if layer_type == "a": - layer = self._configure_causal_attn(i) + layer = self._configure_causal_attn(self_attn_depth) + self_attn_depth += 1 elif layer_type == "c": layer = load_partial_fn( self.attn_fn, @@ -122,12 +125,10 @@ class AttentionLayers(nn.Module): return x -@attr.s(auto_attribs=True, eq=False) -class Encoder(AttentionLayers): - causal: bool = attr.ib(default=False, init=False) - - class Decoder(AttentionLayers): + """Decoder module.""" + def __init__(self, **kwargs: Any) -> None: - assert "causal" not in kwargs, "Cannot set causality on decoder" + if "causal" in kwargs: + ValueError("Cannot set causality on decoder") super().__init__(causal=True, **kwargs) |