summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/positional_encoding.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer/positional_encoding.py')
-rw-r--r--text_recognizer/networks/transformer/positional_encoding.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py
index 5874e97..c50afc3 100644
--- a/text_recognizer/networks/transformer/positional_encoding.py
+++ b/text_recognizer/networks/transformer/positional_encoding.py
@@ -33,7 +33,10 @@ class PositionalEncoding(nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Encodes the tensor with a postional embedding."""
- x = x + self.pe[:, : x.shape[1]]
+ # [T, B, D]
+ if x.shape[2] != self.pe.shape[2]:
+ raise ValueError(f"x shape does not match pe in the 3rd dim.")
+ x = x + self.pe[: x.shape[0]]
return self.dropout(x)
@@ -48,6 +51,7 @@ class PositionalEncoding2D(nn.Module):
pe = self.make_pe(hidden_dim, max_h, max_w)
self.register_buffer("pe", pe)
+ @staticmethod
def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor:
"""Returns 2d postional encoding."""
pe_h = PositionalEncoding.make_pe(