summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/positional_encoding.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer/positional_encoding.py')
-rw-r--r--text_recognizer/networks/transformer/positional_encoding.py16
1 files changed, 9 insertions, 7 deletions
diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py
index d03f630..d67d297 100644
--- a/text_recognizer/networks/transformer/positional_encoding.py
+++ b/text_recognizer/networks/transformer/positional_encoding.py
@@ -16,7 +16,7 @@ class PositionalEncoding(nn.Module):
self.dropout = nn.Dropout(p=dropout_rate)
pe = self.make_pe(hidden_dim, max_len)
self.register_buffer("pe", pe)
-
+
@staticmethod
def make_pe(hidden_dim: int, max_len: int) -> Tensor:
"""Returns positional encoding."""
@@ -40,7 +40,7 @@ class PositionalEncoding(nn.Module):
class PositionalEncoding2D(nn.Module):
"""Positional encodings for feature maps."""
- def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int =2048) -> None:
+ def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int = 2048) -> None:
super().__init__()
if hidden_dim % 2 != 0:
raise ValueError(f"Embedding depth {hidden_dim} is not even!")
@@ -50,10 +50,14 @@ class PositionalEncoding2D(nn.Module):
def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor:
"""Returns 2d postional encoding."""
- pe_h = PositionalEncoding.make_pe(hidden_dim // 2, max_len=max_h) # [H, 1, D // 2]
+ pe_h = PositionalEncoding.make_pe(
+ hidden_dim // 2, max_len=max_h
+ ) # [H, 1, D // 2]
pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w)
- pe_w = PositionalEncoding.make_pe(hidden_dim // 2, max_len=max_h) # [W, 1, D // 2]
+ pe_w = PositionalEncoding.make_pe(
+ hidden_dim // 2, max_len=max_h
+ ) # [W, 1, D // 2]
pe_w = repeat(pe_w, "h w d -> d (h tile) w", tile=max_h)
pe = torch.cat([pe_h, pe_w], dim=0) # [D, H, W]
@@ -64,7 +68,5 @@ class PositionalEncoding2D(nn.Module):
# Assumes x hase shape [B, D, H, W]
if x.shape[1] != self.pe.shape[0]:
raise ValueError("Hidden dimensions does not match.")
- x += self.pe[:, :x.shape[2], :x.shape[3]]
+ x += self.pe[:, : x.shape[2], : x.shape[3]]
return x
-
-