diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-21 21:34:53 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-21 21:34:53 +0100 |
commit | b44de0e11281c723ec426f8bec8ca0897ecfe3ff (patch) | |
tree | 998841a3a681d3dedfbe8470c1b8544b4dcbe7a2 /text_recognizer/criterion/n_layer_discriminator.py | |
parent | 3b2fb0fd977a6aff4dcf88e1a0f99faac51e05b1 (diff) |
Remove VQVAE stuff, did not work...
Diffstat (limited to 'text_recognizer/criterion/n_layer_discriminator.py')
-rw-r--r-- | text_recognizer/criterion/n_layer_discriminator.py | 59 |
1 files changed, 0 insertions, 59 deletions
diff --git a/text_recognizer/criterion/n_layer_discriminator.py b/text_recognizer/criterion/n_layer_discriminator.py deleted file mode 100644 index a9f47f9..0000000 --- a/text_recognizer/criterion/n_layer_discriminator.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Pix2pix discriminator loss.""" -from torch import nn, Tensor - -from text_recognizer.networks.vqvae.norm import Normalize - - -class NLayerDiscriminator(nn.Module): - """Defines a PatchGAN discriminator loss in Pix2Pix.""" - - def __init__( - self, in_channels: int = 1, num_channels: int = 32, num_layers: int = 3 - ) -> None: - super().__init__() - self.in_channels = in_channels - self.num_channels = num_channels - self.num_layers = num_layers - self.discriminator = self._build_discriminator() - - def _build_discriminator(self) -> nn.Sequential: - """Builds discriminator.""" - discriminator = [ - nn.Sigmoid(), - nn.Conv2d( - in_channels=self.in_channels, - out_channels=self.num_channels, - kernel_size=4, - stride=2, - padding=1, - ), - nn.Mish(inplace=True), - ] - in_channels = self.num_channels - for n in range(1, self.num_layers): - discriminator += [ - nn.Conv2d( - in_channels=in_channels, - out_channels=in_channels * n, - kernel_size=4, - stride=2, - padding=1, - ), - # Normalize(num_channels=in_channels * n), - nn.Mish(inplace=True), - ] - in_channels *= n - - discriminator += [ - nn.Conv2d( - in_channels=self.num_channels * (self.num_layers - 1), - out_channels=1, - kernel_size=4, - padding=1, - ) - ] - return nn.Sequential(*discriminator) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass through discriminator.""" - return self.discriminator(x) |