diff options
Diffstat (limited to 'text_recognizer/network/transformer/encoder.py')
| -rw-r--r-- | text_recognizer/network/transformer/encoder.py | 4 | 
1 files changed, 4 insertions, 0 deletions
diff --git a/text_recognizer/network/transformer/encoder.py b/text_recognizer/network/transformer/encoder.py index 1728c61..ce30372 100644 --- a/text_recognizer/network/transformer/encoder.py +++ b/text_recognizer/network/transformer/encoder.py @@ -13,6 +13,8 @@ class Encoder(nn.Module):          ff_mult: int,          depth: int,          dropout_rate: float = 0.0, +        use_rotary_emb: bool = False, +        one_kv_head: bool = False,      ) -> None:          super().__init__()          self.norm = nn.LayerNorm(dim) @@ -27,6 +29,8 @@ class Encoder(nn.Module):                      dropout_rate=dropout_rate,                      use_flash=True,                      norm_context=False, +                    use_rotary_emb=use_rotary_emb, +                    one_kv_head=one_kv_head,                  )                  for _ in range(depth)              ]  |