summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-13 18:46:25 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-13 18:46:25 +0200
commit764a14468ab8b669dd0906641b891b9bdec7b96a (patch)
tree995605e68ea1aa5d77f9cbb19580be0ae359e977 /text_recognizer/networks/transformer
parent143d37636c4533a74c558ca5afb8a579af38de97 (diff)
Rename hidden dim to dim
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r--text_recognizer/networks/transformer/embeddings/fourier.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/text_recognizer/networks/transformer/embeddings/fourier.py b/text_recognizer/networks/transformer/embeddings/fourier.py
index ade589c..7843c60 100644
--- a/text_recognizer/networks/transformer/embeddings/fourier.py
+++ b/text_recognizer/networks/transformer/embeddings/fourier.py
@@ -8,12 +8,10 @@ from torch import Tensor
class PositionalEncoding(nn.Module):
"""Encodes a sense of distance or time for transformer networks."""
- def __init__(
- self, hidden_dim: int, dropout_rate: float, max_len: int = 1000
- ) -> None:
+ def __init__(self, dim: int, dropout_rate: float, max_len: int = 1000) -> None:
super().__init__()
self.dropout = nn.Dropout(p=dropout_rate)
- pe = self.make_pe(hidden_dim, max_len)
+ pe = self.make_pe(dim, max_len)
self.register_buffer("pe", pe)
@staticmethod