From 31e9673eef3088f08e3ee6aef8b78abd701ca329 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 4 Apr 2021 16:05:13 +0200 Subject: Reformat test for CER --- .../networks/transformer/positional_encoding.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) (limited to 'text_recognizer/networks/transformer/positional_encoding.py') diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py index d03f630..d67d297 100644 --- a/text_recognizer/networks/transformer/positional_encoding.py +++ b/text_recognizer/networks/transformer/positional_encoding.py @@ -16,7 +16,7 @@ class PositionalEncoding(nn.Module): self.dropout = nn.Dropout(p=dropout_rate) pe = self.make_pe(hidden_dim, max_len) self.register_buffer("pe", pe) - + @staticmethod def make_pe(hidden_dim: int, max_len: int) -> Tensor: """Returns positional encoding.""" @@ -40,7 +40,7 @@ class PositionalEncoding(nn.Module): class PositionalEncoding2D(nn.Module): """Positional encodings for feature maps.""" - def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int =2048) -> None: + def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int = 2048) -> None: super().__init__() if hidden_dim % 2 != 0: raise ValueError(f"Embedding depth {hidden_dim} is not even!") @@ -50,10 +50,14 @@ class PositionalEncoding2D(nn.Module): def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor: """Returns 2d postional encoding.""" - pe_h = PositionalEncoding.make_pe(hidden_dim // 2, max_len=max_h) # [H, 1, D // 2] + pe_h = PositionalEncoding.make_pe( + hidden_dim // 2, max_len=max_h + ) # [H, 1, D // 2] pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w) - pe_w = PositionalEncoding.make_pe(hidden_dim // 2, max_len=max_h) # [W, 1, D // 2] + pe_w = PositionalEncoding.make_pe( + hidden_dim // 2, max_len=max_h + ) # [W, 1, D // 2] pe_w = repeat(pe_w, "h w d -> d (h tile) w", tile=max_h) pe = torch.cat([pe_h, pe_w], dim=0) # [D, H, W] @@ -64,7 +68,5 @@ class PositionalEncoding2D(nn.Module): # Assumes x hase shape [B, D, H, W] if x.shape[1] != self.pe.shape[0]: raise ValueError("Hidden dimensions does not match.") - x += self.pe[:, :x.shape[2], :x.shape[3]] + x += self.pe[:, : x.shape[2], : x.shape[3]] return x - - -- cgit v1.2.3-70-g09d2