summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/embeddings
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer/embeddings')
-rw-r--r--text_recognizer/networks/transformer/embeddings/fourier.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/text_recognizer/networks/transformer/embeddings/fourier.py b/text_recognizer/networks/transformer/embeddings/fourier.py
index ade589c..7843c60 100644
--- a/text_recognizer/networks/transformer/embeddings/fourier.py
+++ b/text_recognizer/networks/transformer/embeddings/fourier.py
@@ -8,12 +8,10 @@ from torch import Tensor
class PositionalEncoding(nn.Module):
"""Encodes a sense of distance or time for transformer networks."""
- def __init__(
- self, hidden_dim: int, dropout_rate: float, max_len: int = 1000
- ) -> None:
+ def __init__(self, dim: int, dropout_rate: float, max_len: int = 1000) -> None:
super().__init__()
self.dropout = nn.Dropout(p=dropout_rate)
- pe = self.make_pe(hidden_dim, max_len)
+ pe = self.make_pe(dim, max_len)
self.register_buffer("pe", pe)
@staticmethod