diff options
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r-- | text_recognizer/networks/transformer/embeddings/fourier.py | 6 |
1 files changed, 2 insertions, 4 deletions
diff --git a/text_recognizer/networks/transformer/embeddings/fourier.py b/text_recognizer/networks/transformer/embeddings/fourier.py index ade589c..7843c60 100644 --- a/text_recognizer/networks/transformer/embeddings/fourier.py +++ b/text_recognizer/networks/transformer/embeddings/fourier.py @@ -8,12 +8,10 @@ from torch import Tensor class PositionalEncoding(nn.Module): """Encodes a sense of distance or time for transformer networks.""" - def __init__( - self, hidden_dim: int, dropout_rate: float, max_len: int = 1000 - ) -> None: + def __init__(self, dim: int, dropout_rate: float, max_len: int = 1000) -> None: super().__init__() self.dropout = nn.Dropout(p=dropout_rate) - pe = self.make_pe(hidden_dim, max_len) + pe = self.make_pe(dim, max_len) self.register_buffer("pe", pe) @staticmethod |