summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-03 00:31:00 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-03 00:31:00 +0200
commit3a21c29e2eff4378c63717f8920ca3ccbfef013c (patch)
treeba46504d7baa8d4fb5bfd473acf99a7a184b330c /text_recognizer/networks
parent75eb34020620584247313926527019471411f6af (diff)
Lint files
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/conv_transformer.py17
1 files changed, 7 insertions, 10 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index 714792e..60c0ef8 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -1,6 +1,6 @@
"""Vision transformer for character recognition."""
import math
-from typing import Tuple
+from typing import Tuple, Type
from torch import nn, Tensor
@@ -24,6 +24,8 @@ class ConvTransformer(nn.Module):
pad_index: Tensor,
encoder: nn.Module,
decoder: Decoder,
+ pixel_pos_embedding: Type[nn.Module],
+ token_pos_embedding: Type[nn.Module],
) -> None:
super().__init__()
self.input_dims = input_dims
@@ -43,11 +45,7 @@ class ConvTransformer(nn.Module):
out_channels=self.hidden_dim,
kernel_size=1,
),
- PositionalEncoding2D(
- hidden_dim=self.hidden_dim,
- max_h=self.input_dims[1],
- max_w=self.input_dims[2],
- ),
+ pixel_pos_embedding,
nn.Flatten(start_dim=2),
)
@@ -57,9 +55,8 @@ class ConvTransformer(nn.Module):
)
# Positional encoding for decoder tokens.
- self.token_pos_encoder = PositionalEncoding(
- hidden_dim=self.hidden_dim, dropout_rate=self.dropout_rate
- )
+ self.token_pos_embedding = token_pos_embedding
+
# Head
self.head = nn.Linear(
in_features=self.hidden_dim, out_features=self.num_classes
@@ -119,7 +116,7 @@ class ConvTransformer(nn.Module):
trg = trg.long()
trg_mask = trg != self.pad_index
trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim)
- trg = self.token_pos_encoder(trg)
+ trg = self.token_pos_embedding(trg)
out = self.decoder(x=trg, context=src, mask=trg_mask)
logits = self.head(out) # [B, Sy, T]
logits = logits.permute(0, 2, 1) # [B, T, Sy]