diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-15 21:49:51 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-15 21:49:51 +0200 |
commit | b3fbfd72a8f647161685b28d20b4b61519d8a643 (patch) | |
tree | a5cac4e15186396aae35231d6d6fe266691b0186 /text_recognizer/network/transformer/encoder.py | |
parent | c7e5354ffa43eccfc4e411375ce2f531af7bbcff (diff) |
Update transformer model
Diffstat (limited to 'text_recognizer/network/transformer/encoder.py')
-rw-r--r-- | text_recognizer/network/transformer/encoder.py | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/text_recognizer/network/transformer/encoder.py b/text_recognizer/network/transformer/encoder.py index 1728c61..ce30372 100644 --- a/text_recognizer/network/transformer/encoder.py +++ b/text_recognizer/network/transformer/encoder.py @@ -13,6 +13,8 @@ class Encoder(nn.Module): ff_mult: int, depth: int, dropout_rate: float = 0.0, + use_rotary_emb: bool = False, + one_kv_head: bool = False, ) -> None: super().__init__() self.norm = nn.LayerNorm(dim) @@ -27,6 +29,8 @@ class Encoder(nn.Module): dropout_rate=dropout_rate, use_flash=True, norm_context=False, + use_rotary_emb=use_rotary_emb, + one_kv_head=one_kv_head, ) for _ in range(depth) ] |