diff options
Diffstat (limited to 'text_recognizer/network/transformer/encoder.py')
-rw-r--r-- | text_recognizer/network/transformer/encoder.py | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/text_recognizer/network/transformer/encoder.py b/text_recognizer/network/transformer/encoder.py index 1728c61..ce30372 100644 --- a/text_recognizer/network/transformer/encoder.py +++ b/text_recognizer/network/transformer/encoder.py @@ -13,6 +13,8 @@ class Encoder(nn.Module): ff_mult: int, depth: int, dropout_rate: float = 0.0, + use_rotary_emb: bool = False, + one_kv_head: bool = False, ) -> None: super().__init__() self.norm = nn.LayerNorm(dim) @@ -27,6 +29,8 @@ class Encoder(nn.Module): dropout_rate=dropout_rate, use_flash=True, norm_context=False, + use_rotary_emb=use_rotary_emb, + one_kv_head=one_kv_head, ) for _ in range(depth) ] |