summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks')
-rw-r--r--src/text_recognizer/networks/__init__.py4
-rw-r--r--src/text_recognizer/networks/beam.py83
-rw-r--r--src/text_recognizer/networks/cnn_transformer.py19
-rw-r--r--src/text_recognizer/networks/fcn.py99
-rw-r--r--src/text_recognizer/networks/neural_machine_reader.py201
-rw-r--r--src/text_recognizer/networks/residual_network.py7
-rw-r--r--src/text_recognizer/networks/unet.py159
7 files changed, 230 insertions, 342 deletions
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index 1635039..f958672 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -3,11 +3,13 @@ from .cnn_transformer import CNNTransformer
from .crnn import ConvolutionalRecurrentNetwork
from .ctc import greedy_decoder
from .densenet import DenseNet
+from .fcn import FCN
from .lenet import LeNet
from .metrics import accuracy, accuracy_ignore_pad, cer, wer
from .mlp import MLP
from .residual_network import ResidualNetwork, ResidualNetworkEncoder
from .transformer import Transformer
+from .unet import UNet
from .util import sliding_window
from .wide_resnet import WideResidualNetwork
@@ -18,12 +20,14 @@ __all__ = [
"CNNTransformer",
"ConvolutionalRecurrentNetwork",
"DenseNet",
+ "FCN",
"greedy_decoder",
"MLP",
"LeNet",
"ResidualNetwork",
"ResidualNetworkEncoder",
"sliding_window",
+ "UNet",
"Transformer",
"wer",
"WideResidualNetwork",
diff --git a/src/text_recognizer/networks/beam.py b/src/text_recognizer/networks/beam.py
new file mode 100644
index 0000000..dccccdb
--- /dev/null
+++ b/src/text_recognizer/networks/beam.py
@@ -0,0 +1,83 @@
+"""Implementation of beam search decoder for a sequence to sequence network.
+
+Stolen from: https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/master/decode_beam.py
+
+"""
+# from typing import List
+# from Queue import PriorityQueue
+
+# from loguru import logger
+# import torch
+# from torch import nn
+# from torch import Tensor
+# import torch.nn.functional as F
+
+
+# class Node:
+# def __init__(
+# self, parent: Node, target_index: int, log_prob: Tensor, length: int
+# ) -> None:
+# self.parent = parent
+# self.target_index = target_index
+# self.log_prob = log_prob
+# self.length = length
+# self.reward = 0.0
+
+# def eval(self, alpha: float = 1.0) -> Tensor:
+# return self.log_prob / (self.length - 1 + 1e-6) + alpha * self.reward
+
+
+# @torch.no_grad()
+# def beam_decoder(
+# network, mapper, device, memory: Tensor = None, max_len: int = 97,
+# ) -> Tensor:
+# beam_width = 10
+# topk = 1 # How many sentences to generate.
+
+# trg_indices = [mapper(mapper.init_token)]
+
+# end_nodes = []
+
+# node = Node(None, trg_indices, 0, 1)
+# nodes = PriorityQueue()
+
+# nodes.put((node.eval(), node))
+# q_size = 1
+
+# # Beam search
+# for _ in range(max_len):
+# if q_size > 2000:
+# logger.warning("Could not decoder input")
+# break
+
+# # Fetch the best node.
+# score, n = nodes.get()
+# decoder_input = n.target_index
+
+# if n.target_index == mapper(mapper.eos_token) and n.parent is not None:
+# end_nodes.append((score, n))
+
+# # If we reached the maximum number of sentences required.
+# if len(end_nodes) >= 1:
+# break
+# else:
+# continue
+
+# # Forward pass with transformer.
+# trg = torch.tensor(trg_indices, device=device)[None, :].long()
+# trg = network.target_embedding(trg)
+# logits = network.decoder(trg=trg, memory=memory, trg_mask=None)
+# log_prob = F.log_softmax(logits, dim=2)
+
+# log_prob, indices = torch.topk(log_prob, beam_width)
+
+# for new_k in range(beam_width):
+# # TODO: continue from here
+# token_index = indices[0][new_k].view(1, -1)
+# log_p = log_prob[0][new_k].item()
+
+# node = Node()
+
+# pass
+
+# pass
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py
index 16c7a41..b2b74b3 100644
--- a/src/text_recognizer/networks/cnn_transformer.py
+++ b/src/text_recognizer/networks/cnn_transformer.py
@@ -88,10 +88,14 @@ class CNNTransformer(nn.Module):
if len(src.shape) < 4:
src = src[(None,) * (4 - len(src.shape))]
src = self.backbone(src)
- src = rearrange(src, "b c h w -> b w c h")
+
if self.adaptive_pool is not None:
+ src = rearrange(src, "b c h w -> b w c h")
src = self.adaptive_pool(src)
- src = src.squeeze(3)
+ src = src.squeeze(3)
+ else:
+ src = rearrange(src, "b c h w -> b (w h) c")
+
src = self.position_encoding(src)
return src
@@ -110,12 +114,17 @@ class CNNTransformer(nn.Module):
trg = self.position_encoding(trg)
return trg
- def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
- """Forward pass with CNN transfomer."""
- h = self.extract_image_features(x)
+ def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+ """Takes images features from the backbone and decodes them with the transformer."""
trg_mask = self._create_trg_mask(trg)
trg = self.target_embedding(trg)
out = self.transformer(h, trg, trg_mask=trg_mask)
logits = self.head(out)
return logits
+
+ def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+ """Forward pass with CNN transfomer."""
+ h = self.extract_image_features(x)
+ logits = self.decode_image_features(h, trg)
+ return logits
diff --git a/src/text_recognizer/networks/fcn.py b/src/text_recognizer/networks/fcn.py
deleted file mode 100644
index f9c4fd4..0000000
--- a/src/text_recognizer/networks/fcn.py
+++ /dev/null
@@ -1,99 +0,0 @@
-"""Fully Convolutional Network (FCN) with dilated kernels for global context."""
-from typing import List, Tuple, Type
-import torch
-from torch import nn
-from torch import Tensor
-
-
-from text_recognizer.networks.util import activation_function
-
-
-class _DilatedBlock(nn.Module):
- def __init__(
- self,
- channels: List[int],
- kernel_sizes: List[int],
- dilations: List[int],
- paddings: List[int],
- activation_fn: Type[nn.Module],
- ) -> None:
- super().__init__()
- self.dilation_conv = nn.Sequential(
- nn.Conv2d(
- in_channels=channels[0],
- out_channels=channels[1],
- kernel_size=kernel_sizes[0],
- stride=1,
- dilation=dilations[0],
- padding=paddings[0],
- ),
- nn.Conv2d(
- in_channels=channels[1],
- out_channels=channels[1] // 2,
- kernel_size=kernel_sizes[1],
- stride=1,
- dilation=dilations[1],
- padding=paddings[1],
- ),
- )
- self.activation_fn = activation_fn
-
- self.conv = nn.Conv2d(
- in_channels=channels[0],
- out_channels=channels[1] // 2,
- kernel_size=1,
- dilation=1,
- stride=1,
- )
-
- def forward(self, x: Tensor) -> Tensor:
- residual = self.conv(x)
- x = self.dilation_conv(x)
- x = torch.cat((x, residual), dim=1)
- return self.activation_fn(x)
-
-
-class FCN(nn.Module):
- def __init__(
- self,
- in_channels: int,
- base_channels: int,
- out_channels: int,
- kernel_size: int,
- dilations: Tuple[int] = (3, 7),
- paddings: Tuple[int] = (9, 21),
- num_blocks: int = 14,
- activation: str = "elu",
- ) -> None:
- super().__init__()
- self.kernel_sizes = [kernel_size] * num_blocks
- self.channels = [in_channels] + [base_channels] * (num_blocks - 1)
- self.out_channels = out_channels
- self.dilations = [dilations[0]] * (num_blocks // 2) + [dilations[1]] * (
- num_blocks // 2
- )
- self.paddings = [paddings[0]] * (num_blocks // 2) + [paddings[1]] * (
- num_blocks // 2
- )
- self.activation_fn = activation_function(activation)
- self.fcn = self._configure_fcn()
-
- def _configure_fcn(self) -> nn.Sequential:
- layers = []
- for i in range(0, len(self.channels), 2):
- layers.append(
- _DilatedBlock(
- self.channels[i : i + 2],
- self.kernel_sizes[i : i + 2],
- self.dilations[i : i + 2],
- self.paddings[i : i + 2],
- self.activation_fn,
- )
- )
- layers.append(
- nn.Conv2d(self.channels[-1], self.out_channels, kernel_size=1, stride=1)
- )
- return nn.Sequential(*layers)
-
- def forward(self, x: Tensor) -> Tensor:
- return self.fcn(x)
diff --git a/src/text_recognizer/networks/neural_machine_reader.py b/src/text_recognizer/networks/neural_machine_reader.py
deleted file mode 100644
index 7f8c49b..0000000
--- a/src/text_recognizer/networks/neural_machine_reader.py
+++ /dev/null
@@ -1,201 +0,0 @@
-"""Sequence to sequence network with RNN cells."""
-# from typing import Dict, Optional, Tuple
-
-# from einops import rearrange
-# from einops.layers.torch import Rearrange
-# import torch
-# from torch import nn
-# from torch import Tensor
-
-# from text_recognizer.networks.util import configure_backbone
-
-
-# class Encoder(nn.Module):
-# def __init__(
-# self,
-# embedding_dim: int,
-# encoder_dim: int,
-# decoder_dim: int,
-# dropout_rate: float = 0.1,
-# ) -> None:
-# super().__init__()
-# self.rnn = nn.GRU(
-# input_size=embedding_dim, hidden_size=encoder_dim, bidirectional=True
-# )
-# self.fc = nn.Sequential(
-# nn.Linear(in_features=2 * encoder_dim, out_features=decoder_dim), nn.Tanh()
-# )
-# self.dropout = nn.Dropout(p=dropout_rate)
-
-# def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
-# """Encodes a sequence of tensors with a bidirectional GRU.
-
-# Args:
-# x (Tensor): A input sequence.
-
-# Shape:
-# - x: :math:`(T, N, E)`.
-# - output[0]: :math:`(T, N, 2 * E)`.
-# - output[1]: :math:`(T, N, D)`.
-
-# where T is the sequence length, N is the batch size, E is the
-# embedding/encoder dimension, and D is the decoder dimension.
-
-# Returns:
-# Tuple[Tensor, Tensor]: The encoder output and the hidden state of the
-# encoder.
-
-# """
-
-# output, hidden = self.rnn(x)
-
-# # Get the hidden state from the forward and backward rnn.
-# hidden_state = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)
-
-# # Apply fully connected layer and tanh activation.
-# hidden_state = self.fc(hidden_state)
-
-# return output, hidden_state
-
-
-# class Attention(nn.Module):
-# def __init__(self, encoder_dim: int, decoder_dim: int) -> None:
-# super().__init__()
-# self.atten = nn.Linear(
-# in_features=2 * encoder_dim + decoder_dim, out_features=decoder_dim
-# )
-# self.value = nn.Linear(in_features=decoder_dim, out_features=1, bias=False)
-
-# def forward(self, hidden_state: Tensor, encoder_outputs: Tensor) -> Tensor:
-# """Short summary.
-
-# Args:
-# hidden_state (Tensor): Description of parameter `h`.
-# encoder_outputs (Tensor): Description of parameter `enc_out`.
-
-# Shape:
-# - x: :math:`(T, N, E)`.
-# - output[0]: :math:`(T, N, 2 * E)`.
-# - output[1]: :math:`(T, N, D)`.
-
-# where T is the sequence length, N is the batch size, E is the
-# embedding/encoder dimension, and D is the decoder dimension.
-
-# Returns:
-# Tensor: Description of returned object.
-
-# """
-# t, b = enc_out.shape[:2]
-# # repeat decoder hidden state src_len times
-# hidden_state = hidden_state.unsqueeze(1).repeat(1, t, 1)
-
-# encoder_outputs = rearrange(encoder_outputs, "t b e2 -> b t e2")
-
-# # Calculate the energy between the decoders previous hidden state and the
-# # encoders hidden states.
-# energy = torch.tanh(
-# self.attn(torch.cat((hidden_state, encoder_outputs), dim=2))
-# )
-
-# attention = self.value(energy).squeeze(2)
-
-# # Apply softmax on the attention to squeeze it between 0 and 1.
-# attention = F.softmax(attention, dim=1)
-
-# return attention
-
-
-# class Decoder(nn.Module):
-# def __init__(
-# self,
-# embedding_dim: int,
-# encoder_dim: int,
-# decoder_dim: int,
-# output_dim: int,
-# dropout_rate: float = 0.1,
-# ) -> None:
-# super().__init__()
-# self.output_dim = output_dim
-# self.embedding = nn.Embedding(output_dim, embedding_dim)
-# self.attention = Attention(encoder_dim, decoder_dim)
-# self.rnn = nn.GRU(
-# input_size=2 * encoder_dim + embedding_dim, hidden_size=decoder_dim
-# )
-
-# self.head = nn.Linear(
-# in_features=2 * encoder_dim + embedding_dim + decoder_dim,
-# out_features=output_dim,
-# )
-# self.dropout = nn.Dropout(p=dropout_rate)
-
-# def forward(
-# self, trg: Tensor, hidden_state: Tensor, encoder_outputs: Tensor
-# ) -> Tensor:
-# # input = [batch size]
-# # hidden = [batch size, dec hid dim]
-# # encoder_outputs = [src len, batch size, enc hid dim * 2]
-# trg = trg.unsqueeze(0)
-# trg_embedded = self.dropout(self.embedding(trg))
-
-# a = self.attention(hidden_state, encoder_outputs)
-
-# weighted = torch.bmm(a, encoder_outputs)
-
-# # Permutate the tensor.
-# weighted = rearrange(weighted, "b a e2 -> a b e2")
-
-# rnn_input = torch.cat((trg_embedded, weighted), dim=2)
-
-# output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
-
-# # seq len, n layers and n directions will always be 1 in this decoder, therefore:
-# # output = [1, batch size, dec hid dim]
-# # hidden = [1, batch size, dec hid dim]
-# # this also means that output == hidden
-# assert (output == hidden).all()
-
-# trg_embedded = trg_embedded.squeeze(0)
-# output = output.squeeze(0)
-# weighted = weighted.squeeze(0)
-
-# logits = self.fc_out(torch.cat((output, weighted, trg_embedded), dim=1))
-
-# # prediction = [batch size, output dim]
-
-# return logits, hidden.squeeze(0)
-
-
-# class NeuralMachineReader(nn.Module):
-# def __init__(
-# self,
-# embedding_dim: int,
-# encoder_dim: int,
-# decoder_dim: int,
-# output_dim: int,
-# backbone: Optional[str] = None,
-# backbone_args: Optional[Dict] = None,
-# adaptive_pool_dim: Tuple = (None, 1),
-# dropout_rate: float = 0.1,
-# teacher_forcing_ratio: float = 0.5,
-# ) -> None:
-# super().__init__()
-
-# self.backbone = configure_backbone(backbone, backbone_args)
-# self.adaptive_pool = nn.AdaptiveAvgPool2d((adaptive_pool_dim))
-
-# self.encoder = Encoder(embedding_dim, encoder_dim, decoder_dim, dropout_rate)
-# self.decoder = Decoder(
-# embedding_dim, encoder_dim, decoder_dim, output_dim, dropout_rate
-# )
-# self.teacher_forcing_ratio = teacher_forcing_ratio
-
-# def extract_image_features(self, x: Tensor) -> Tensor:
-# x = self.backbone(x)
-# x = rearrange(x, "b c h w -> b w c h")
-# x = self.adaptive_pool(x)
-# x = x.squeeze(3)
-
-# def forward(self, x: Tensor, trg: Tensor) -> Tensor:
-# # x = [batch size, height, width]
-# # trg = [trg len, batch size]
-# z = self.extract_image_features(x)
diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py
index 6405192..e397224 100644
--- a/src/text_recognizer/networks/residual_network.py
+++ b/src/text_recognizer/networks/residual_network.py
@@ -7,7 +7,6 @@ import torch
from torch import nn
from torch import Tensor
-from text_recognizer.networks.stn import SpatialTransformerNetwork
from text_recognizer.networks.util import activation_function
@@ -209,12 +208,10 @@ class ResidualNetworkEncoder(nn.Module):
activation: str = "relu",
block: Type[nn.Module] = BasicBlock,
levels: int = 1,
- stn: bool = False,
*args,
**kwargs
) -> None:
super().__init__()
- self.stn = SpatialTransformerNetwork() if stn else None
self.block_sizes = (
block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels
)
@@ -231,7 +228,7 @@ class ResidualNetworkEncoder(nn.Module):
),
nn.BatchNorm2d(self.block_sizes[0]),
activation_function(self.activation),
- nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
+ # nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
)
self.blocks = self._configure_blocks(block)
@@ -275,8 +272,6 @@ class ResidualNetworkEncoder(nn.Module):
# If batch dimenstion is missing, it needs to be added.
if len(x.shape) == 3:
x = x.unsqueeze(0)
- if self.stn is not None:
- x = self.stn(x)
x = self.gate(x)
x = self.blocks(x)
return x
diff --git a/src/text_recognizer/networks/unet.py b/src/text_recognizer/networks/unet.py
index 51f242a..510910f 100644
--- a/src/text_recognizer/networks/unet.py
+++ b/src/text_recognizer/networks/unet.py
@@ -8,64 +8,118 @@ from torch import Tensor
from text_recognizer.networks.util import activation_function
-class ConvBlock(nn.Module):
- """Basic UNet convolutional block."""
+class _ConvBlock(nn.Module):
+ """Modified UNet convolutional block with dilation."""
- def __init__(self, channels: List[int], activation: str) -> None:
+ def __init__(
+ self,
+ channels: List[int],
+ activation: str,
+ num_groups: int,
+ dropout_rate: float = 0.1,
+ kernel_size: int = 3,
+ dilation: int = 1,
+ padding: int = 0,
+ ) -> None:
super().__init__()
self.channels = channels
+ self.dropout_rate = dropout_rate
+ self.kernel_size = kernel_size
+ self.dilation = dilation
+ self.padding = padding
+ self.num_groups = num_groups
self.activation = activation_function(activation)
self.block = self._configure_block()
+ self.residual_conv = nn.Sequential(
+ nn.Conv2d(
+ self.channels[0], self.channels[-1], kernel_size=3, stride=1, padding=1
+ ),
+ self.activation,
+ )
def _configure_block(self) -> nn.Sequential:
block = []
for i in range(len(self.channels) - 1):
block += [
+ nn.Dropout(p=self.dropout_rate),
+ nn.GroupNorm(self.num_groups, self.channels[i]),
+ self.activation,
nn.Conv2d(
- self.channels[i], self.channels[i + 1], kernel_size=3, padding=1
+ self.channels[i],
+ self.channels[i + 1],
+ kernel_size=self.kernel_size,
+ padding=self.padding,
+ stride=1,
+ dilation=self.dilation,
),
- nn.BatchNorm2d(self.channels[i + 1]),
- self.activation,
]
return nn.Sequential(*block)
def forward(self, x: Tensor) -> Tensor:
"""Apply the convolutional block."""
- return self.block(x)
+ residual = self.residual_conv(x)
+ return self.block(x) + residual
-class DownSamplingBlock(nn.Module):
+class _DownSamplingBlock(nn.Module):
"""Basic down sampling block."""
def __init__(
self,
channels: List[int],
activation: str,
+ num_groups: int,
pooling_kernel: Union[int, bool] = 2,
+ dropout_rate: float = 0.1,
+ kernel_size: int = 3,
+ dilation: int = 1,
+ padding: int = 0,
) -> None:
super().__init__()
- self.conv_block = ConvBlock(channels, activation)
+ self.conv_block = _ConvBlock(
+ channels,
+ activation,
+ num_groups,
+ dropout_rate,
+ kernel_size,
+ dilation,
+ padding,
+ )
self.down_sampling = nn.MaxPool2d(pooling_kernel) if pooling_kernel else None
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""Return the convolutional block output and a down sampled tensor."""
x = self.conv_block(x)
- if self.down_sampling is not None:
- x_down = self.down_sampling(x)
- else:
- x_down = None
+ x_down = self.down_sampling(x) if self.down_sampling is not None else x
+
return x_down, x
-class UpSamplingBlock(nn.Module):
+class _UpSamplingBlock(nn.Module):
"""The upsampling block of the UNet."""
def __init__(
- self, channels: List[int], activation: str, scale_factor: int = 2
+ self,
+ channels: List[int],
+ activation: str,
+ num_groups: int,
+ scale_factor: int = 2,
+ dropout_rate: float = 0.1,
+ kernel_size: int = 3,
+ dilation: int = 1,
+ padding: int = 0,
) -> None:
super().__init__()
- self.conv_block = ConvBlock(channels, activation)
+ self.conv_block = _ConvBlock(
+ channels,
+ activation,
+ num_groups,
+ dropout_rate,
+ kernel_size,
+ dilation,
+ padding,
+ )
self.up_sampling = nn.Upsample(
scale_factor=scale_factor, mode="bilinear", align_corners=True
)
@@ -87,14 +141,43 @@ class UNet(nn.Module):
base_channels: int = 64,
num_classes: int = 3,
depth: int = 4,
- out_channels: int = 3,
activation: str = "relu",
+ num_groups: int = 8,
+ dropout_rate: float = 0.1,
pooling_kernel: int = 2,
scale_factor: int = 2,
+ kernel_size: Optional[List[int]] = None,
+ dilation: Optional[List[int]] = None,
+ padding: Optional[List[int]] = None,
) -> None:
super().__init__()
self.depth = depth
- channels = [1] + [base_channels * 2 ** i for i in range(depth)]
+ self.num_groups = num_groups
+
+ if kernel_size is not None and dilation is not None and padding is not None:
+ if (
+ len(kernel_size) != depth
+ and len(dilation) != depth
+ and len(padding) != depth
+ ):
+ raise RuntimeError(
+ "Length of convolutional parameters does not match the depth."
+ )
+ self.kernel_size = kernel_size
+ self.padding = padding
+ self.dilation = dilation
+
+ else:
+ self.kernel_size = [3] * depth
+ self.padding = [1] * depth
+ self.dilation = [1] * depth
+
+ self.dropout_rate = dropout_rate
+ self.conv = nn.Conv2d(
+ in_channels, base_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ channels = [base_channels] + [base_channels * 2 ** i for i in range(depth)]
self.encoder_blocks = self._configure_down_sampling_blocks(
channels, activation, pooling_kernel
)
@@ -110,49 +193,63 @@ class UNet(nn.Module):
blocks = nn.ModuleList([])
for i in range(len(channels) - 1):
pooling_kernel = pooling_kernel if i < self.depth - 1 else False
+ dropout_rate = self.dropout_rate if i < 0 else 0
blocks += [
- DownSamplingBlock(
+ _DownSamplingBlock(
[channels[i], channels[i + 1], channels[i + 1]],
activation,
+ self.num_groups,
pooling_kernel,
+ dropout_rate,
+ self.kernel_size[i],
+ self.dilation[i],
+ self.padding[i],
)
]
return blocks
def _configure_up_sampling_blocks(
- self,
- channels: List[int],
- activation: str,
- scale_factor: int,
+ self, channels: List[int], activation: str, scale_factor: int,
) -> nn.ModuleList:
channels.reverse()
+ self.kernel_size.reverse()
+ self.dilation.reverse()
+ self.padding.reverse()
return nn.ModuleList(
[
- UpSamplingBlock(
+ _UpSamplingBlock(
[channels[i] + channels[i + 1], channels[i + 1], channels[i + 1]],
activation,
+ self.num_groups,
scale_factor,
+ self.dropout_rate,
+ self.kernel_size[i],
+ self.dilation[i],
+ self.padding[i],
)
for i in range(len(channels) - 2)
]
)
- def encode(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
+ def _encode(self, x: Tensor) -> List[Tensor]:
x_skips = []
for block in self.encoder_blocks:
x, x_skip = block(x)
- if x_skip is not None:
- x_skips.append(x_skip)
- return x, x_skips
+ x_skips.append(x_skip)
+ return x_skips
- def decode(self, x: Tensor, x_skips: List[Tensor]) -> Tensor:
+ def _decode(self, x_skips: List[Tensor]) -> Tensor:
x = x_skips[-1]
for i, block in enumerate(self.decoder_blocks):
x = block(x, x_skips[-(i + 2)])
return x
def forward(self, x: Tensor) -> Tensor:
- x, x_skips = self.encode(x)
- x = self.decode(x, x_skips)
+ """Forward pass with the UNet model."""
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+ x = self.conv(x)
+ x_skips = self._encode(x)
+ x = self._decode(x_skips)
return self.head(x)