summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/cnn_transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/cnn_transformer.py')
-rw-r--r--text_recognizer/networks/cnn_transformer.py182
1 files changed, 0 insertions, 182 deletions
diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py
deleted file mode 100644
index 80798e1..0000000
--- a/text_recognizer/networks/cnn_transformer.py
+++ /dev/null
@@ -1,182 +0,0 @@
-# """A Transformer with a cnn backbone.
-#
-# The network encodes a image with a convolutional backbone to a latent representation,
-# i.e. feature maps. A 2d positional encoding is applied to the feature maps for
-# spatial information. The resulting feature are then set to a transformer decoder
-# together with the target tokens.
-#
-# TODO: Local attention for lower layer in attention.
-#
-# """
-# import importlib
-# import math
-# from typing import Dict, Optional, Union, Sequence, Type
-#
-# from einops import rearrange
-# from omegaconf import DictConfig, OmegaConf
-# import torch
-# from torch import nn
-# from torch import Tensor
-#
-# from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS
-# from text_recognizer.networks.transformer import (
-# Decoder,
-# DecoderLayer,
-# PositionalEncoding,
-# PositionalEncoding2D,
-# target_padding_mask,
-# )
-#
-# NUM_WORD_PIECES = 1000
-#
-#
-# class CNNTransformer(nn.Module):
-# def __init__(
-# self,
-# input_dim: Sequence[int],
-# output_dims: Sequence[int],
-# encoder: Union[DictConfig, Dict],
-# vocab_size: Optional[int] = None,
-# num_decoder_layers: int = 4,
-# hidden_dim: int = 256,
-# num_heads: int = 4,
-# expansion_dim: int = 1024,
-# dropout_rate: float = 0.1,
-# transformer_activation: str = "glu",
-# *args,
-# **kwargs,
-# ) -> None:
-# super().__init__()
-# self.vocab_size = (
-# NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size
-# )
-# self.pad_index = 3 # TODO: fix me
-# self.hidden_dim = hidden_dim
-# self.max_output_length = output_dims[0]
-#
-# # Image backbone
-# self.encoder = self._configure_encoder(encoder)
-# self.encoder_proj = nn.Conv2d(256, hidden_dim, kernel_size=1)
-# self.feature_map_encoding = PositionalEncoding2D(
-# hidden_dim=hidden_dim, max_h=input_dim[1], max_w=input_dim[2]
-# )
-#
-# # Target token embedding
-# self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim)
-# self.trg_position_encoding = PositionalEncoding(
-# hidden_dim, dropout_rate, max_len=output_dims[0]
-# )
-#
-# # Transformer decoder
-# self.decoder = Decoder(
-# decoder_layer=DecoderLayer(
-# hidden_dim=hidden_dim,
-# num_heads=num_heads,
-# expansion_dim=expansion_dim,
-# dropout_rate=dropout_rate,
-# activation=transformer_activation,
-# ),
-# num_layers=num_decoder_layers,
-# norm=nn.LayerNorm(hidden_dim),
-# )
-#
-# # Classification head
-# self.head = nn.Linear(hidden_dim, self.vocab_size)
-#
-# # Initialize weights
-# self._init_weights()
-#
-# def _init_weights(self) -> None:
-# """Initialize network weights."""
-# self.trg_embedding.weight.data.uniform_(-0.1, 0.1)
-# self.head.bias.data.zero_()
-# self.head.weight.data.uniform_(-0.1, 0.1)
-#
-# nn.init.kaiming_normal_(
-# self.encoder_proj.weight.data,
-# a=0,
-# mode="fan_out",
-# nonlinearity="relu",
-# )
-# if self.encoder_proj.bias is not None:
-# _, fan_out = nn.init._calculate_fan_in_and_fan_out(
-# self.encoder_proj.weight.data
-# )
-# bound = 1 / math.sqrt(fan_out)
-# nn.init.normal_(self.encoder_proj.bias, -bound, bound)
-#
-# @staticmethod
-# def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]:
-# encoder = OmegaConf.create(encoder)
-# args = encoder.args or {}
-# network_module = importlib.import_module("text_recognizer.networks")
-# encoder_class = getattr(network_module, encoder.type)
-# return encoder_class(**args)
-#
-# def encode(self, image: Tensor) -> Tensor:
-# """Extracts image features with backbone.
-#
-# Args:
-# image (Tensor): Image(s) of handwritten text.
-#
-# Retuns:
-# Tensor: Image features.
-#
-# Shapes:
-# - image: :math: `(B, C, H, W)`
-# - latent: :math: `(B, T, C)`
-#
-# """
-# # Extract image features.
-# image_features = self.encoder(image)
-# image_features = self.encoder_proj(image_features)
-#
-# # Add 2d encoding to the feature maps.
-# image_features = self.feature_map_encoding(image_features)
-#
-# # Collapse features maps height and width.
-# image_features = rearrange(image_features, "b c h w -> b (h w) c")
-# return image_features
-#
-# def decode(self, memory: Tensor, trg: Tensor) -> Tensor:
-# """Decodes image features with transformer decoder."""
-# trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index)
-# trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim)
-# trg = rearrange(trg, "b t d -> t b d")
-# trg = self.trg_position_encoding(trg)
-# trg = rearrange(trg, "t b d -> b t d")
-# out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None)
-# logits = self.head(out)
-# return logits
-#
-# def forward(self, image: Tensor, trg: Tensor) -> Tensor:
-# image_features = self.encode(image)
-# output = self.decode(image_features, trg)
-# output = rearrange(output, "b t c -> b c t")
-# return output
-#
-# def predict(self, image: Tensor) -> Tensor:
-# """Transcribes text in image(s)."""
-# bsz = image.shape[0]
-# image_features = self.encode(image)
-#
-# output_tokens = (
-# (torch.ones((bsz, self.max_output_length)) * self.pad_index)
-# .type_as(image)
-# .long()
-# )
-# output_tokens[:, 0] = self.start_index
-# for i in range(1, self.max_output_length):
-# trg = output_tokens[:, :i]
-# output = self.decode(image_features, trg)
-# output = torch.argmax(output, dim=-1)
-# output_tokens[:, i] = output[-1:]
-#
-# # Set all tokens after end token to be padding.
-# for i in range(1, self.max_output_length):
-# indices = output_tokens[:, i - 1] == self.end_index | (
-# output_tokens[:, i - 1] == self.pad_index
-# )
-# output_tokens[indices, i] = self.pad_index
-#
-# return output_tokens