summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-23 14:55:31 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-23 14:55:31 +0200
commita1d795bf02d14befc62cf600fb48842958148eba (patch)
tree21465c20262b15654985368731e8a289562e8df7 /text_recognizer
parentd20802e1f412045f7afa4bd8ac50be3488945e90 (diff)
Complete cnn-transformer network, not tested
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/data/iam_preprocessor.py2
-rw-r--r--text_recognizer/networks/cnn_tranformer.py81
-rw-r--r--text_recognizer/networks/encoders/efficientnet/efficientnet.py2
3 files changed, 67 insertions, 18 deletions
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py
index 506036e..f7457e4 100644
--- a/text_recognizer/data/iam_preprocessor.py
+++ b/text_recognizer/data/iam_preprocessor.py
@@ -47,6 +47,8 @@ def load_metadata(
class Preprocessor:
"""A preprocessor for the IAM dataset."""
+ # TODO: attrs
+
def __init__(
self,
data_dir: Union[str, Path],
diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/cnn_tranformer.py
index 5c13e9a..e030cb8 100644
--- a/text_recognizer/networks/cnn_tranformer.py
+++ b/text_recognizer/networks/cnn_tranformer.py
@@ -3,6 +3,7 @@ import math
from typing import Tuple, Type
import attr
+import torch
from torch import nn, Tensor
from text_recognizer.data.mappings import AbstractMapping
@@ -18,13 +19,19 @@ class CnnTransformer(nn.Module):
def __attrs_pre_init__(self) -> None:
super().__init__()
- # Parameters,
+ # Parameters and placeholders,
input_dims: Tuple[int, int, int] = attr.ib()
hidden_dim: int = attr.ib()
dropout_rate: float = attr.ib()
max_output_len: int = attr.ib()
num_classes: int = attr.ib()
padding_idx: int = attr.ib()
+ start_token: str = attr.ib()
+ start_index: int = attr.ib(init=False, default=None)
+ end_token: str = attr.ib()
+ end_index: int = attr.ib(init=False, default=None)
+ pad_token: str = attr.ib()
+ pad_index: int = attr.ib(init=False, default=None)
# Modules.
encoder: Type[nn.Module] = attr.ib()
@@ -38,6 +45,9 @@ class CnnTransformer(nn.Module):
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
+ self.start_index = int(self.mapping.get_index(self.start_token))
+ self.end_index = int(self.mapping.get_index(self.end_token))
+ self.pad_index = int(self.mapping.get_index(self.pad_token))
# Latent projector for down sampling number of filters and 2d
# positional encoding.
self.latent_encoder = nn.Sequential(
@@ -99,20 +109,20 @@ class CnnTransformer(nn.Module):
z = self.encoder(x)
z = self.latent_encoder(z)
- # Permute tensor from [B, E, Ho * Wo] to [Sx, B, E]
- z = z.permute(2, 0, 1)
+ # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
+ z = z.permute(0, 2, 1)
return z
- def decode(self, z: Tensor, trg: Tensor) -> Tensor:
+ def decode(self, z: Tensor, context: Tensor) -> Tensor:
"""Decodes latent images embedding into word pieces.
Args:
z (Tensor): Latent images embedding.
- trg (Tensor): Word embeddings.
+ context (Tensor): Word embeddings.
Shapes:
- z: :math: `(B, Sx, E)`
- - trg: :math: `(B, Sy)`
+ - context: :math: `(B, Sy)`
- out: :math: `(B, Sy, T)`
where Sy is the length of the output and T is the number of tokens.
@@ -120,32 +130,69 @@ class CnnTransformer(nn.Module):
Returns:
Tensor: Sequence of word piece embeddings.
"""
- trg_mask = trg != self.padding_idx
- trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim)
- trg = self.token_pos_encoder(trg)
- out = self.decoder(x=trg, context=z, mask=trg_mask)
+ context_mask = context != self.padding_idx
+ context = self.token_embedding(context) * math.sqrt(self.hidden_dim)
+ context = self.token_pos_encoder(context)
+ out = self.decoder(x=context, context=z, mask=context_mask)
logits = self.head(out)
return logits
- def forward(self, x: Tensor, trg: Tensor) -> Tensor:
+ def forward(self, x: Tensor, context: Tensor) -> Tensor:
"""Encodes images into word piece logtis.
Args:
x (Tensor): Input image(s).
- trg (Tensor): Target word embeddings.
+ context (Tensor): Target word embeddings.
Shapes:
- x: :math: `(B, C, H, W)`
- - trg: :math: `(B, Sy, T)`
+ - context: :math: `(B, Sy, T)`
where B is the batch size, C is the number of input channels, H is
the image height and W is the image width.
+
+ Returns:
+ Tensor: Sequence of logits.
"""
z = self.encode(x)
- logits = self.decode(z, trg)
+ logits = self.decode(z, context)
return logits
def predict(self, x: Tensor) -> Tensor:
- """Predicts text in image."""
- # TODO: continue here!!!!!!!!!
- pass
+ """Predicts text in image.
+
+ Args:
+ x (Tensor): Image(s) to extract text from.
+
+ Shapes:
+ - x: :math: `(B, H, W)`
+ - output: :math: `(B, S)`
+
+ Returns:
+ Tensor: A tensor of token indices of the predictions from the model.
+ """
+ bsz = x.shape[0]
+
+ # Encode image(s) to latent vectors.
+ z = self.encode(x)
+
+ # Create a placeholder matrix for storing outputs from the network
+ output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
+ output[:, 0] = self.start_index
+
+ for i in range(1, self.max_output_len):
+ context = output[:, :i] # (bsz, i)
+ logits = self.decode(z, context) # (i, bsz, c)
+ tokens = torch.argmax(logits, dim=-1) # (i, bsz)
+ output[:, i : i + 1] = tokens[-1:]
+
+ # Early stopping of prediction loop if token is end or padding token.
+ if (output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index).all():
+ break
+
+ # Set all tokens after end token to pad token.
+ for i in range(1, self.max_output_len):
+ idx = (output[:, i -1] == self.end_index | output[:, i - 1] == self.pad_index)
+ output[idx, i] = self.pad_index
+
+ return output
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
index 59598b5..6719efb 100644
--- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
@@ -41,7 +41,7 @@ class EfficientNet(nn.Module):
self.bn_momentum = bn_momentum
self.bn_eps = bn_eps
self._conv_stem: nn.Sequential = None
- self._blocks: nn.Sequential = None
+ self._blocks: nn.ModuleList = None
self._conv_head: nn.Sequential = None
self._build()