diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-12-07 22:54:04 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-12-07 22:54:04 +0100 |
commit | 25b5d6983d51e0e791b96a76beb7e49f392cd9a8 (patch) | |
tree | 526ba739714b3d040f7810c1a6be3ff0ba37fdb1 /src/text_recognizer/networks/cnn_transformer.py | |
parent | 5529e0fc9ca39e81fe0f08a54f257d32f0afe120 (diff) |
Segmentation working!
Diffstat (limited to 'src/text_recognizer/networks/cnn_transformer.py')
-rw-r--r-- | src/text_recognizer/networks/cnn_transformer.py | 19 |
1 files changed, 14 insertions, 5 deletions
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py index 16c7a41..b2b74b3 100644 --- a/src/text_recognizer/networks/cnn_transformer.py +++ b/src/text_recognizer/networks/cnn_transformer.py @@ -88,10 +88,14 @@ class CNNTransformer(nn.Module): if len(src.shape) < 4: src = src[(None,) * (4 - len(src.shape))] src = self.backbone(src) - src = rearrange(src, "b c h w -> b w c h") + if self.adaptive_pool is not None: + src = rearrange(src, "b c h w -> b w c h") src = self.adaptive_pool(src) - src = src.squeeze(3) + src = src.squeeze(3) + else: + src = rearrange(src, "b c h w -> b (w h) c") + src = self.position_encoding(src) return src @@ -110,12 +114,17 @@ class CNNTransformer(nn.Module): trg = self.position_encoding(trg) return trg - def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: - """Forward pass with CNN transfomer.""" - h = self.extract_image_features(x) + def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor: + """Takes images features from the backbone and decodes them with the transformer.""" trg_mask = self._create_trg_mask(trg) trg = self.target_embedding(trg) out = self.transformer(h, trg, trg_mask=trg_mask) logits = self.head(out) return logits + + def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: + """Forward pass with CNN transfomer.""" + h = self.extract_image_features(x) + logits = self.decode_image_features(h, trg) + return logits |