diff options
-rw-r--r-- | poetry.lock | 52 | ||||
-rw-r--r-- | pyproject.toml | 3 | ||||
-rw-r--r-- | text_recognizer/data/iam_synthetic_paragraphs.py | 2 | ||||
-rw-r--r-- | text_recognizer/networks/__init__.py | 3 | ||||
-rw-r--r-- | text_recognizer/networks/cnn_transformer.py | 364 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/__init__.py | 3 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/transformer.py | 520 | ||||
-rw-r--r-- | training/.gitignore | 1 | ||||
-rw-r--r-- | training/conf/callbacks/default.yaml | 14 | ||||
-rw-r--r-- | training/conf/callbacks/swa.yaml | 16 | ||||
-rw-r--r-- | training/conf/cnn_transformer.yaml (renamed from training/configs/cnn_transformer.yaml) | 0 | ||||
-rw-r--r-- | training/conf/config.yaml | 6 | ||||
-rw-r--r-- | training/conf/dataset/iam_extended_paragraphs.yaml | 7 | ||||
-rw-r--r-- | training/conf/model/lit_vqvae.yaml | 24 | ||||
-rw-r--r-- | training/conf/network/vqvae.yaml | 14 | ||||
-rw-r--r-- | training/conf/trainer/default.yaml | 18 | ||||
-rw-r--r-- | training/configs/vqvae.yaml | 89 | ||||
-rw-r--r-- | training/run_experiment.py | 136 |
18 files changed, 635 insertions, 637 deletions
diff --git a/poetry.lock b/poetry.lock index a8034e7..81ccaed 100644 --- a/poetry.lock +++ b/poetry.lock @@ -37,6 +37,14 @@ optional = false python-versions = "*" [[package]] +name = "antlr4-python3-runtime" +version = "4.8" +description = "ANTLR 4.8 runtime for Python 3.7" +category = "main" +optional = false +python-versions = "*" + +[[package]] name = "appdirs" version = "1.4.4" description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." @@ -622,6 +630,19 @@ numpy = [ ] [[package]] +name = "hydra-core" +version = "1.0.6" +description = "A framework for elegantly configuring complex applications" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +antlr4-python3-runtime = "4.8" +importlib-resources = {version = "*", markers = "python_version < \"3.9\""} +omegaconf = ">=2.0.5,<2.1" + +[[package]] name = "idna" version = "2.10" description = "Internationalized Domain Names in Applications (IDNA)" @@ -638,6 +659,18 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" [[package]] +name = "importlib-resources" +version = "5.1.2" +description = "Read resources from Python packages" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.extras] +docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"] +testing = ["pytest (>=4.6)", "pytest-checkdocs (>=1.2.3)", "pytest-flake8", "pytest-cov", "pytest-enabler", "pytest-black (>=0.3.7)", "pytest-mypy"] + +[[package]] name = "ipykernel" version = "5.5.3" description = "IPython Kernel for Jupyter" @@ -2086,7 +2119,7 @@ brotli = ["brotlipy (>=0.6.0)"] [[package]] name = "wandb" -version = "0.10.27" +version = "0.10.28" description = "A CLI and library for interacting with the Weights and Biases API." category = "main" optional = false @@ -2197,7 +2230,7 @@ multidict = ">=4.0" [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "07d5f14d7c55a961ce1841ecd125c0f7c83d5649cc118fcae2ed3b58347ca8c2" +content-hash = "bcc456879df9e9ec937f1aa8c339fe96ca18ce95b12b12fc10c23494829296e8" [metadata.files] absl-py = [ @@ -2247,6 +2280,9 @@ alabaster = [ {file = "alabaster-0.7.12-py2.py3-none-any.whl", hash = "sha256:446438bdcca0e05bd45ea2de1668c1d9b032e1a9154c2c259092d77031ddd359"}, {file = "alabaster-0.7.12.tar.gz", hash = "sha256:a661d72d58e6ea8a57f7a86e37d86716863ee5e92788398526d58b26a4e4dc02"}, ] +antlr4-python3-runtime = [ + {file = "antlr4-python3-runtime-4.8.tar.gz", hash = "sha256:15793f5d0512a372b4e7d2284058ad32ce7dd27126b105fb0b2245130445db33"}, +] appdirs = [ {file = "appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128"}, {file = "appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41"}, @@ -2628,6 +2664,10 @@ h5py = [ {file = "h5py-3.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:7c5b5f18c96fb63399280a724734fd91e1781c6b60e385e439ad8e654a294ba4"}, {file = "h5py-3.2.1.tar.gz", hash = "sha256:89474be911bfcdb34cbf0d98b8ec48b578c27a89fdb1ae4ee7513f1ef8d9249e"}, ] +hydra-core = [ + {file = "hydra-core-1.0.6.tar.gz", hash = "sha256:be5fe119eb41ada20c82835b153b956781de7e2506670ba942f7453b2e850950"}, + {file = "hydra_core-1.0.6-py3-none-any.whl", hash = "sha256:500d4346b7afcd654276c87c15820d7e6b76c2da95ad698cceb4120d7a877b32"}, +] idna = [ {file = "idna-2.10-py2.py3-none-any.whl", hash = "sha256:b97d804b1e9b523befed77c48dacec60e6dcb0b5391d57af6a65a312a90648c0"}, {file = "idna-2.10.tar.gz", hash = "sha256:b307872f855b18632ce0c21c5e45be78c0ea7ae4c15c828c20788b26921eb3f6"}, @@ -2636,6 +2676,10 @@ imagesize = [ {file = "imagesize-1.2.0-py2.py3-none-any.whl", hash = "sha256:6965f19a6a2039c7d48bca7dba2473069ff854c36ae6f19d2cde309d998228a1"}, {file = "imagesize-1.2.0.tar.gz", hash = "sha256:b1f6b5a4eab1f73479a50fb79fcf729514a900c341d8503d62a62dbc4127a2b1"}, ] +importlib-resources = [ + {file = "importlib_resources-5.1.2-py3-none-any.whl", hash = "sha256:ebab3efe74d83b04d6bf5cd9a17f0c5c93e60fb60f30c90f56265fce4682a469"}, + {file = "importlib_resources-5.1.2.tar.gz", hash = "sha256:642586fc4740bd1cad7690f836b3321309402b20b332529f25617ff18e8e1370"}, +] ipykernel = [ {file = "ipykernel-5.5.3-py3-none-any.whl", hash = "sha256:21abd584543759e49010975a4621603b3cf871b1039cb3879a14094717692614"}, {file = "ipykernel-5.5.3.tar.gz", hash = "sha256:a682e4f7affd86d9ce9b699d21bcab6d5ec9fbb2bfcb194f2706973b252bc509"}, @@ -3635,8 +3679,8 @@ urllib3 = [ {file = "urllib3-1.26.4.tar.gz", hash = "sha256:e7b021f7241115872f92f43c6508082facffbd1c048e3c6e2bb9c2a157e28937"}, ] wandb = [ - {file = "wandb-0.10.27-py2.py3-none-any.whl", hash = "sha256:f591bb9d5c402ec5c12bd823db913d49ac31e4068f2c35c50b6f13e4fcd717b4"}, - {file = "wandb-0.10.27.tar.gz", hash = "sha256:1b6d3bbfd644183bbd79a02ff9e57e65a8a4f7c9b770af0d3c4f961ae6d7fc73"}, + {file = "wandb-0.10.28-py2.py3-none-any.whl", hash = "sha256:609f2605ec82f846490a2bdd9d01fa947b1892a76133cbb80c41501e5dfda763"}, + {file = "wandb-0.10.28.tar.gz", hash = "sha256:b48aa55f147717e197d38b0dd9d9ef3662efe54fefc0be2909f93e8a5c0cc71b"}, ] wcwidth = [ {file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"}, diff --git a/pyproject.toml b/pyproject.toml index f9bce95..6f050c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ nltk = "^3.5" torch-summary = "^1.4.2" defusedxml = "^0.6.0" omegaconf = "^2.0.2" -wandb = "^0.10.27" +wandb = "^0.10.28" einops = "^0.3.0" gtn = "^0.0.0" sentencepiece = "^0.1.95" @@ -41,6 +41,7 @@ Pillow = "^8.1.2" madgrad = "^1.0" editdistance = "^0.5.3" torchmetrics = "^0.2.0" +hydra-core = "^1.0.6" [tool.poetry.dev-dependencies] pytest = "^5.4.2" diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index 78e6c05..ad6fa25 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -76,7 +76,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): def setup(self, stage: str = None) -> None: """Loading synthetic dataset.""" - logger.info(f"IAM Synthetic dataset steup for stage {stage}") + logger.info(f"IAM Synthetic dataset steup for stage {stage}...") if stage == "fit" or stage is None: line_crops, line_labels = load_line_crops_and_labels( diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py index a9117f8..d1ebf1a 100644 --- a/text_recognizer/networks/__init__.py +++ b/text_recognizer/networks/__init__.py @@ -1,4 +1,5 @@ """Network modules""" from .encoders import EfficientNet from .vqvae import VQVAE -from .cnn_transformer import CNNTransformer + +# from .cnn_transformer import CNNTransformer diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py index d42c29d..80798e1 100644 --- a/text_recognizer/networks/cnn_transformer.py +++ b/text_recognizer/networks/cnn_transformer.py @@ -1,182 +1,182 @@ -"""A Transformer with a cnn backbone. - -The network encodes a image with a convolutional backbone to a latent representation, -i.e. feature maps. A 2d positional encoding is applied to the feature maps for -spatial information. The resulting feature are then set to a transformer decoder -together with the target tokens. - -TODO: Local attention for lower layer in attention. - -""" -import importlib -import math -from typing import Dict, Optional, Union, Sequence, Type - -from einops import rearrange -from omegaconf import DictConfig, OmegaConf -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS -from text_recognizer.networks.transformer import ( - Decoder, - DecoderLayer, - PositionalEncoding, - PositionalEncoding2D, - target_padding_mask, -) - -NUM_WORD_PIECES = 1000 - - -class CNNTransformer(nn.Module): - def __init__( - self, - input_dim: Sequence[int], - output_dims: Sequence[int], - encoder: Union[DictConfig, Dict], - vocab_size: Optional[int] = None, - num_decoder_layers: int = 4, - hidden_dim: int = 256, - num_heads: int = 4, - expansion_dim: int = 1024, - dropout_rate: float = 0.1, - transformer_activation: str = "glu", - *args, - **kwargs, - ) -> None: - super().__init__() - self.vocab_size = ( - NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size - ) - self.pad_index = 3 # TODO: fix me - self.hidden_dim = hidden_dim - self.max_output_length = output_dims[0] - - # Image backbone - self.encoder = self._configure_encoder(encoder) - self.encoder_proj = nn.Conv2d(256, hidden_dim, kernel_size=1) - self.feature_map_encoding = PositionalEncoding2D( - hidden_dim=hidden_dim, max_h=input_dim[1], max_w=input_dim[2] - ) - - # Target token embedding - self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) - self.trg_position_encoding = PositionalEncoding( - hidden_dim, dropout_rate, max_len=output_dims[0] - ) - - # Transformer decoder - self.decoder = Decoder( - decoder_layer=DecoderLayer( - hidden_dim=hidden_dim, - num_heads=num_heads, - expansion_dim=expansion_dim, - dropout_rate=dropout_rate, - activation=transformer_activation, - ), - num_layers=num_decoder_layers, - norm=nn.LayerNorm(hidden_dim), - ) - - # Classification head - self.head = nn.Linear(hidden_dim, self.vocab_size) - - # Initialize weights - self._init_weights() - - def _init_weights(self) -> None: - """Initialize network weights.""" - self.trg_embedding.weight.data.uniform_(-0.1, 0.1) - self.head.bias.data.zero_() - self.head.weight.data.uniform_(-0.1, 0.1) - - nn.init.kaiming_normal_( - self.encoder_proj.weight.data, - a=0, - mode="fan_out", - nonlinearity="relu", - ) - if self.encoder_proj.bias is not None: - _, fan_out = nn.init._calculate_fan_in_and_fan_out( - self.encoder_proj.weight.data - ) - bound = 1 / math.sqrt(fan_out) - nn.init.normal_(self.encoder_proj.bias, -bound, bound) - - @staticmethod - def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: - encoder = OmegaConf.create(encoder) - args = encoder.args or {} - network_module = importlib.import_module("text_recognizer.networks") - encoder_class = getattr(network_module, encoder.type) - return encoder_class(**args) - - def encode(self, image: Tensor) -> Tensor: - """Extracts image features with backbone. - - Args: - image (Tensor): Image(s) of handwritten text. - - Retuns: - Tensor: Image features. - - Shapes: - - image: :math: `(B, C, H, W)` - - latent: :math: `(B, T, C)` - - """ - # Extract image features. - image_features = self.encoder(image) - image_features = self.encoder_proj(image_features) - - # Add 2d encoding to the feature maps. - image_features = self.feature_map_encoding(image_features) - - # Collapse features maps height and width. - image_features = rearrange(image_features, "b c h w -> b (h w) c") - return image_features - - def decode(self, memory: Tensor, trg: Tensor) -> Tensor: - """Decodes image features with transformer decoder.""" - trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index) - trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim) - trg = rearrange(trg, "b t d -> t b d") - trg = self.trg_position_encoding(trg) - trg = rearrange(trg, "t b d -> b t d") - out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None) - logits = self.head(out) - return logits - - def forward(self, image: Tensor, trg: Tensor) -> Tensor: - image_features = self.encode(image) - output = self.decode(image_features, trg) - output = rearrange(output, "b t c -> b c t") - return output - - def predict(self, image: Tensor) -> Tensor: - """Transcribes text in image(s).""" - bsz = image.shape[0] - image_features = self.encode(image) - - output_tokens = ( - (torch.ones((bsz, self.max_output_length)) * self.pad_index) - .type_as(image) - .long() - ) - output_tokens[:, 0] = self.start_index - for i in range(1, self.max_output_length): - trg = output_tokens[:, :i] - output = self.decode(image_features, trg) - output = torch.argmax(output, dim=-1) - output_tokens[:, i] = output[-1:] - - # Set all tokens after end token to be padding. - for i in range(1, self.max_output_length): - indices = output_tokens[:, i - 1] == self.end_index | ( - output_tokens[:, i - 1] == self.pad_index - ) - output_tokens[indices, i] = self.pad_index - - return output_tokens +# """A Transformer with a cnn backbone. +# +# The network encodes a image with a convolutional backbone to a latent representation, +# i.e. feature maps. A 2d positional encoding is applied to the feature maps for +# spatial information. The resulting feature are then set to a transformer decoder +# together with the target tokens. +# +# TODO: Local attention for lower layer in attention. +# +# """ +# import importlib +# import math +# from typing import Dict, Optional, Union, Sequence, Type +# +# from einops import rearrange +# from omegaconf import DictConfig, OmegaConf +# import torch +# from torch import nn +# from torch import Tensor +# +# from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS +# from text_recognizer.networks.transformer import ( +# Decoder, +# DecoderLayer, +# PositionalEncoding, +# PositionalEncoding2D, +# target_padding_mask, +# ) +# +# NUM_WORD_PIECES = 1000 +# +# +# class CNNTransformer(nn.Module): +# def __init__( +# self, +# input_dim: Sequence[int], +# output_dims: Sequence[int], +# encoder: Union[DictConfig, Dict], +# vocab_size: Optional[int] = None, +# num_decoder_layers: int = 4, +# hidden_dim: int = 256, +# num_heads: int = 4, +# expansion_dim: int = 1024, +# dropout_rate: float = 0.1, +# transformer_activation: str = "glu", +# *args, +# **kwargs, +# ) -> None: +# super().__init__() +# self.vocab_size = ( +# NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size +# ) +# self.pad_index = 3 # TODO: fix me +# self.hidden_dim = hidden_dim +# self.max_output_length = output_dims[0] +# +# # Image backbone +# self.encoder = self._configure_encoder(encoder) +# self.encoder_proj = nn.Conv2d(256, hidden_dim, kernel_size=1) +# self.feature_map_encoding = PositionalEncoding2D( +# hidden_dim=hidden_dim, max_h=input_dim[1], max_w=input_dim[2] +# ) +# +# # Target token embedding +# self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) +# self.trg_position_encoding = PositionalEncoding( +# hidden_dim, dropout_rate, max_len=output_dims[0] +# ) +# +# # Transformer decoder +# self.decoder = Decoder( +# decoder_layer=DecoderLayer( +# hidden_dim=hidden_dim, +# num_heads=num_heads, +# expansion_dim=expansion_dim, +# dropout_rate=dropout_rate, +# activation=transformer_activation, +# ), +# num_layers=num_decoder_layers, +# norm=nn.LayerNorm(hidden_dim), +# ) +# +# # Classification head +# self.head = nn.Linear(hidden_dim, self.vocab_size) +# +# # Initialize weights +# self._init_weights() +# +# def _init_weights(self) -> None: +# """Initialize network weights.""" +# self.trg_embedding.weight.data.uniform_(-0.1, 0.1) +# self.head.bias.data.zero_() +# self.head.weight.data.uniform_(-0.1, 0.1) +# +# nn.init.kaiming_normal_( +# self.encoder_proj.weight.data, +# a=0, +# mode="fan_out", +# nonlinearity="relu", +# ) +# if self.encoder_proj.bias is not None: +# _, fan_out = nn.init._calculate_fan_in_and_fan_out( +# self.encoder_proj.weight.data +# ) +# bound = 1 / math.sqrt(fan_out) +# nn.init.normal_(self.encoder_proj.bias, -bound, bound) +# +# @staticmethod +# def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: +# encoder = OmegaConf.create(encoder) +# args = encoder.args or {} +# network_module = importlib.import_module("text_recognizer.networks") +# encoder_class = getattr(network_module, encoder.type) +# return encoder_class(**args) +# +# def encode(self, image: Tensor) -> Tensor: +# """Extracts image features with backbone. +# +# Args: +# image (Tensor): Image(s) of handwritten text. +# +# Retuns: +# Tensor: Image features. +# +# Shapes: +# - image: :math: `(B, C, H, W)` +# - latent: :math: `(B, T, C)` +# +# """ +# # Extract image features. +# image_features = self.encoder(image) +# image_features = self.encoder_proj(image_features) +# +# # Add 2d encoding to the feature maps. +# image_features = self.feature_map_encoding(image_features) +# +# # Collapse features maps height and width. +# image_features = rearrange(image_features, "b c h w -> b (h w) c") +# return image_features +# +# def decode(self, memory: Tensor, trg: Tensor) -> Tensor: +# """Decodes image features with transformer decoder.""" +# trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index) +# trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim) +# trg = rearrange(trg, "b t d -> t b d") +# trg = self.trg_position_encoding(trg) +# trg = rearrange(trg, "t b d -> b t d") +# out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None) +# logits = self.head(out) +# return logits +# +# def forward(self, image: Tensor, trg: Tensor) -> Tensor: +# image_features = self.encode(image) +# output = self.decode(image_features, trg) +# output = rearrange(output, "b t c -> b c t") +# return output +# +# def predict(self, image: Tensor) -> Tensor: +# """Transcribes text in image(s).""" +# bsz = image.shape[0] +# image_features = self.encode(image) +# +# output_tokens = ( +# (torch.ones((bsz, self.max_output_length)) * self.pad_index) +# .type_as(image) +# .long() +# ) +# output_tokens[:, 0] = self.start_index +# for i in range(1, self.max_output_length): +# trg = output_tokens[:, :i] +# output = self.decode(image_features, trg) +# output = torch.argmax(output, dim=-1) +# output_tokens[:, i] = output[-1:] +# +# # Set all tokens after end token to be padding. +# for i in range(1, self.max_output_length): +# indices = output_tokens[:, i - 1] == self.end_index | ( +# output_tokens[:, i - 1] == self.pad_index +# ) +# output_tokens[indices, i] = self.pad_index +# +# return output_tokens diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py index 627fa7b..4ff48f7 100644 --- a/text_recognizer/networks/transformer/__init__.py +++ b/text_recognizer/networks/transformer/__init__.py @@ -4,4 +4,5 @@ from .positional_encoding import ( PositionalEncoding2D, target_padding_mask, ) -from .transformer import Decoder, DecoderLayer, Encoder, EncoderLayer, Transformer + +# from .transformer import Decoder, DecoderLayer, Encoder, EncoderLayer, Transformer diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py index 5ac2787..d49c85a 100644 --- a/text_recognizer/networks/transformer/transformer.py +++ b/text_recognizer/networks/transformer/transformer.py @@ -1,260 +1,260 @@ -"""Transfomer module.""" -import copy -from typing import Dict, Optional, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -import torch.nn.functional as F - -from text_recognizer.networks.transformer.attention import MultiHeadAttention -from text_recognizer.networks.util import activation_function - - -class GEGLU(nn.Module): - """GLU activation for improving feedforward activations.""" - - def __init__(self, dim_in: int, dim_out: int) -> None: - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def forward(self, x: Tensor) -> Tensor: - """Forward propagation.""" - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) - - -def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList: - return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)]) - - -class _IntraLayerConnection(nn.Module): - """Preforms the residual connection inside the transfomer blocks and applies layernorm.""" - - def __init__(self, dropout_rate: float, hidden_dim: int) -> None: - super().__init__() - self.norm = nn.LayerNorm(normalized_shape=hidden_dim) - self.dropout = nn.Dropout(p=dropout_rate) - - def forward(self, src: Tensor, residual: Tensor) -> Tensor: - return self.norm(self.dropout(src) + residual) - - -class FeedForward(nn.Module): - def __init__( - self, - hidden_dim: int, - expansion_dim: int, - dropout_rate: float, - activation: str = "relu", - ) -> None: - super().__init__() - - in_projection = ( - nn.Sequential( - nn.Linear(hidden_dim, expansion_dim), activation_function(activation) - ) - if activation != "glu" - else GEGLU(hidden_dim, expansion_dim) - ) - - self.layer = nn.Sequential( - in_projection, - nn.Dropout(p=dropout_rate), - nn.Linear(in_features=expansion_dim, out_features=hidden_dim), - ) - - def forward(self, x: Tensor) -> Tensor: - return self.layer(x) - - -class EncoderLayer(nn.Module): - """Transfomer encoding layer.""" - - def __init__( - self, - hidden_dim: int, - num_heads: int, - expansion_dim: int, - dropout_rate: float, - activation: str = "relu", - ) -> None: - super().__init__() - self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) - self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation) - self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) - self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) - - def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: - """Forward pass through the encoder.""" - # First block. - # Multi head attention. - out, _ = self.self_attention(src, src, src, mask) - - # Add & norm. - out = self.block1(out, src) - - # Second block. - # Apply 1D-convolution. - mlp_out = self.mlp(out) - - # Add & norm. - out = self.block2(mlp_out, out) - - return out - - -class Encoder(nn.Module): - """Transfomer encoder module.""" - - def __init__( - self, - num_layers: int, - encoder_layer: Type[nn.Module], - norm: Optional[Type[nn.Module]] = None, - ) -> None: - super().__init__() - self.layers = _get_clones(encoder_layer, num_layers) - self.norm = norm - - def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor: - """Forward pass through all encoder layers.""" - for layer in self.layers: - src = layer(src, src_mask) - - if self.norm is not None: - src = self.norm(src) - - return src - - -class DecoderLayer(nn.Module): - """Transfomer decoder layer.""" - - def __init__( - self, - hidden_dim: int, - num_heads: int, - expansion_dim: int, - dropout_rate: float = 0.0, - activation: str = "relu", - ) -> None: - super().__init__() - self.hidden_dim = hidden_dim - self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) - self.multihead_attention = MultiHeadAttention( - hidden_dim, num_heads, dropout_rate - ) - self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation) - self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) - self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) - self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim) - - def forward( - self, - trg: Tensor, - memory: Tensor, - trg_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - ) -> Tensor: - """Forward pass of the layer.""" - out, _ = self.self_attention(trg, trg, trg, trg_mask) - trg = self.block1(out, trg) - - out, _ = self.multihead_attention(trg, memory, memory, memory_mask) - trg = self.block2(out, trg) - - out = self.mlp(trg) - out = self.block3(out, trg) - - return out - - -class Decoder(nn.Module): - """Transfomer decoder module.""" - - def __init__( - self, - decoder_layer: Type[nn.Module], - num_layers: int, - norm: Optional[Type[nn.Module]] = None, - ) -> None: - super().__init__() - self.layers = _get_clones(decoder_layer, num_layers) - self.num_layers = num_layers - self.norm = norm - - def forward( - self, - trg: Tensor, - memory: Tensor, - trg_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - ) -> Tensor: - """Forward pass through the decoder.""" - for layer in self.layers: - trg = layer(trg, memory, trg_mask, memory_mask) - - if self.norm is not None: - trg = self.norm(trg) - - return trg - - -class Transformer(nn.Module): - """Transformer network.""" - - def __init__( - self, - num_encoder_layers: int, - num_decoder_layers: int, - hidden_dim: int, - num_heads: int, - expansion_dim: int, - dropout_rate: float, - activation: str = "relu", - ) -> None: - super().__init__() - - # Configure encoder. - encoder_norm = nn.LayerNorm(hidden_dim) - encoder_layer = EncoderLayer( - hidden_dim, num_heads, expansion_dim, dropout_rate, activation - ) - self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm) - - # Configure decoder. - decoder_norm = nn.LayerNorm(hidden_dim) - decoder_layer = DecoderLayer( - hidden_dim, num_heads, expansion_dim, dropout_rate, activation - ) - self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm) - - self._reset_parameters() - - def _reset_parameters(self) -> None: - for p in self.parameters(): - if p.dim() > 1: - nn.init.xavier_uniform_(p) - - def forward( - self, - src: Tensor, - trg: Tensor, - src_mask: Optional[Tensor] = None, - trg_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - ) -> Tensor: - """Forward pass through the transformer.""" - if src.shape[0] != trg.shape[0]: - print(trg.shape) - raise RuntimeError("The batch size of the src and trg must be the same.") - if src.shape[2] != trg.shape[2]: - raise RuntimeError( - "The number of features for the src and trg must be the same." - ) - - memory = self.encoder(src, src_mask) - output = self.decoder(trg, memory, trg_mask, memory_mask) - return output +# """Transfomer module.""" +# import copy +# from typing import Dict, Optional, Type, Union +# +# import numpy as np +# import torch +# from torch import nn +# from torch import Tensor +# import torch.nn.functional as F +# +# from text_recognizer.networks.transformer.attention import MultiHeadAttention +# from text_recognizer.networks.util import activation_function +# +# +# class GEGLU(nn.Module): +# """GLU activation for improving feedforward activations.""" +# +# def __init__(self, dim_in: int, dim_out: int) -> None: +# super().__init__() +# self.proj = nn.Linear(dim_in, dim_out * 2) +# +# def forward(self, x: Tensor) -> Tensor: +# """Forward propagation.""" +# x, gate = self.proj(x).chunk(2, dim=-1) +# return x * F.gelu(gate) +# +# +# def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList: +# return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)]) +# +# +# class _IntraLayerConnection(nn.Module): +# """Preforms the residual connection inside the transfomer blocks and applies layernorm.""" +# +# def __init__(self, dropout_rate: float, hidden_dim: int) -> None: +# super().__init__() +# self.norm = nn.LayerNorm(normalized_shape=hidden_dim) +# self.dropout = nn.Dropout(p=dropout_rate) +# +# def forward(self, src: Tensor, residual: Tensor) -> Tensor: +# return self.norm(self.dropout(src) + residual) +# +# +# class FeedForward(nn.Module): +# def __init__( +# self, +# hidden_dim: int, +# expansion_dim: int, +# dropout_rate: float, +# activation: str = "relu", +# ) -> None: +# super().__init__() +# +# in_projection = ( +# nn.Sequential( +# nn.Linear(hidden_dim, expansion_dim), activation_function(activation) +# ) +# if activation != "glu" +# else GEGLU(hidden_dim, expansion_dim) +# ) +# +# self.layer = nn.Sequential( +# in_projection, +# nn.Dropout(p=dropout_rate), +# nn.Linear(in_features=expansion_dim, out_features=hidden_dim), +# ) +# +# def forward(self, x: Tensor) -> Tensor: +# return self.layer(x) +# +# +# class EncoderLayer(nn.Module): +# """Transfomer encoding layer.""" +# +# def __init__( +# self, +# hidden_dim: int, +# num_heads: int, +# expansion_dim: int, +# dropout_rate: float, +# activation: str = "relu", +# ) -> None: +# super().__init__() +# self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) +# self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation) +# self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) +# self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) +# +# def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: +# """Forward pass through the encoder.""" +# # First block. +# # Multi head attention. +# out, _ = self.self_attention(src, src, src, mask) +# +# # Add & norm. +# out = self.block1(out, src) +# +# # Second block. +# # Apply 1D-convolution. +# mlp_out = self.mlp(out) +# +# # Add & norm. +# out = self.block2(mlp_out, out) +# +# return out +# +# +# class Encoder(nn.Module): +# """Transfomer encoder module.""" +# +# def __init__( +# self, +# num_layers: int, +# encoder_layer: Type[nn.Module], +# norm: Optional[Type[nn.Module]] = None, +# ) -> None: +# super().__init__() +# self.layers = _get_clones(encoder_layer, num_layers) +# self.norm = norm +# +# def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor: +# """Forward pass through all encoder layers.""" +# for layer in self.layers: +# src = layer(src, src_mask) +# +# if self.norm is not None: +# src = self.norm(src) +# +# return src +# +# +# class DecoderLayer(nn.Module): +# """Transfomer decoder layer.""" +# +# def __init__( +# self, +# hidden_dim: int, +# num_heads: int, +# expansion_dim: int, +# dropout_rate: float = 0.0, +# activation: str = "relu", +# ) -> None: +# super().__init__() +# self.hidden_dim = hidden_dim +# self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) +# self.multihead_attention = MultiHeadAttention( +# hidden_dim, num_heads, dropout_rate +# ) +# self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation) +# self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) +# self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) +# self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim) +# +# def forward( +# self, +# trg: Tensor, +# memory: Tensor, +# trg_mask: Optional[Tensor] = None, +# memory_mask: Optional[Tensor] = None, +# ) -> Tensor: +# """Forward pass of the layer.""" +# out, _ = self.self_attention(trg, trg, trg, trg_mask) +# trg = self.block1(out, trg) +# +# out, _ = self.multihead_attention(trg, memory, memory, memory_mask) +# trg = self.block2(out, trg) +# +# out = self.mlp(trg) +# out = self.block3(out, trg) +# +# return out +# +# +# class Decoder(nn.Module): +# """Transfomer decoder module.""" +# +# def __init__( +# self, +# decoder_layer: Type[nn.Module], +# num_layers: int, +# norm: Optional[Type[nn.Module]] = None, +# ) -> None: +# super().__init__() +# self.layers = _get_clones(decoder_layer, num_layers) +# self.num_layers = num_layers +# self.norm = norm +# +# def forward( +# self, +# trg: Tensor, +# memory: Tensor, +# trg_mask: Optional[Tensor] = None, +# memory_mask: Optional[Tensor] = None, +# ) -> Tensor: +# """Forward pass through the decoder.""" +# for layer in self.layers: +# trg = layer(trg, memory, trg_mask, memory_mask) +# +# if self.norm is not None: +# trg = self.norm(trg) +# +# return trg +# +# +# class Transformer(nn.Module): +# """Transformer network.""" +# +# def __init__( +# self, +# num_encoder_layers: int, +# num_decoder_layers: int, +# hidden_dim: int, +# num_heads: int, +# expansion_dim: int, +# dropout_rate: float, +# activation: str = "relu", +# ) -> None: +# super().__init__() +# +# # Configure encoder. +# encoder_norm = nn.LayerNorm(hidden_dim) +# encoder_layer = EncoderLayer( +# hidden_dim, num_heads, expansion_dim, dropout_rate, activation +# ) +# self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm) +# +# # Configure decoder. +# decoder_norm = nn.LayerNorm(hidden_dim) +# decoder_layer = DecoderLayer( +# hidden_dim, num_heads, expansion_dim, dropout_rate, activation +# ) +# self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm) +# +# self._reset_parameters() +# +# def _reset_parameters(self) -> None: +# for p in self.parameters(): +# if p.dim() > 1: +# nn.init.xavier_uniform_(p) +# +# def forward( +# self, +# src: Tensor, +# trg: Tensor, +# src_mask: Optional[Tensor] = None, +# trg_mask: Optional[Tensor] = None, +# memory_mask: Optional[Tensor] = None, +# ) -> Tensor: +# """Forward pass through the transformer.""" +# if src.shape[0] != trg.shape[0]: +# print(trg.shape) +# raise RuntimeError("The batch size of the src and trg must be the same.") +# if src.shape[2] != trg.shape[2]: +# raise RuntimeError( +# "The number of features for the src and trg must be the same." +# ) +# +# memory = self.encoder(src, src_mask) +# output = self.decoder(trg, memory, trg_mask, memory_mask) +# return output diff --git a/training/.gitignore b/training/.gitignore index 333c1e9..7d268ea 100644 --- a/training/.gitignore +++ b/training/.gitignore @@ -1 +1,2 @@ logs/ +outputs/ diff --git a/training/conf/callbacks/default.yaml b/training/conf/callbacks/default.yaml new file mode 100644 index 0000000..74dc30c --- /dev/null +++ b/training/conf/callbacks/default.yaml @@ -0,0 +1,14 @@ +# @package _group_ +- type: ModelCheckpoint + args: + monitor: val_loss + mode: min + save_last: true +- type: LearningRateMonitor + args: + logging_interval: step +# - type: EarlyStopping +# args: +# monitor: val_loss +# mode: min +# patience: 10 diff --git a/training/conf/callbacks/swa.yaml b/training/conf/callbacks/swa.yaml new file mode 100644 index 0000000..144ad6e --- /dev/null +++ b/training/conf/callbacks/swa.yaml @@ -0,0 +1,16 @@ +# @package _group_ +- type: ModelCheckpoint + args: + monitor: val_loss + mode: min + save_last: true +- type: StochasticWeightAveraging + args: + swa_epoch_start: 0.8 + swa_lrs: 0.05 + annealing_epochs: 10 + annealing_strategy: cos + device: null +- type: LearningRateMonitor + args: + logging_interval: step diff --git a/training/configs/cnn_transformer.yaml b/training/conf/cnn_transformer.yaml index a4f16df..a4f16df 100644 --- a/training/configs/cnn_transformer.yaml +++ b/training/conf/cnn_transformer.yaml diff --git a/training/conf/config.yaml b/training/conf/config.yaml new file mode 100644 index 0000000..11adeb7 --- /dev/null +++ b/training/conf/config.yaml @@ -0,0 +1,6 @@ +defaults: + - network: vqvae + - model: lit_vqvae + - dataset: iam_extended_paragraphs + - trainer: default + - callbacks: default diff --git a/training/conf/dataset/iam_extended_paragraphs.yaml b/training/conf/dataset/iam_extended_paragraphs.yaml new file mode 100644 index 0000000..6bd7fc9 --- /dev/null +++ b/training/conf/dataset/iam_extended_paragraphs.yaml @@ -0,0 +1,7 @@ +# @package _group_ +type: IAMExtendedParagraphs +args: + batch_size: 32 + num_workers: 12 + train_fraction: 0.8 + augment: true diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml new file mode 100644 index 0000000..90780b7 --- /dev/null +++ b/training/conf/model/lit_vqvae.yaml @@ -0,0 +1,24 @@ +# @package _group_ +type: LitVQVAEModel +args: + optimizer: + type: MADGRAD + args: + lr: 1.0e-3 + momentum: 0.9 + weight_decay: 0 + eps: 1.0e-6 + lr_scheduler: + type: OneCycleLR + args: + interval: step + max_lr: 1.0e-3 + three_phase: true + epochs: 64 + steps_per_epoch: 633 # num_samples / batch_size + criterion: + type: MSELoss + args: + reduction: mean + monitor: val_loss + mapping: sentence_piece diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml new file mode 100644 index 0000000..8c30bbd --- /dev/null +++ b/training/conf/network/vqvae.yaml @@ -0,0 +1,14 @@ +# @package _group_ +type: VQVAE +args: + in_channels: 1 + channels: [32, 64, 64] + kernel_sizes: [4, 4, 4] + strides: [2, 2, 2] + num_residual_layers: 2 + embedding_dim: 64 + num_embeddings: 256 + upsampling: null + beta: 0.25 + activation: leaky_relu + dropout_rate: 0.2 diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml new file mode 100644 index 0000000..82afd93 --- /dev/null +++ b/training/conf/trainer/default.yaml @@ -0,0 +1,18 @@ +# @package _group_ +seed: 4711 +load_checkpoint: null +wandb: false +tune: false +train: true +test: true +logging: INFO +args: + stochastic_weight_avg: false + auto_scale_batch_size: binsearch + gradient_clip_val: 0 + fast_dev_run: false + gpus: 1 + precision: 16 + max_epochs: 64 + terminate_on_nan: true + weights_summary: top diff --git a/training/configs/vqvae.yaml b/training/configs/vqvae.yaml deleted file mode 100644 index 13d7c97..0000000 --- a/training/configs/vqvae.yaml +++ /dev/null @@ -1,89 +0,0 @@ -seed: 4711 - -network: - desc: Configuration of the PyTorch neural network. - type: VQVAE - args: - in_channels: 1 - channels: [32, 64, 64, 96, 96] - kernel_sizes: [4, 4, 4, 4, 4] - strides: [2, 2, 2, 2, 2] - num_residual_layers: 2 - embedding_dim: 512 - num_embeddings: 1024 - upsampling: null - beta: 0.25 - activation: leaky_relu - dropout_rate: 0.2 - -model: - desc: Configuration of the PyTorch Lightning model. - type: LitVQVAEModel - args: - optimizer: - type: MADGRAD - args: - lr: 1.0e-3 - momentum: 0.9 - weight_decay: 0 - eps: 1.0e-6 - lr_scheduler: - type: OneCycleLR - args: - interval: &interval step - max_lr: 1.0e-3 - three_phase: true - epochs: 64 - steps_per_epoch: 633 # num_samples / batch_size - criterion: - type: MSELoss - args: - reduction: mean - monitor: val_loss - mapping: sentence_piece - -data: - desc: Configuration of the training/test data. - type: IAMExtendedParagraphs - args: - batch_size: 32 - num_workers: 12 - train_fraction: 0.8 - augment: true - -callbacks: - - type: ModelCheckpoint - args: - monitor: val_loss - mode: min - save_last: true - - type: StochasticWeightAveraging - args: - swa_epoch_start: 0.8 - swa_lrs: 0.05 - annealing_epochs: 10 - annealing_strategy: cos - device: null - - type: LearningRateMonitor - args: - logging_interval: *interval - # - type: EarlyStopping - # args: - # monitor: val_loss - # mode: min - # patience: 10 - -trainer: - desc: Configuration of the PyTorch Lightning Trainer. - args: - stochastic_weight_avg: true - auto_scale_batch_size: binsearch - gradient_clip_val: 0 - fast_dev_run: false - gpus: 1 - precision: 16 - max_epochs: 64 - terminate_on_nan: true - weights_summary: top - -load_checkpoint: null diff --git a/training/run_experiment.py b/training/run_experiment.py index bdefbf0..2b3ecab 100644 --- a/training/run_experiment.py +++ b/training/run_experiment.py @@ -4,17 +4,15 @@ import importlib from pathlib import Path from typing import Dict, List, Optional, Type -import click +import hydra from loguru import logger -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig import pytorch_lightning as pl from torch import nn from tqdm import tqdm import wandb -SEED = 4711 -CONFIGS_DIRNAME = Path(__file__).parent.resolve() / "configs" LOGS_DIRNAME = Path(__file__).parent.resolve() / "logs" @@ -29,21 +27,10 @@ def _create_experiment_dir(config: DictConfig) -> Path: return log_dir -def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None: +def _configure_logging(log_dir: Optional[Path], level: str) -> None: """Configure the loguru logger for output to terminal and disk.""" - - def _get_level(verbose: int) -> str: - """Sets the logger level.""" - levels = {0: "WARNING", 1: "INFO", 2: "DEBUG"} - verbose = min(verbose, 2) - return levels[verbose] - # Remove default logger to get tqdm to work properly. logger.remove() - - # Fetch verbosity level. - level = _get_level(verbose) - logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=level) if log_dir is not None: logger.add( @@ -52,14 +39,6 @@ def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None: ) -def _load_config(file_path: Path) -> DictConfig: - """Return experiment config.""" - logger.info(f"Loading config from: {file_path}") - if not file_path.exists(): - raise FileNotFoundError(f"Experiment config not found at: {file_path}") - return OmegaConf.load(file_path) - - def _import_class(module_and_class_name: str) -> type: """Import class from module.""" module_name, class_name = module_and_class_name.rsplit(".", 1) @@ -78,14 +57,16 @@ def _configure_callbacks( def _configure_logger( - network: Type[nn.Module], args: Dict, log_dir: Path, use_wandb: bool + network: Type[nn.Module], config: DictConfig, log_dir: Path ) -> Type[pl.loggers.LightningLoggerBase]: """Configures lightning logger.""" - if use_wandb: + if config.trainer.wandb: + logger.info("Logging model with W&B") pl_logger = pl.loggers.WandbLogger(save_dir=str(log_dir)) pl_logger.watch(network) - pl_logger.log_hyperparams(vars(args)) + pl_logger.log_hyperparams(vars(config)) return pl_logger + logger.info("Logging model with Tensorboard") return pl.loggers.TensorBoardLogger(save_dir=str(log_dir)) @@ -110,50 +91,36 @@ def _load_lit_model( lit_model_class: type, network: Type[nn.Module], config: DictConfig ) -> Type[pl.LightningModule]: """Load lightning model.""" - if config.load_checkpoint is not None: + if config.trainer.load_checkpoint is not None: logger.info( - f"Loading network weights from checkpoint: {config.load_checkpoint}" + f"Loading network weights from checkpoint: {config.trainer.load_checkpoint}" ) return lit_model_class.load_from_checkpoint( - config.load_checkpoint, network=network, **config.model.args + config.trainer.load_checkpoint, network=network, **config.model.args ) return lit_model_class(network=network, **config.model.args) -def run( - filename: str, - fast_dev_run: bool, - train: bool, - test: bool, - tune: bool, - use_wandb: bool, - verbose: int = 0, -) -> None: +def run(config: DictConfig) -> None: """Runs experiment.""" - # Load config. - file_path = CONFIGS_DIRNAME / filename - config = _load_config(file_path) - log_dir = _create_experiment_dir(config) - _configure_logging(log_dir, verbose=verbose) + _configure_logging(log_dir, level=config.trainer.logging) logger.info("Starting experiment...") - # Seed everything in the experiment. - logger.info(f"Seeding everthing with seed={SEED}") - pl.utilities.seed.seed_everything(SEED) + pl.utilities.seed.seed_everything(config.trainer.seed) # Load classes. - data_module_class = _import_class(f"text_recognizer.data.{config.data.type}") + data_module_class = _import_class(f"text_recognizer.data.{config.dataset.type}") network_class = _import_class(f"text_recognizer.networks.{config.network.type}") lit_model_class = _import_class(f"text_recognizer.models.{config.model.type}") # Initialize data object and network. - data_module = data_module_class(**config.data.args) + data_module = data_module_class(**config.dataset.args) network = network_class(**data_module.config(), **config.network.args) # Load callback and logger. callbacks = _configure_callbacks(config.callbacks) - pl_logger = _configure_logger(network, config, log_dir, use_wandb) + pl_logger = _configure_logger(network, config, log_dir) # Load ligtning model. lit_model = _load_lit_model(lit_model_class, network, config) @@ -164,55 +131,28 @@ def run( logger=pl_logger, weights_save_path=str(log_dir), ) - if fast_dev_run: - logger.info("Fast dev run...") + + if config.trainer.tune and not config.trainer.args.fast_dev_run: + logger.info("Tuning learning rate and batch size...") + trainer.tune(lit_model, datamodule=data_module) + + if config.trainer.train: + logger.info("Training network...") trainer.fit(lit_model, datamodule=data_module) - else: - if tune: - logger.info("Tuning learning rate and batch size...") - trainer.tune(lit_model, datamodule=data_module) - - if train: - logger.info("Training network...") - trainer.fit(lit_model, datamodule=data_module) - - if test: - logger.info("Testing network...") - trainer.test(lit_model, datamodule=data_module) - - _save_best_weights(callbacks, use_wandb) - - -@click.command() -@click.option("-f", "--experiment_config", type=str, help="Path to experiment config.") -@click.option("--use_wandb", is_flag=True, help="If true, do use wandb for logging.") -@click.option("--dev", is_flag=True, help="If true, run a fast dev run.") -@click.option( - "--tune", is_flag=True, help="If true, tune hyperparameters for training." -) -@click.option("-t", "--train", is_flag=True, help="If true, train the model.") -@click.option("-e", "--test", is_flag=True, help="If true, test the model.") -@click.option("-v", "--verbose", count=True) -def cli( - experiment_config: str, - use_wandb: bool, - dev: bool, - tune: bool, - train: bool, - test: bool, - verbose: int, -) -> None: - """Run experiment.""" - run( - filename=experiment_config, - fast_dev_run=dev, - train=train, - test=test, - tune=tune, - use_wandb=use_wandb, - verbose=verbose, - ) + + if config.trainer.test and not config.trainer.args.fast_dev_run: + logger.info("Testing network...") + trainer.test(lit_model, datamodule=data_module) + + if not config.trainer.args.fast_dev_run: + _save_best_weights(callbacks, config.trainer.wandb) + + +@hydra.main(config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + """Loads config with hydra.""" + run(cfg) if __name__ == "__main__": - cli() + main() |