summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-06-20 22:18:25 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-06-20 22:18:25 +0200
commitcbe968b84546976ebdd666b5adbaead09b5c0c01 (patch)
treec0b14b3835a1d3105ead17507b259cc533daa03c
parent9995922ff957ce424dca0655a01d8338a519aa86 (diff)
Remove resnet
-rw-r--r--text_recognizer/networks/encoders/residual_network.py310
1 files changed, 0 insertions, 310 deletions
diff --git a/text_recognizer/networks/encoders/residual_network.py b/text_recognizer/networks/encoders/residual_network.py
deleted file mode 100644
index c33f419..0000000
--- a/text_recognizer/networks/encoders/residual_network.py
+++ /dev/null
@@ -1,310 +0,0 @@
-"""Residual CNN."""
-from functools import partial
-from typing import Callable, Dict, List, Optional, Type, Union
-
-from einops.layers.torch import Rearrange, Reduce
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import activation_function
-
-
-class Conv2dAuto(nn.Conv2d):
- """Convolution with auto padding based on kernel size."""
-
- def __init__(self, *args, **kwargs) -> None:
- super().__init__(*args, **kwargs)
- self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2)
-
-
-def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential:
- """3x3 convolution with batch norm."""
- conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,)
- return nn.Sequential(
- conv3x3(in_channels, out_channels, *args, **kwargs),
- nn.BatchNorm2d(out_channels),
- )
-
-
-class IdentityBlock(nn.Module):
- """Residual with identity block."""
-
- def __init__(
- self, in_channels: int, out_channels: int, activation: str = "relu"
- ) -> None:
- super().__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.blocks = nn.Identity()
- self.activation_fn = activation_function(activation)
- self.shortcut = nn.Identity()
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass."""
- residual = x
- if self.apply_shortcut:
- residual = self.shortcut(x)
- x = self.blocks(x)
- x += residual
- x = self.activation_fn(x)
- return x
-
- @property
- def apply_shortcut(self) -> bool:
- """Check if shortcut should be applied."""
- return self.in_channels != self.out_channels
-
-
-class ResidualBlock(IdentityBlock):
- """Residual with nonlinear shortcut."""
-
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- expansion: int = 1,
- downsampling: int = 1,
- *args,
- **kwargs
- ) -> None:
- """Short summary.
-
- Args:
- in_channels (int): Number of in channels.
- out_channels (int): umber of out channels.
- expansion (int): Expansion factor of the out channels. Defaults to 1.
- downsampling (int): Downsampling factor used in stride. Defaults to 1.
- *args (type): Extra arguments.
- **kwargs (type): Extra key value arguments.
-
- """
- super().__init__(in_channels, out_channels, *args, **kwargs)
- self.expansion = expansion
- self.downsampling = downsampling
-
- self.shortcut = (
- nn.Sequential(
- nn.Conv2d(
- in_channels=self.in_channels,
- out_channels=self.expanded_channels,
- kernel_size=1,
- stride=self.downsampling,
- bias=False,
- ),
- nn.BatchNorm2d(self.expanded_channels),
- )
- if self.apply_shortcut
- else None
- )
-
- @property
- def expanded_channels(self) -> int:
- """Computes the expanded output channels."""
- return self.out_channels * self.expansion
-
- @property
- def apply_shortcut(self) -> bool:
- """Check if shortcut should be applied."""
- return self.in_channels != self.expanded_channels
-
-
-class BasicBlock(ResidualBlock):
- """Basic ResNet block."""
-
- expansion = 1
-
- def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None:
- super().__init__(in_channels, out_channels, *args, **kwargs)
- self.blocks = nn.Sequential(
- conv_bn(
- in_channels=self.in_channels,
- out_channels=self.out_channels,
- bias=False,
- stride=self.downsampling,
- ),
- self.activation_fn,
- conv_bn(
- in_channels=self.out_channels,
- out_channels=self.expanded_channels,
- bias=False,
- ),
- )
-
-
-class BottleNeckBlock(ResidualBlock):
- """Bottleneck block to increase depth while minimizing parameter size."""
-
- expansion = 4
-
- def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None:
- super().__init__(in_channels, out_channels, *args, **kwargs)
- self.blocks = nn.Sequential(
- conv_bn(
- in_channels=self.in_channels,
- out_channels=self.out_channels,
- kernel_size=1,
- ),
- self.activation_fn,
- conv_bn(
- in_channels=self.out_channels,
- out_channels=self.out_channels,
- kernel_size=3,
- stride=self.downsampling,
- ),
- self.activation_fn,
- conv_bn(
- in_channels=self.out_channels,
- out_channels=self.expanded_channels,
- kernel_size=1,
- ),
- )
-
-
-class ResidualLayer(nn.Module):
- """ResNet layer."""
-
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- block: BasicBlock = BasicBlock,
- num_blocks: int = 1,
- *args,
- **kwargs
- ) -> None:
- super().__init__()
- downsampling = 2 if in_channels != out_channels else 1
- self.blocks = nn.Sequential(
- block(
- in_channels, out_channels, *args, **kwargs, downsampling=downsampling
- ),
- *[
- block(
- out_channels * block.expansion,
- out_channels,
- downsampling=1,
- *args,
- **kwargs
- )
- for _ in range(num_blocks - 1)
- ]
- )
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass."""
- x = self.blocks(x)
- return x
-
-
-class ResidualNetworkEncoder(nn.Module):
- """Encoder network."""
-
- def __init__(
- self,
- in_channels: int = 1,
- block_sizes: Union[int, List[int]] = (32, 64),
- depths: Union[int, List[int]] = (2, 2),
- activation: str = "relu",
- block: Type[nn.Module] = BasicBlock,
- levels: int = 1,
- *args,
- **kwargs
- ) -> None:
- super().__init__()
- self.block_sizes = (
- block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels
- )
- self.depths = depths if isinstance(depths, list) else [depths] * levels
- self.activation = activation
- self.gate = nn.Sequential(
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=self.block_sizes[0],
- kernel_size=7,
- stride=2,
- padding=1,
- bias=False,
- ),
- nn.BatchNorm2d(self.block_sizes[0]),
- activation_function(self.activation),
- # nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
- )
-
- self.blocks = self._configure_blocks(block)
-
- def _configure_blocks(
- self, block: Type[nn.Module], *args, **kwargs
- ) -> nn.Sequential:
- channels = [self.block_sizes[0]] + list(
- zip(self.block_sizes, self.block_sizes[1:])
- )
- blocks = [
- ResidualLayer(
- in_channels=channels[0],
- out_channels=channels[0],
- num_blocks=self.depths[0],
- block=block,
- activation=self.activation,
- *args,
- **kwargs
- )
- ]
- blocks += [
- ResidualLayer(
- in_channels=in_channels * block.expansion,
- out_channels=out_channels,
- num_blocks=num_blocks,
- block=block,
- activation=self.activation,
- *args,
- **kwargs
- )
- for (in_channels, out_channels), num_blocks in zip(
- channels[1:], self.depths[1:]
- )
- ]
-
- return nn.Sequential(*blocks)
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass."""
- # If batch dimenstion is missing, it needs to be added.
- if len(x.shape) == 3:
- x = x.unsqueeze(0)
- x = self.gate(x)
- x = self.blocks(x)
- return x
-
-
-class ResidualNetworkDecoder(nn.Module):
- """Classification head."""
-
- def __init__(self, in_features: int, num_classes: int = 80) -> None:
- super().__init__()
- self.decoder = nn.Sequential(
- Reduce("b c h w -> b c", "mean"),
- nn.Linear(in_features=in_features, out_features=num_classes),
- )
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass."""
- return self.decoder(x)
-
-
-class ResidualNetwork(nn.Module):
- """Full residual network."""
-
- def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None:
- super().__init__()
- self.encoder = ResidualNetworkEncoder(in_channels, *args, **kwargs)
- self.decoder = ResidualNetworkDecoder(
- in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels,
- num_classes=num_classes,
- )
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass."""
- x = self.encoder(x)
- x = self.decoder(x)
- return x