summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-22 08:15:58 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-22 08:15:58 +0200
commit1ca8b0b9e0613c1e02f6a5d8b49e20c4d6916412 (patch)
tree5e610ac459c9b254f8826e92372346f01f8e2412 /text_recognizer/networks/vqvae
parentffa4be4bf4e3758e01d52a9c1f354a05a90b93de (diff)
Fixed training script, able to train vqvae
Diffstat (limited to 'text_recognizer/networks/vqvae')
-rw-r--r--text_recognizer/networks/vqvae/decoder.py20
-rw-r--r--text_recognizer/networks/vqvae/encoder.py30
-rw-r--r--text_recognizer/networks/vqvae/vqvae.py5
3 files changed, 38 insertions, 17 deletions
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
index 8847aba..93a1e43 100644
--- a/text_recognizer/networks/vqvae/decoder.py
+++ b/text_recognizer/networks/vqvae/decoder.py
@@ -44,7 +44,12 @@ class Decoder(nn.Module):
# Configure encoder.
self.decoder = self._build_decoder(
- channels, kernel_sizes, strides, num_residual_layers, activation, dropout,
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ activation,
+ dropout,
)
def _build_decompression_block(
@@ -72,8 +77,10 @@ class Decoder(nn.Module):
)
)
- if i < len(self.upsampling):
- modules.append(nn.Upsample(size=self.upsampling[i]),)
+ if self.upsampling and i < len(self.upsampling):
+ modules.append(
+ nn.Upsample(size=self.upsampling[i]),
+ )
if dropout is not None:
modules.append(dropout)
@@ -102,7 +109,12 @@ class Decoder(nn.Module):
) -> nn.Sequential:
self.res_block.append(
- nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,)
+ nn.Conv2d(
+ self.embedding_dim,
+ channels[0],
+ kernel_size=1,
+ stride=1,
+ )
)
# Bottleneck module.
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
index d3adac5..b0cceed 100644
--- a/text_recognizer/networks/vqvae/encoder.py
+++ b/text_recognizer/networks/vqvae/encoder.py
@@ -1,5 +1,5 @@
"""CNN encoder for the VQ-VAE."""
-from typing import List, Optional, Tuple, Type
+from typing import Sequence, Optional, Tuple, Type
import torch
from torch import nn
@@ -11,7 +11,10 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer
class _ResidualBlock(nn.Module):
def __init__(
- self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]],
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: Optional[Type[nn.Module]],
) -> None:
super().__init__()
self.block = [
@@ -36,9 +39,9 @@ class Encoder(nn.Module):
def __init__(
self,
in_channels: int,
- channels: List[int],
- kernel_sizes: List[int],
- strides: List[int],
+ channels: Sequence[int],
+ kernel_sizes: Sequence[int],
+ strides: Sequence[int],
num_residual_layers: int,
embedding_dim: int,
num_embeddings: int,
@@ -77,12 +80,12 @@ class Encoder(nn.Module):
self.num_embeddings, self.embedding_dim, self.beta
)
+ @staticmethod
def _build_compression_block(
- self,
in_channels: int,
channels: int,
- kernel_sizes: List[int],
- strides: List[int],
+ kernel_sizes: Sequence[int],
+ strides: Sequence[int],
activation: Type[nn.Module],
dropout: Optional[Type[nn.Module]],
) -> nn.ModuleList:
@@ -109,8 +112,8 @@ class Encoder(nn.Module):
self,
in_channels: int,
channels: int,
- kernel_sizes: List[int],
- strides: List[int],
+ kernel_sizes: Sequence[int],
+ strides: Sequence[int],
num_residual_layers: int,
activation: Type[nn.Module],
dropout: Optional[Type[nn.Module]],
@@ -135,7 +138,12 @@ class Encoder(nn.Module):
)
encoder.append(
- nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,)
+ nn.Conv2d(
+ channels[-1],
+ self.embedding_dim,
+ kernel_size=1,
+ stride=1,
+ )
)
return nn.Sequential(*encoder)
diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py
index 50448b4..1f08e5e 100644
--- a/text_recognizer/networks/vqvae/vqvae.py
+++ b/text_recognizer/networks/vqvae/vqvae.py
@@ -1,8 +1,7 @@
"""The VQ-VAE."""
-from typing import List, Optional, Tuple, Type
+from typing import Any, Dict, List, Optional, Tuple
-import torch
from torch import nn
from torch import Tensor
@@ -25,6 +24,8 @@ class VQVAE(nn.Module):
beta: float = 0.25,
activation: str = "leaky_relu",
dropout_rate: float = 0.0,
+ *args: Any,
+ **kwargs: Dict,
) -> None:
super().__init__()