diff options
Diffstat (limited to 'text_recognizer')
20 files changed, 247 insertions, 195 deletions
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 2380660..0a30a42 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -19,18 +19,10 @@ class IAMExtendedParagraphs(BaseDataModule): super().__init__(batch_size, num_workers) self.iam_paragraphs = IAMParagraphs( - batch_size, - num_workers, - train_fraction, - augment, - word_pieces, + batch_size, num_workers, train_fraction, augment, word_pieces, ) self.iam_synthetic_paragraphs = IAMSyntheticParagraphs( - batch_size, - num_workers, - train_fraction, - augment, - word_pieces, + batch_size, num_workers, train_fraction, augment, word_pieces, ) self.dims = self.iam_paragraphs.dims diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index 62c44f9..24409bc 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -101,7 +101,7 @@ class IAMParagraphs(BaseDataModule): data, targets, transform=get_transform(image_shape=self.dims[1:], augment=augment), - target_transform=get_target_transform(self.word_pieces) + target_transform=get_target_transform(self.word_pieces), ) logger.info(f"Loading IAM paragraph regions and lines for {stage}...") @@ -162,10 +162,7 @@ def get_dataset_properties() -> Dict: "min": min(_get_property_values("num_lines")), "max": max(_get_property_values("num_lines")), }, - "crop_shape": { - "min": crop_shapes.min(axis=0), - "max": crop_shapes.max(axis=0), - }, + "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),}, "aspect_ratio": { "min": aspect_ratio.min(axis=0), "max": aspect_ratio.max(axis=0), @@ -286,9 +283,7 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com ), transforms.ColorJitter(brightness=(0.8, 1.6)), transforms.RandomAffine( - degrees=1, - shear=(-10, 10), - interpolation=InterpolationMode.BILINEAR, + degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR, ), ] else: @@ -296,10 +291,12 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Com transforms_list.append(transforms.ToTensor()) return transforms.Compose(transforms_list) + def get_target_transform(word_pieces: bool) -> Optional[transforms.Compose]: """Transform emnist characters to word pieces.""" return transforms.Compose([WordPiece()]) if word_pieces else None + def _labels_filename(split: str) -> Path: """Return filename of processed labels.""" return PROCESSED_DATA_DIRNAME / split / "_labels.json" diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index 4ccc5c2..78e6c05 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -97,7 +97,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): transform=get_transform( image_shape=self.dims[1:], augment=self.augment ), - target_transform=get_target_transform(self.word_pieces) + target_transform=get_target_transform(self.word_pieces), ) def __repr__(self) -> str: diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py index f4016ba..190febe 100644 --- a/text_recognizer/data/mappings.py +++ b/text_recognizer/data/mappings.py @@ -58,13 +58,13 @@ class WordPieceMapping(EmnistMapping): def __init__( self, num_features: int = 1000, - tokens: str = "iamdb_1kwp_tokens_1000.txt" , + tokens: str = "iamdb_1kwp_tokens_1000.txt", lexicon: str = "iamdb_1kwp_lex_1000.txt", data_dir: Optional[Union[str, Path]] = None, use_words: bool = False, prepend_wordsep: bool = False, special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"), - extra_symbols: Optional[Sequence[str]] = ("\n", ), + extra_symbols: Optional[Sequence[str]] = ("\n",), ) -> None: super().__init__(extra_symbols) self.wordpiece_processor = self._configure_wordpiece_processor( @@ -90,7 +90,13 @@ class WordPieceMapping(EmnistMapping): extra_symbols: Optional[Sequence[str]], ) -> Preprocessor: data_dir = ( - (Path(__file__).resolve().parents[2] / "data" / "downloaded" / "iam" / "iamdb") + ( + Path(__file__).resolve().parents[2] + / "data" + / "downloaded" + / "iam" + / "iamdb" + ) if data_dir is None else Path(data_dir) ) diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py index 8d1bedd..d0f1f35 100644 --- a/text_recognizer/data/transforms.py +++ b/text_recognizer/data/transforms.py @@ -13,7 +13,7 @@ class WordPiece: def __init__( self, num_features: int = 1000, - tokens: str = "iamdb_1kwp_tokens_1000.txt" , + tokens: str = "iamdb_1kwp_tokens_1000.txt", lexicon: str = "iamdb_1kwp_lex_1000.txt", data_dir: Optional[Union[str, Path]] = None, use_words: bool = False, @@ -35,4 +35,4 @@ class WordPiece: self.max_len = max_len def __call__(self, x: Tensor) -> Tensor: - return self.mapping.emnist_to_wordpiece_indices(x)[:self.max_len] + return self.mapping.emnist_to_wordpiece_indices(x)[: self.max_len] diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 7dc1352..8dd4db2 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -39,7 +39,7 @@ class LitTransformerModel(LitBaseModel): def configure_mapping(mapping: Optional[List[str]]) -> Tuple[List[str], List[int]]: """Configure mapping.""" # TODO: Fix me!!! - mapping, inverse_mapping, _ = emnist_mapping() + mapping, inverse_mapping, _ = emnist_mapping(["\n"]) start_index = inverse_mapping["<s>"] end_index = inverse_mapping["<e>"] pad_index = inverse_mapping["<p>"] diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py index 41fd43f..63b43b2 100644 --- a/text_recognizer/networks/__init__.py +++ b/text_recognizer/networks/__init__.py @@ -1,2 +1,4 @@ """Network modules""" +from .backbones import EfficientNet from .vqvae import VQVAE +from .cnn_transformer import CNNTransformer diff --git a/text_recognizer/networks/backbones/__init__.py b/text_recognizer/networks/backbones/__init__.py new file mode 100644 index 0000000..25aed0e --- /dev/null +++ b/text_recognizer/networks/backbones/__init__.py @@ -0,0 +1,2 @@ +"""Vision backbones.""" +from .efficientnet import EfficientNet diff --git a/text_recognizer/networks/backbones/efficientnet.py b/text_recognizer/networks/backbones/efficientnet.py new file mode 100644 index 0000000..61dea77 --- /dev/null +++ b/text_recognizer/networks/backbones/efficientnet.py @@ -0,0 +1,145 @@ +"""Efficient net b0 implementation.""" +import torch +from torch import nn +from torch import Tensor + + +class ConvNorm(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + padding: int, + groups: int = 1, + ) -> None: + super().__init__() + self.block = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias=False, + ), + nn.BatchNorm2d(num_features=out_channels), + nn.SiLU(inplace=True), + ) + + def forward(self, x: Tensor) -> Tensor: + return self.block(x) + + +class SqueezeExcite(nn.Module): + def __init__(self, in_channels: int, reduce_dim: int) -> None: + super().__init__() + self.se = nn.Sequential( + nn.AdaptiveAvgPool2d(1), # [C, H, W] -> [C, 1, 1] + nn.Conv2d(in_channels=in_channels, out_channels=reduce_dim, kernel_size=1), + nn.SiLU(), + nn.Conv2d(in_channels=reduce_dim, out_channels=in_channels, kernel_size=1), + nn.Sigmoid(), + ) + + def forward(self, x: Tensor) -> Tensor: + return x * self.se(x) + + +class InvertedResidulaBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + padding: int, + expand_ratio: float, + reduction: int = 4, + survival_prob: float = 0.8, + ) -> None: + super().__init__() + self.survival_prob = survival_prob + self.use_residual = in_channels == out_channels and stride == 1 + hidden_dim = in_channels * expand_ratio + self.expand = in_channels != hidden_dim + reduce_dim = in_channels // reduction + + if self.expand: + self.expand_conv = ConvNorm( + in_channels, hidden_dim, kernel_size=3, stride=1, padding=1 + ) + + self.conv = nn.Sequential( + ConvNorm( + hidden_dim, hidden_dim, kernel_size, stride, padding, groups=hidden_dim + ), + SqueezeExcite(hidden_dim, reduce_dim), + nn.Conv2d( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + bias=False, + ), + nn.BatchNorm2d(num_features=out_channels), + ) + + def stochastic_depth(self, x: Tensor) -> Tensor: + if not self.training: + return x + + binary_tensor = ( + torch.rand(x.shape[0], 1, 1, 1, device=x.device) < self.survival_prob + ) + return torch.div(x, self.survival_prob) * binary_tensor + + def forward(self, x: Tensor) -> Tensor: + out = self.expand_conv(x) if self.expand else x + if self.use_residual: + return self.stochastic_depth(self.conv(out)) + x + return self.conv(out) + + +class EfficientNet(nn.Module): + """Efficient net b0 backbone.""" + + def __init__(self) -> None: + super().__init__() + self.base_model = [ + # expand_ratio, channels, repeats, stride, kernel_size + [1, 16, 1, 1, 3], + [6, 24, 2, 2, 3], + [6, 40, 2, 2, 5], + [6, 80, 3, 2, 3], + [6, 112, 3, 1, 5], + [6, 192, 4, 2, 5], + [6, 320, 1, 1, 3], + ] + + self.backbone = self._build_b0() + + def _build_b0(self) -> nn.Sequential: + in_channels = 32 + layers = [ConvNorm(1, in_channels, 3, stride=2, padding=1)] + + for expand_ratio, out_channels, repeats, stride, kernel_size in self.base_model: + for i in range(repeats): + layers.append( + InvertedResidulaBlock( + in_channels, + out_channels, + expand_ratio=expand_ratio, + stride=stride if i == 0 else 1, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + ) + in_channels = out_channels + layers.append(ConvNorm(in_channels, 256, kernel_size=1, stride=1, padding=0)) + + return nn.Sequential(*layers) + + def forward(self, x: Tensor) -> Tensor: + return self.backbone(x) diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py index e23a15d..d42c29d 100644 --- a/text_recognizer/networks/cnn_transformer.py +++ b/text_recognizer/networks/cnn_transformer.py @@ -33,8 +33,8 @@ NUM_WORD_PIECES = 1000 class CNNTransformer(nn.Module): def __init__( self, - input_shape: Sequence[int], - output_shape: Sequence[int], + input_dim: Sequence[int], + output_dims: Sequence[int], encoder: Union[DictConfig, Dict], vocab_size: Optional[int] = None, num_decoder_layers: int = 4, @@ -43,22 +43,29 @@ class CNNTransformer(nn.Module): expansion_dim: int = 1024, dropout_rate: float = 0.1, transformer_activation: str = "glu", + *args, + **kwargs, ) -> None: + super().__init__() self.vocab_size = ( NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size ) + self.pad_index = 3 # TODO: fix me self.hidden_dim = hidden_dim - self.max_output_length = output_shape[0] + self.max_output_length = output_dims[0] # Image backbone self.encoder = self._configure_encoder(encoder) + self.encoder_proj = nn.Conv2d(256, hidden_dim, kernel_size=1) self.feature_map_encoding = PositionalEncoding2D( - hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2] + hidden_dim=hidden_dim, max_h=input_dim[1], max_w=input_dim[2] ) # Target token embedding self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) - self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) + self.trg_position_encoding = PositionalEncoding( + hidden_dim, dropout_rate, max_len=output_dims[0] + ) # Transformer decoder self.decoder = Decoder( @@ -86,24 +93,25 @@ class CNNTransformer(nn.Module): self.head.weight.data.uniform_(-0.1, 0.1) nn.init.kaiming_normal_( - self.feature_map_encoding.weight.data, + self.encoder_proj.weight.data, a=0, mode="fan_out", nonlinearity="relu", ) - if self.feature_map_encoding.bias is not None: + if self.encoder_proj.bias is not None: _, fan_out = nn.init._calculate_fan_in_and_fan_out( - self.feature_map_encoding.weight.data + self.encoder_proj.weight.data ) bound = 1 / math.sqrt(fan_out) - nn.init.normal_(self.feature_map_encoding.bias, -bound, bound) + nn.init.normal_(self.encoder_proj.bias, -bound, bound) @staticmethod def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: encoder = OmegaConf.create(encoder) + args = encoder.args or {} network_module = importlib.import_module("text_recognizer.networks") encoder_class = getattr(network_module, encoder.type) - return encoder_class(**encoder.args) + return encoder_class(**args) def encode(self, image: Tensor) -> Tensor: """Extracts image features with backbone. @@ -121,6 +129,7 @@ class CNNTransformer(nn.Module): """ # Extract image features. image_features = self.encoder(image) + image_features = self.encoder_proj(image_features) # Add 2d encoding to the feature maps. image_features = self.feature_map_encoding(image_features) @@ -133,11 +142,19 @@ class CNNTransformer(nn.Module): """Decodes image features with transformer decoder.""" trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index) trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim) + trg = rearrange(trg, "b t d -> t b d") trg = self.trg_position_encoding(trg) + trg = rearrange(trg, "t b d -> b t d") out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None) logits = self.head(out) return logits + def forward(self, image: Tensor, trg: Tensor) -> Tensor: + image_features = self.encode(image) + output = self.decode(image_features, trg) + output = rearrange(output, "b t c -> b c t") + return output + def predict(self, image: Tensor) -> Tensor: """Transcribes text in image(s).""" bsz = image.shape[0] diff --git a/text_recognizer/networks/coat/__init__.py b/text_recognizer/networks/coat/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/text_recognizer/networks/coat/__init__.py +++ /dev/null diff --git a/text_recognizer/networks/coat/factor_attention.py b/text_recognizer/networks/coat/factor_attention.py deleted file mode 100644 index f91c5ef..0000000 --- a/text_recognizer/networks/coat/factor_attention.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Factorized attention with convolutional relative positional encodings.""" -from torch import nn - - -class FactorAttention(nn.Module): - """Factorized attention with relative positional encodings.""" - def __init__(self, dim: int, num_heads: int) -> None: - pass - diff --git a/text_recognizer/networks/coat/patch_embedding.py b/text_recognizer/networks/coat/patch_embedding.py deleted file mode 100644 index 3b7b76a..0000000 --- a/text_recognizer/networks/coat/patch_embedding.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Patch embedding for images and feature maps.""" -from typing import Sequence, Tuple - -from einops import rearrange -from loguru import logger -from torch import nn -from torch import Tensor - - -class PatchEmbedding(nn.Module): - """Patch embedding of images.""" - - def __init__( - self, - image_shape: Sequence[int], - patch_size: int = 16, - in_channels: int = 1, - embedding_dim: int = 512, - ) -> None: - if image_shape[0] % patch_size == 0 and image_shape[1] % patch_size == 0: - logger.error( - f"Image shape {image_shape} not divisable by patch size {patch_size}" - ) - - self.patch_size = patch_size - self.embedding = nn.Conv2d( - in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size - ) - self.norm = nn.LayerNorm(embedding_dim) - - def forward(self, x: Tensor) -> Tuple[Tensor, Tuple[int, int]]: - """Embeds image or feature maps with patch embedding.""" - _, _, h, w = x.shape - h_out, w_out = h // self.patch_size, w // self.patch_size - x = self.embedding(x) - x = rearrange(x, "b c h w -> b (h w) c") - x = self.norm(x) - return x, (h_out, w_out) diff --git a/text_recognizer/networks/coat/positional_encodings.py b/text_recognizer/networks/coat/positional_encodings.py deleted file mode 100644 index 925db04..0000000 --- a/text_recognizer/networks/coat/positional_encodings.py +++ /dev/null @@ -1,76 +0,0 @@ -"""Positional encodings for input sequence to transformer.""" -from typing import Dict, Union, Tuple - -from einops import rearrange -from loguru import logger -import torch -from torch import nn -from torch import Tensor - - -class RelativeEncoding(nn.Module): - """Relative positional encoding.""" - def __init__(self, channels: int, heads: int, windows: Union[int, Dict[int, int]]) -> None: - super().__init__() - self.windows = {windows: heads} if isinstance(windows, int) else windows - self.heads = list(self.windows.values()) - self.channel_heads = [head * channels for head in self.heads] - self.convs = nn.ModuleList([ - nn.Conv2d(in_channels=head * channels, - out_channels=head * channels, - kernel_shape=window, - padding=window // 2, - dilation=1, - groups=head * channels, - ) for window, head in self.windows.items()]) - - def forward(self, q: Tensor, v: Tensor, shape: Tuple[int, int]) -> Tensor: - """Applies relative positional encoding.""" - b, heads, hw, c = q.shape - h, w = shape - if hw != h * w: - logger.exception(f"Query width {hw} neq to height x width {h * w}") - raise ValueError - - v = rearrange(v, "b heads (h w) c -> b (heads c) h w", h=h, w=w) - v = torch.split(v, self.channel_heads, dim=1) - v = [conv(x) for conv, x in zip(self.convs, v)] - v = torch.cat(v, dim=1) - v = rearrange(v, "b (heads c) h w -> b heads (h w) c", heads=heads) - - encoding = q * v - zeros = torch.zeros((b, heads, 1, c), dtype=q.dtype, layout=q.layout, device=q.device) - encoding = torch.cat((zeros, encoding), dim=2) - return encoding - - -class PositionalEncoding(nn.Module): - """Convolutional positional encoding.""" - def __init__(self, dim: int, k: int = 3) -> None: - super().__init__() - self.encode = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=k, stride=1, padding=k//2, groups=dim) - - def forward(self, x: Tensor, shape: Tuple[int, int]) -> Tensor: - """Applies convolutional encoding.""" - _, hw, _ = x.shape - h, w = shape - - if hw != h * w: - logger.exception(f"Query width {hw} neq to height x width {h * w}") - raise ValueError - - # Depthwise convolution. - x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) - x = self.encode(x) + x - x = rearrange(x, "b c h w -> b (h w) c") - return x - - - - - - - - - - diff --git a/text_recognizer/networks/residual_network.py b/text_recognizer/networks/residual_network.py index da7553d..c33f419 100644 --- a/text_recognizer/networks/residual_network.py +++ b/text_recognizer/networks/residual_network.py @@ -20,11 +20,7 @@ class Conv2dAuto(nn.Conv2d): def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential: """3x3 convolution with batch norm.""" - conv3x3 = partial( - Conv2dAuto, - kernel_size=3, - bias=False, - ) + conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,) return nn.Sequential( conv3x3(in_channels, out_channels, *args, **kwargs), nn.BatchNorm2d(out_channels), diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py index b10f93a..d7e3d08 100644 --- a/text_recognizer/networks/transducer/transducer.py +++ b/text_recognizer/networks/transducer/transducer.py @@ -392,12 +392,7 @@ def load_transducer_loss( transitions = gtn.load(str(processed_path / transitions)) preprocessor = Preprocessor( - data_dir, - num_features, - tokens_path, - lexicon_path, - use_words, - prepend_wordsep, + data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep, ) num_tokens = preprocessor.num_tokens diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py index 5874e97..c50afc3 100644 --- a/text_recognizer/networks/transformer/positional_encoding.py +++ b/text_recognizer/networks/transformer/positional_encoding.py @@ -33,7 +33,10 @@ class PositionalEncoding(nn.Module): def forward(self, x: Tensor) -> Tensor: """Encodes the tensor with a postional embedding.""" - x = x + self.pe[:, : x.shape[1]] + # [T, B, D] + if x.shape[2] != self.pe.shape[2]: + raise ValueError(f"x shape does not match pe in the 3rd dim.") + x = x + self.pe[: x.shape[0]] return self.dropout(x) @@ -48,6 +51,7 @@ class PositionalEncoding2D(nn.Module): pe = self.make_pe(hidden_dim, max_h, max_w) self.register_buffer("pe", pe) + @staticmethod def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor: """Returns 2d postional encoding.""" pe_h = PositionalEncoding.make_pe( diff --git a/text_recognizer/networks/transformer/rotary_embedding.py b/text_recognizer/networks/transformer/rotary_embedding.py new file mode 100644 index 0000000..5e80572 --- /dev/null +++ b/text_recognizer/networks/transformer/rotary_embedding.py @@ -0,0 +1,39 @@ +"""Roatary embedding. + +Stolen from lucidrains: + https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py + +Explanation of roatary: + https://blog.eleuther.ai/rotary-embeddings/ + +""" +from typing import Tuple + +from einops import rearrange +import torch +from torch import nn +from torch import Tensor + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim: int): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, x: Tensor, seq_dim: int = 1) -> Tensor: + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + freqs = torch.einsum("i , j -> i j", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb[None, :, :] + + +def rotate_half(x: Tensor) -> Tensor: + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q: Tensor, k: Tensor, freqs: Tensor) -> Tuple[Tensor, Tensor]: + q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k)) + return q, k diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py index 93a1e43..32de912 100644 --- a/text_recognizer/networks/vqvae/decoder.py +++ b/text_recognizer/networks/vqvae/decoder.py @@ -44,12 +44,7 @@ class Decoder(nn.Module): # Configure encoder. self.decoder = self._build_decoder( - channels, - kernel_sizes, - strides, - num_residual_layers, - activation, - dropout, + channels, kernel_sizes, strides, num_residual_layers, activation, dropout, ) def _build_decompression_block( @@ -78,9 +73,7 @@ class Decoder(nn.Module): ) if self.upsampling and i < len(self.upsampling): - modules.append( - nn.Upsample(size=self.upsampling[i]), - ) + modules.append(nn.Upsample(size=self.upsampling[i]),) if dropout is not None: modules.append(dropout) @@ -109,12 +102,7 @@ class Decoder(nn.Module): ) -> nn.Sequential: self.res_block.append( - nn.Conv2d( - self.embedding_dim, - channels[0], - kernel_size=1, - stride=1, - ) + nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,) ) # Bottleneck module. diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py index b0cceed..65801df 100644 --- a/text_recognizer/networks/vqvae/encoder.py +++ b/text_recognizer/networks/vqvae/encoder.py @@ -11,10 +11,7 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer class _ResidualBlock(nn.Module): def __init__( - self, - in_channels: int, - out_channels: int, - dropout: Optional[Type[nn.Module]], + self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]], ) -> None: super().__init__() self.block = [ @@ -138,12 +135,7 @@ class Encoder(nn.Module): ) encoder.append( - nn.Conv2d( - channels[-1], - self.embedding_dim, - kernel_size=1, - stride=1, - ) + nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,) ) return nn.Sequential(*encoder) |