diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-04 16:05:13 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-04 16:05:13 +0200 |
commit | 31e9673eef3088f08e3ee6aef8b78abd701ca329 (patch) | |
tree | f529d975d18d718a5d646e93f746d8be6f2f5cfe /text_recognizer/networks | |
parent | 36964354407d0fdf73bdca2f611fee1664860197 (diff) |
Reformat test for CER
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/transformer/positional_encoding.py | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py index d03f630..d67d297 100644 --- a/text_recognizer/networks/transformer/positional_encoding.py +++ b/text_recognizer/networks/transformer/positional_encoding.py @@ -16,7 +16,7 @@ class PositionalEncoding(nn.Module): self.dropout = nn.Dropout(p=dropout_rate) pe = self.make_pe(hidden_dim, max_len) self.register_buffer("pe", pe) - + @staticmethod def make_pe(hidden_dim: int, max_len: int) -> Tensor: """Returns positional encoding.""" @@ -40,7 +40,7 @@ class PositionalEncoding(nn.Module): class PositionalEncoding2D(nn.Module): """Positional encodings for feature maps.""" - def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int =2048) -> None: + def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int = 2048) -> None: super().__init__() if hidden_dim % 2 != 0: raise ValueError(f"Embedding depth {hidden_dim} is not even!") @@ -50,10 +50,14 @@ class PositionalEncoding2D(nn.Module): def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor: """Returns 2d postional encoding.""" - pe_h = PositionalEncoding.make_pe(hidden_dim // 2, max_len=max_h) # [H, 1, D // 2] + pe_h = PositionalEncoding.make_pe( + hidden_dim // 2, max_len=max_h + ) # [H, 1, D // 2] pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w) - pe_w = PositionalEncoding.make_pe(hidden_dim // 2, max_len=max_h) # [W, 1, D // 2] + pe_w = PositionalEncoding.make_pe( + hidden_dim // 2, max_len=max_h + ) # [W, 1, D // 2] pe_w = repeat(pe_w, "h w d -> d (h tile) w", tile=max_h) pe = torch.cat([pe_h, pe_w], dim=0) # [D, H, W] @@ -64,7 +68,5 @@ class PositionalEncoding2D(nn.Module): # Assumes x hase shape [B, D, H, W] if x.shape[1] != self.pe.shape[0]: raise ValueError("Hidden dimensions does not match.") - x += self.pe[:, :x.shape[2], :x.shape[3]] + x += self.pe[:, : x.shape[2], : x.shape[3]] return x - - |