diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-04 05:03:51 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-04 05:03:51 +0200 |
commit | d3afa310f77f47553586eeee58e3d3345a754e2c (patch) | |
tree | 08b7de1daf2550852d0a1e4d4d75202f14bb03d4 /text_recognizer/networks/vqvae/encoder.py | |
parent | 65d5f6c694e73792e40ed693a1381a792da8d277 (diff) |
New VQVAE
Diffstat (limited to 'text_recognizer/networks/vqvae/encoder.py')
-rw-r--r-- | text_recognizer/networks/vqvae/encoder.py | 176 |
1 files changed, 52 insertions, 124 deletions
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py index 65801df..e480545 100644 --- a/text_recognizer/networks/vqvae/encoder.py +++ b/text_recognizer/networks/vqvae/encoder.py @@ -1,147 +1,75 @@ """CNN encoder for the VQ-VAE.""" from typing import Sequence, Optional, Tuple, Type -import torch +import attr from torch import nn from torch import Tensor from text_recognizer.networks.util import activation_function -from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer - - -class _ResidualBlock(nn.Module): - def __init__( - self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]], - ) -> None: - super().__init__() - self.block = [ - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), - nn.ReLU(inplace=True), - nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False), - ] - - if dropout is not None: - self.block.append(dropout) - - self.block = nn.Sequential(*self.block) - - def forward(self, x: Tensor) -> Tensor: - """Apply the residual forward pass.""" - return x + self.block(x) +from text_recognizer.networks.vqvae.residual import Residual +@attr.s(eq=False) class Encoder(nn.Module): """A CNN encoder network.""" - def __init__( - self, - in_channels: int, - channels: Sequence[int], - kernel_sizes: Sequence[int], - strides: Sequence[int], - num_residual_layers: int, - embedding_dim: int, - num_embeddings: int, - beta: float = 0.25, - activation: str = "leaky_relu", - dropout_rate: float = 0.0, - ) -> None: - super().__init__() - - if dropout_rate: - if activation == "selu": - dropout = nn.AlphaDropout(p=dropout_rate) - else: - dropout = nn.Dropout(p=dropout_rate) - else: - dropout = None - - self.embedding_dim = embedding_dim - self.num_embeddings = num_embeddings - self.beta = beta - activation = activation_function(activation) - - # Configure encoder. - self.encoder = self._build_encoder( - in_channels, - channels, - kernel_sizes, - strides, - num_residual_layers, - activation, - dropout, - ) + in_channels: int = attr.ib() + out_channels: int = attr.ib() + res_channels: int = attr.ib() + num_residual_layers: int = attr.ib() + embedding_dim: int = attr.ib() + activation: str = attr.ib() + encoder: nn.Sequential = attr.ib(init=False) - # Configure Vector Quantizer. - self.vector_quantizer = VectorQuantizer( - self.num_embeddings, self.embedding_dim, self.beta - ) - - @staticmethod - def _build_compression_block( - in_channels: int, - channels: int, - kernel_sizes: Sequence[int], - strides: Sequence[int], - activation: Type[nn.Module], - dropout: Optional[Type[nn.Module]], - ) -> nn.ModuleList: - modules = nn.ModuleList([]) - configuration = zip(channels, kernel_sizes, strides) - for out_channels, kernel_size, stride in configuration: - modules.append( - nn.Sequential( - nn.Conv2d( - in_channels, out_channels, kernel_size, stride=stride, padding=1 - ), - activation, - ) - ) - - if dropout is not None: - modules.append(dropout) - - in_channels = out_channels - - return modules + def __attrs_pre_init__(self) -> None: + super().__init__() - def _build_encoder( - self, - in_channels: int, - channels: int, - kernel_sizes: Sequence[int], - strides: Sequence[int], - num_residual_layers: int, - activation: Type[nn.Module], - dropout: Optional[Type[nn.Module]], - ) -> nn.Sequential: - encoder = nn.ModuleList([]) + def __attrs_post_init__(self) -> None: + """Post init configuration.""" + self.encoder = self._build_compression_block() + + def _build_compression_block(self) -> nn.Sequential: + activation_fn = activation_function(self.activation) + block = [ + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels // 2, + kernel_size=4, + stride=2, + padding=1, + ), + activation_fn, + nn.Conv2d( + in_channels=self.out_channels // 2, + out_channels=self.out_channels, + kernel_size=4, + stride=2, + padding=1, + ), + activation_fn, + nn.Conv2d( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + padding=1, + ), + ] - # compression module - encoder.extend( - self._build_compression_block( - in_channels, channels, kernel_sizes, strides, activation, dropout + for _ in range(self.num_residual_layers): + block.append( + Residual(in_channels=self.out_channels, out_channels=self.res_channels) ) - ) - # Bottleneck module. - encoder.extend( - nn.ModuleList( - [ - _ResidualBlock(channels[-1], channels[-1], dropout) - for i in range(num_residual_layers) - ] + block.append( + nn.Conv2d( + in_channels=self.out_channels, + out_channels=self.embedding_dim, + kernel_size=1, ) ) - encoder.append( - nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,) - ) - - return nn.Sequential(*encoder) + return nn.Sequential(*block) def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Encodes input into a discrete representation.""" - z_e = self.encoder(x) - z_q, vq_loss = self.vector_quantizer(z_e) - return z_q, vq_loss + return self.encoder(x) |