summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--poetry.lock52
-rw-r--r--pyproject.toml3
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py2
-rw-r--r--text_recognizer/networks/__init__.py3
-rw-r--r--text_recognizer/networks/cnn_transformer.py364
-rw-r--r--text_recognizer/networks/transformer/__init__.py3
-rw-r--r--text_recognizer/networks/transformer/transformer.py520
-rw-r--r--training/.gitignore1
-rw-r--r--training/conf/callbacks/default.yaml14
-rw-r--r--training/conf/callbacks/swa.yaml16
-rw-r--r--training/conf/cnn_transformer.yaml (renamed from training/configs/cnn_transformer.yaml)0
-rw-r--r--training/conf/config.yaml6
-rw-r--r--training/conf/dataset/iam_extended_paragraphs.yaml7
-rw-r--r--training/conf/model/lit_vqvae.yaml24
-rw-r--r--training/conf/network/vqvae.yaml14
-rw-r--r--training/conf/trainer/default.yaml18
-rw-r--r--training/configs/vqvae.yaml89
-rw-r--r--training/run_experiment.py136
18 files changed, 635 insertions, 637 deletions
diff --git a/poetry.lock b/poetry.lock
index a8034e7..81ccaed 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -37,6 +37,14 @@ optional = false
python-versions = "*"
[[package]]
+name = "antlr4-python3-runtime"
+version = "4.8"
+description = "ANTLR 4.8 runtime for Python 3.7"
+category = "main"
+optional = false
+python-versions = "*"
+
+[[package]]
name = "appdirs"
version = "1.4.4"
description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
@@ -622,6 +630,19 @@ numpy = [
]
[[package]]
+name = "hydra-core"
+version = "1.0.6"
+description = "A framework for elegantly configuring complex applications"
+category = "main"
+optional = false
+python-versions = "*"
+
+[package.dependencies]
+antlr4-python3-runtime = "4.8"
+importlib-resources = {version = "*", markers = "python_version < \"3.9\""}
+omegaconf = ">=2.0.5,<2.1"
+
+[[package]]
name = "idna"
version = "2.10"
description = "Internationalized Domain Names in Applications (IDNA)"
@@ -638,6 +659,18 @@ optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
[[package]]
+name = "importlib-resources"
+version = "5.1.2"
+description = "Read resources from Python packages"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.extras]
+docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"]
+testing = ["pytest (>=4.6)", "pytest-checkdocs (>=1.2.3)", "pytest-flake8", "pytest-cov", "pytest-enabler", "pytest-black (>=0.3.7)", "pytest-mypy"]
+
+[[package]]
name = "ipykernel"
version = "5.5.3"
description = "IPython Kernel for Jupyter"
@@ -2086,7 +2119,7 @@ brotli = ["brotlipy (>=0.6.0)"]
[[package]]
name = "wandb"
-version = "0.10.27"
+version = "0.10.28"
description = "A CLI and library for interacting with the Weights and Biases API."
category = "main"
optional = false
@@ -2197,7 +2230,7 @@ multidict = ">=4.0"
[metadata]
lock-version = "1.1"
python-versions = "^3.8"
-content-hash = "07d5f14d7c55a961ce1841ecd125c0f7c83d5649cc118fcae2ed3b58347ca8c2"
+content-hash = "bcc456879df9e9ec937f1aa8c339fe96ca18ce95b12b12fc10c23494829296e8"
[metadata.files]
absl-py = [
@@ -2247,6 +2280,9 @@ alabaster = [
{file = "alabaster-0.7.12-py2.py3-none-any.whl", hash = "sha256:446438bdcca0e05bd45ea2de1668c1d9b032e1a9154c2c259092d77031ddd359"},
{file = "alabaster-0.7.12.tar.gz", hash = "sha256:a661d72d58e6ea8a57f7a86e37d86716863ee5e92788398526d58b26a4e4dc02"},
]
+antlr4-python3-runtime = [
+ {file = "antlr4-python3-runtime-4.8.tar.gz", hash = "sha256:15793f5d0512a372b4e7d2284058ad32ce7dd27126b105fb0b2245130445db33"},
+]
appdirs = [
{file = "appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128"},
{file = "appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41"},
@@ -2628,6 +2664,10 @@ h5py = [
{file = "h5py-3.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:7c5b5f18c96fb63399280a724734fd91e1781c6b60e385e439ad8e654a294ba4"},
{file = "h5py-3.2.1.tar.gz", hash = "sha256:89474be911bfcdb34cbf0d98b8ec48b578c27a89fdb1ae4ee7513f1ef8d9249e"},
]
+hydra-core = [
+ {file = "hydra-core-1.0.6.tar.gz", hash = "sha256:be5fe119eb41ada20c82835b153b956781de7e2506670ba942f7453b2e850950"},
+ {file = "hydra_core-1.0.6-py3-none-any.whl", hash = "sha256:500d4346b7afcd654276c87c15820d7e6b76c2da95ad698cceb4120d7a877b32"},
+]
idna = [
{file = "idna-2.10-py2.py3-none-any.whl", hash = "sha256:b97d804b1e9b523befed77c48dacec60e6dcb0b5391d57af6a65a312a90648c0"},
{file = "idna-2.10.tar.gz", hash = "sha256:b307872f855b18632ce0c21c5e45be78c0ea7ae4c15c828c20788b26921eb3f6"},
@@ -2636,6 +2676,10 @@ imagesize = [
{file = "imagesize-1.2.0-py2.py3-none-any.whl", hash = "sha256:6965f19a6a2039c7d48bca7dba2473069ff854c36ae6f19d2cde309d998228a1"},
{file = "imagesize-1.2.0.tar.gz", hash = "sha256:b1f6b5a4eab1f73479a50fb79fcf729514a900c341d8503d62a62dbc4127a2b1"},
]
+importlib-resources = [
+ {file = "importlib_resources-5.1.2-py3-none-any.whl", hash = "sha256:ebab3efe74d83b04d6bf5cd9a17f0c5c93e60fb60f30c90f56265fce4682a469"},
+ {file = "importlib_resources-5.1.2.tar.gz", hash = "sha256:642586fc4740bd1cad7690f836b3321309402b20b332529f25617ff18e8e1370"},
+]
ipykernel = [
{file = "ipykernel-5.5.3-py3-none-any.whl", hash = "sha256:21abd584543759e49010975a4621603b3cf871b1039cb3879a14094717692614"},
{file = "ipykernel-5.5.3.tar.gz", hash = "sha256:a682e4f7affd86d9ce9b699d21bcab6d5ec9fbb2bfcb194f2706973b252bc509"},
@@ -3635,8 +3679,8 @@ urllib3 = [
{file = "urllib3-1.26.4.tar.gz", hash = "sha256:e7b021f7241115872f92f43c6508082facffbd1c048e3c6e2bb9c2a157e28937"},
]
wandb = [
- {file = "wandb-0.10.27-py2.py3-none-any.whl", hash = "sha256:f591bb9d5c402ec5c12bd823db913d49ac31e4068f2c35c50b6f13e4fcd717b4"},
- {file = "wandb-0.10.27.tar.gz", hash = "sha256:1b6d3bbfd644183bbd79a02ff9e57e65a8a4f7c9b770af0d3c4f961ae6d7fc73"},
+ {file = "wandb-0.10.28-py2.py3-none-any.whl", hash = "sha256:609f2605ec82f846490a2bdd9d01fa947b1892a76133cbb80c41501e5dfda763"},
+ {file = "wandb-0.10.28.tar.gz", hash = "sha256:b48aa55f147717e197d38b0dd9d9ef3662efe54fefc0be2909f93e8a5c0cc71b"},
]
wcwidth = [
{file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"},
diff --git a/pyproject.toml b/pyproject.toml
index f9bce95..6f050c9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -32,7 +32,7 @@ nltk = "^3.5"
torch-summary = "^1.4.2"
defusedxml = "^0.6.0"
omegaconf = "^2.0.2"
-wandb = "^0.10.27"
+wandb = "^0.10.28"
einops = "^0.3.0"
gtn = "^0.0.0"
sentencepiece = "^0.1.95"
@@ -41,6 +41,7 @@ Pillow = "^8.1.2"
madgrad = "^1.0"
editdistance = "^0.5.3"
torchmetrics = "^0.2.0"
+hydra-core = "^1.0.6"
[tool.poetry.dev-dependencies]
pytest = "^5.4.2"
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index 78e6c05..ad6fa25 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -76,7 +76,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
def setup(self, stage: str = None) -> None:
"""Loading synthetic dataset."""
- logger.info(f"IAM Synthetic dataset steup for stage {stage}")
+ logger.info(f"IAM Synthetic dataset steup for stage {stage}...")
if stage == "fit" or stage is None:
line_crops, line_labels = load_line_crops_and_labels(
diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py
index a9117f8..d1ebf1a 100644
--- a/text_recognizer/networks/__init__.py
+++ b/text_recognizer/networks/__init__.py
@@ -1,4 +1,5 @@
"""Network modules"""
from .encoders import EfficientNet
from .vqvae import VQVAE
-from .cnn_transformer import CNNTransformer
+
+# from .cnn_transformer import CNNTransformer
diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py
index d42c29d..80798e1 100644
--- a/text_recognizer/networks/cnn_transformer.py
+++ b/text_recognizer/networks/cnn_transformer.py
@@ -1,182 +1,182 @@
-"""A Transformer with a cnn backbone.
-
-The network encodes a image with a convolutional backbone to a latent representation,
-i.e. feature maps. A 2d positional encoding is applied to the feature maps for
-spatial information. The resulting feature are then set to a transformer decoder
-together with the target tokens.
-
-TODO: Local attention for lower layer in attention.
-
-"""
-import importlib
-import math
-from typing import Dict, Optional, Union, Sequence, Type
-
-from einops import rearrange
-from omegaconf import DictConfig, OmegaConf
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS
-from text_recognizer.networks.transformer import (
- Decoder,
- DecoderLayer,
- PositionalEncoding,
- PositionalEncoding2D,
- target_padding_mask,
-)
-
-NUM_WORD_PIECES = 1000
-
-
-class CNNTransformer(nn.Module):
- def __init__(
- self,
- input_dim: Sequence[int],
- output_dims: Sequence[int],
- encoder: Union[DictConfig, Dict],
- vocab_size: Optional[int] = None,
- num_decoder_layers: int = 4,
- hidden_dim: int = 256,
- num_heads: int = 4,
- expansion_dim: int = 1024,
- dropout_rate: float = 0.1,
- transformer_activation: str = "glu",
- *args,
- **kwargs,
- ) -> None:
- super().__init__()
- self.vocab_size = (
- NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size
- )
- self.pad_index = 3 # TODO: fix me
- self.hidden_dim = hidden_dim
- self.max_output_length = output_dims[0]
-
- # Image backbone
- self.encoder = self._configure_encoder(encoder)
- self.encoder_proj = nn.Conv2d(256, hidden_dim, kernel_size=1)
- self.feature_map_encoding = PositionalEncoding2D(
- hidden_dim=hidden_dim, max_h=input_dim[1], max_w=input_dim[2]
- )
-
- # Target token embedding
- self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim)
- self.trg_position_encoding = PositionalEncoding(
- hidden_dim, dropout_rate, max_len=output_dims[0]
- )
-
- # Transformer decoder
- self.decoder = Decoder(
- decoder_layer=DecoderLayer(
- hidden_dim=hidden_dim,
- num_heads=num_heads,
- expansion_dim=expansion_dim,
- dropout_rate=dropout_rate,
- activation=transformer_activation,
- ),
- num_layers=num_decoder_layers,
- norm=nn.LayerNorm(hidden_dim),
- )
-
- # Classification head
- self.head = nn.Linear(hidden_dim, self.vocab_size)
-
- # Initialize weights
- self._init_weights()
-
- def _init_weights(self) -> None:
- """Initialize network weights."""
- self.trg_embedding.weight.data.uniform_(-0.1, 0.1)
- self.head.bias.data.zero_()
- self.head.weight.data.uniform_(-0.1, 0.1)
-
- nn.init.kaiming_normal_(
- self.encoder_proj.weight.data,
- a=0,
- mode="fan_out",
- nonlinearity="relu",
- )
- if self.encoder_proj.bias is not None:
- _, fan_out = nn.init._calculate_fan_in_and_fan_out(
- self.encoder_proj.weight.data
- )
- bound = 1 / math.sqrt(fan_out)
- nn.init.normal_(self.encoder_proj.bias, -bound, bound)
-
- @staticmethod
- def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]:
- encoder = OmegaConf.create(encoder)
- args = encoder.args or {}
- network_module = importlib.import_module("text_recognizer.networks")
- encoder_class = getattr(network_module, encoder.type)
- return encoder_class(**args)
-
- def encode(self, image: Tensor) -> Tensor:
- """Extracts image features with backbone.
-
- Args:
- image (Tensor): Image(s) of handwritten text.
-
- Retuns:
- Tensor: Image features.
-
- Shapes:
- - image: :math: `(B, C, H, W)`
- - latent: :math: `(B, T, C)`
-
- """
- # Extract image features.
- image_features = self.encoder(image)
- image_features = self.encoder_proj(image_features)
-
- # Add 2d encoding to the feature maps.
- image_features = self.feature_map_encoding(image_features)
-
- # Collapse features maps height and width.
- image_features = rearrange(image_features, "b c h w -> b (h w) c")
- return image_features
-
- def decode(self, memory: Tensor, trg: Tensor) -> Tensor:
- """Decodes image features with transformer decoder."""
- trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index)
- trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim)
- trg = rearrange(trg, "b t d -> t b d")
- trg = self.trg_position_encoding(trg)
- trg = rearrange(trg, "t b d -> b t d")
- out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None)
- logits = self.head(out)
- return logits
-
- def forward(self, image: Tensor, trg: Tensor) -> Tensor:
- image_features = self.encode(image)
- output = self.decode(image_features, trg)
- output = rearrange(output, "b t c -> b c t")
- return output
-
- def predict(self, image: Tensor) -> Tensor:
- """Transcribes text in image(s)."""
- bsz = image.shape[0]
- image_features = self.encode(image)
-
- output_tokens = (
- (torch.ones((bsz, self.max_output_length)) * self.pad_index)
- .type_as(image)
- .long()
- )
- output_tokens[:, 0] = self.start_index
- for i in range(1, self.max_output_length):
- trg = output_tokens[:, :i]
- output = self.decode(image_features, trg)
- output = torch.argmax(output, dim=-1)
- output_tokens[:, i] = output[-1:]
-
- # Set all tokens after end token to be padding.
- for i in range(1, self.max_output_length):
- indices = output_tokens[:, i - 1] == self.end_index | (
- output_tokens[:, i - 1] == self.pad_index
- )
- output_tokens[indices, i] = self.pad_index
-
- return output_tokens
+# """A Transformer with a cnn backbone.
+#
+# The network encodes a image with a convolutional backbone to a latent representation,
+# i.e. feature maps. A 2d positional encoding is applied to the feature maps for
+# spatial information. The resulting feature are then set to a transformer decoder
+# together with the target tokens.
+#
+# TODO: Local attention for lower layer in attention.
+#
+# """
+# import importlib
+# import math
+# from typing import Dict, Optional, Union, Sequence, Type
+#
+# from einops import rearrange
+# from omegaconf import DictConfig, OmegaConf
+# import torch
+# from torch import nn
+# from torch import Tensor
+#
+# from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS
+# from text_recognizer.networks.transformer import (
+# Decoder,
+# DecoderLayer,
+# PositionalEncoding,
+# PositionalEncoding2D,
+# target_padding_mask,
+# )
+#
+# NUM_WORD_PIECES = 1000
+#
+#
+# class CNNTransformer(nn.Module):
+# def __init__(
+# self,
+# input_dim: Sequence[int],
+# output_dims: Sequence[int],
+# encoder: Union[DictConfig, Dict],
+# vocab_size: Optional[int] = None,
+# num_decoder_layers: int = 4,
+# hidden_dim: int = 256,
+# num_heads: int = 4,
+# expansion_dim: int = 1024,
+# dropout_rate: float = 0.1,
+# transformer_activation: str = "glu",
+# *args,
+# **kwargs,
+# ) -> None:
+# super().__init__()
+# self.vocab_size = (
+# NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size
+# )
+# self.pad_index = 3 # TODO: fix me
+# self.hidden_dim = hidden_dim
+# self.max_output_length = output_dims[0]
+#
+# # Image backbone
+# self.encoder = self._configure_encoder(encoder)
+# self.encoder_proj = nn.Conv2d(256, hidden_dim, kernel_size=1)
+# self.feature_map_encoding = PositionalEncoding2D(
+# hidden_dim=hidden_dim, max_h=input_dim[1], max_w=input_dim[2]
+# )
+#
+# # Target token embedding
+# self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim)
+# self.trg_position_encoding = PositionalEncoding(
+# hidden_dim, dropout_rate, max_len=output_dims[0]
+# )
+#
+# # Transformer decoder
+# self.decoder = Decoder(
+# decoder_layer=DecoderLayer(
+# hidden_dim=hidden_dim,
+# num_heads=num_heads,
+# expansion_dim=expansion_dim,
+# dropout_rate=dropout_rate,
+# activation=transformer_activation,
+# ),
+# num_layers=num_decoder_layers,
+# norm=nn.LayerNorm(hidden_dim),
+# )
+#
+# # Classification head
+# self.head = nn.Linear(hidden_dim, self.vocab_size)
+#
+# # Initialize weights
+# self._init_weights()
+#
+# def _init_weights(self) -> None:
+# """Initialize network weights."""
+# self.trg_embedding.weight.data.uniform_(-0.1, 0.1)
+# self.head.bias.data.zero_()
+# self.head.weight.data.uniform_(-0.1, 0.1)
+#
+# nn.init.kaiming_normal_(
+# self.encoder_proj.weight.data,
+# a=0,
+# mode="fan_out",
+# nonlinearity="relu",
+# )
+# if self.encoder_proj.bias is not None:
+# _, fan_out = nn.init._calculate_fan_in_and_fan_out(
+# self.encoder_proj.weight.data
+# )
+# bound = 1 / math.sqrt(fan_out)
+# nn.init.normal_(self.encoder_proj.bias, -bound, bound)
+#
+# @staticmethod
+# def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]:
+# encoder = OmegaConf.create(encoder)
+# args = encoder.args or {}
+# network_module = importlib.import_module("text_recognizer.networks")
+# encoder_class = getattr(network_module, encoder.type)
+# return encoder_class(**args)
+#
+# def encode(self, image: Tensor) -> Tensor:
+# """Extracts image features with backbone.
+#
+# Args:
+# image (Tensor): Image(s) of handwritten text.
+#
+# Retuns:
+# Tensor: Image features.
+#
+# Shapes:
+# - image: :math: `(B, C, H, W)`
+# - latent: :math: `(B, T, C)`
+#
+# """
+# # Extract image features.
+# image_features = self.encoder(image)
+# image_features = self.encoder_proj(image_features)
+#
+# # Add 2d encoding to the feature maps.
+# image_features = self.feature_map_encoding(image_features)
+#
+# # Collapse features maps height and width.
+# image_features = rearrange(image_features, "b c h w -> b (h w) c")
+# return image_features
+#
+# def decode(self, memory: Tensor, trg: Tensor) -> Tensor:
+# """Decodes image features with transformer decoder."""
+# trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index)
+# trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim)
+# trg = rearrange(trg, "b t d -> t b d")
+# trg = self.trg_position_encoding(trg)
+# trg = rearrange(trg, "t b d -> b t d")
+# out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None)
+# logits = self.head(out)
+# return logits
+#
+# def forward(self, image: Tensor, trg: Tensor) -> Tensor:
+# image_features = self.encode(image)
+# output = self.decode(image_features, trg)
+# output = rearrange(output, "b t c -> b c t")
+# return output
+#
+# def predict(self, image: Tensor) -> Tensor:
+# """Transcribes text in image(s)."""
+# bsz = image.shape[0]
+# image_features = self.encode(image)
+#
+# output_tokens = (
+# (torch.ones((bsz, self.max_output_length)) * self.pad_index)
+# .type_as(image)
+# .long()
+# )
+# output_tokens[:, 0] = self.start_index
+# for i in range(1, self.max_output_length):
+# trg = output_tokens[:, :i]
+# output = self.decode(image_features, trg)
+# output = torch.argmax(output, dim=-1)
+# output_tokens[:, i] = output[-1:]
+#
+# # Set all tokens after end token to be padding.
+# for i in range(1, self.max_output_length):
+# indices = output_tokens[:, i - 1] == self.end_index | (
+# output_tokens[:, i - 1] == self.pad_index
+# )
+# output_tokens[indices, i] = self.pad_index
+#
+# return output_tokens
diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py
index 627fa7b..4ff48f7 100644
--- a/text_recognizer/networks/transformer/__init__.py
+++ b/text_recognizer/networks/transformer/__init__.py
@@ -4,4 +4,5 @@ from .positional_encoding import (
PositionalEncoding2D,
target_padding_mask,
)
-from .transformer import Decoder, DecoderLayer, Encoder, EncoderLayer, Transformer
+
+# from .transformer import Decoder, DecoderLayer, Encoder, EncoderLayer, Transformer
diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py
index 5ac2787..d49c85a 100644
--- a/text_recognizer/networks/transformer/transformer.py
+++ b/text_recognizer/networks/transformer/transformer.py
@@ -1,260 +1,260 @@
-"""Transfomer module."""
-import copy
-from typing import Dict, Optional, Type, Union
-
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-import torch.nn.functional as F
-
-from text_recognizer.networks.transformer.attention import MultiHeadAttention
-from text_recognizer.networks.util import activation_function
-
-
-class GEGLU(nn.Module):
- """GLU activation for improving feedforward activations."""
-
- def __init__(self, dim_in: int, dim_out: int) -> None:
- super().__init__()
- self.proj = nn.Linear(dim_in, dim_out * 2)
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward propagation."""
- x, gate = self.proj(x).chunk(2, dim=-1)
- return x * F.gelu(gate)
-
-
-def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList:
- return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)])
-
-
-class _IntraLayerConnection(nn.Module):
- """Preforms the residual connection inside the transfomer blocks and applies layernorm."""
-
- def __init__(self, dropout_rate: float, hidden_dim: int) -> None:
- super().__init__()
- self.norm = nn.LayerNorm(normalized_shape=hidden_dim)
- self.dropout = nn.Dropout(p=dropout_rate)
-
- def forward(self, src: Tensor, residual: Tensor) -> Tensor:
- return self.norm(self.dropout(src) + residual)
-
-
-class FeedForward(nn.Module):
- def __init__(
- self,
- hidden_dim: int,
- expansion_dim: int,
- dropout_rate: float,
- activation: str = "relu",
- ) -> None:
- super().__init__()
-
- in_projection = (
- nn.Sequential(
- nn.Linear(hidden_dim, expansion_dim), activation_function(activation)
- )
- if activation != "glu"
- else GEGLU(hidden_dim, expansion_dim)
- )
-
- self.layer = nn.Sequential(
- in_projection,
- nn.Dropout(p=dropout_rate),
- nn.Linear(in_features=expansion_dim, out_features=hidden_dim),
- )
-
- def forward(self, x: Tensor) -> Tensor:
- return self.layer(x)
-
-
-class EncoderLayer(nn.Module):
- """Transfomer encoding layer."""
-
- def __init__(
- self,
- hidden_dim: int,
- num_heads: int,
- expansion_dim: int,
- dropout_rate: float,
- activation: str = "relu",
- ) -> None:
- super().__init__()
- self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
- self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation)
- self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
- self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
-
- def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
- """Forward pass through the encoder."""
- # First block.
- # Multi head attention.
- out, _ = self.self_attention(src, src, src, mask)
-
- # Add & norm.
- out = self.block1(out, src)
-
- # Second block.
- # Apply 1D-convolution.
- mlp_out = self.mlp(out)
-
- # Add & norm.
- out = self.block2(mlp_out, out)
-
- return out
-
-
-class Encoder(nn.Module):
- """Transfomer encoder module."""
-
- def __init__(
- self,
- num_layers: int,
- encoder_layer: Type[nn.Module],
- norm: Optional[Type[nn.Module]] = None,
- ) -> None:
- super().__init__()
- self.layers = _get_clones(encoder_layer, num_layers)
- self.norm = norm
-
- def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
- """Forward pass through all encoder layers."""
- for layer in self.layers:
- src = layer(src, src_mask)
-
- if self.norm is not None:
- src = self.norm(src)
-
- return src
-
-
-class DecoderLayer(nn.Module):
- """Transfomer decoder layer."""
-
- def __init__(
- self,
- hidden_dim: int,
- num_heads: int,
- expansion_dim: int,
- dropout_rate: float = 0.0,
- activation: str = "relu",
- ) -> None:
- super().__init__()
- self.hidden_dim = hidden_dim
- self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
- self.multihead_attention = MultiHeadAttention(
- hidden_dim, num_heads, dropout_rate
- )
- self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation)
- self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
- self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
- self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim)
-
- def forward(
- self,
- trg: Tensor,
- memory: Tensor,
- trg_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- ) -> Tensor:
- """Forward pass of the layer."""
- out, _ = self.self_attention(trg, trg, trg, trg_mask)
- trg = self.block1(out, trg)
-
- out, _ = self.multihead_attention(trg, memory, memory, memory_mask)
- trg = self.block2(out, trg)
-
- out = self.mlp(trg)
- out = self.block3(out, trg)
-
- return out
-
-
-class Decoder(nn.Module):
- """Transfomer decoder module."""
-
- def __init__(
- self,
- decoder_layer: Type[nn.Module],
- num_layers: int,
- norm: Optional[Type[nn.Module]] = None,
- ) -> None:
- super().__init__()
- self.layers = _get_clones(decoder_layer, num_layers)
- self.num_layers = num_layers
- self.norm = norm
-
- def forward(
- self,
- trg: Tensor,
- memory: Tensor,
- trg_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- ) -> Tensor:
- """Forward pass through the decoder."""
- for layer in self.layers:
- trg = layer(trg, memory, trg_mask, memory_mask)
-
- if self.norm is not None:
- trg = self.norm(trg)
-
- return trg
-
-
-class Transformer(nn.Module):
- """Transformer network."""
-
- def __init__(
- self,
- num_encoder_layers: int,
- num_decoder_layers: int,
- hidden_dim: int,
- num_heads: int,
- expansion_dim: int,
- dropout_rate: float,
- activation: str = "relu",
- ) -> None:
- super().__init__()
-
- # Configure encoder.
- encoder_norm = nn.LayerNorm(hidden_dim)
- encoder_layer = EncoderLayer(
- hidden_dim, num_heads, expansion_dim, dropout_rate, activation
- )
- self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm)
-
- # Configure decoder.
- decoder_norm = nn.LayerNorm(hidden_dim)
- decoder_layer = DecoderLayer(
- hidden_dim, num_heads, expansion_dim, dropout_rate, activation
- )
- self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm)
-
- self._reset_parameters()
-
- def _reset_parameters(self) -> None:
- for p in self.parameters():
- if p.dim() > 1:
- nn.init.xavier_uniform_(p)
-
- def forward(
- self,
- src: Tensor,
- trg: Tensor,
- src_mask: Optional[Tensor] = None,
- trg_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- ) -> Tensor:
- """Forward pass through the transformer."""
- if src.shape[0] != trg.shape[0]:
- print(trg.shape)
- raise RuntimeError("The batch size of the src and trg must be the same.")
- if src.shape[2] != trg.shape[2]:
- raise RuntimeError(
- "The number of features for the src and trg must be the same."
- )
-
- memory = self.encoder(src, src_mask)
- output = self.decoder(trg, memory, trg_mask, memory_mask)
- return output
+# """Transfomer module."""
+# import copy
+# from typing import Dict, Optional, Type, Union
+#
+# import numpy as np
+# import torch
+# from torch import nn
+# from torch import Tensor
+# import torch.nn.functional as F
+#
+# from text_recognizer.networks.transformer.attention import MultiHeadAttention
+# from text_recognizer.networks.util import activation_function
+#
+#
+# class GEGLU(nn.Module):
+# """GLU activation for improving feedforward activations."""
+#
+# def __init__(self, dim_in: int, dim_out: int) -> None:
+# super().__init__()
+# self.proj = nn.Linear(dim_in, dim_out * 2)
+#
+# def forward(self, x: Tensor) -> Tensor:
+# """Forward propagation."""
+# x, gate = self.proj(x).chunk(2, dim=-1)
+# return x * F.gelu(gate)
+#
+#
+# def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList:
+# return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)])
+#
+#
+# class _IntraLayerConnection(nn.Module):
+# """Preforms the residual connection inside the transfomer blocks and applies layernorm."""
+#
+# def __init__(self, dropout_rate: float, hidden_dim: int) -> None:
+# super().__init__()
+# self.norm = nn.LayerNorm(normalized_shape=hidden_dim)
+# self.dropout = nn.Dropout(p=dropout_rate)
+#
+# def forward(self, src: Tensor, residual: Tensor) -> Tensor:
+# return self.norm(self.dropout(src) + residual)
+#
+#
+# class FeedForward(nn.Module):
+# def __init__(
+# self,
+# hidden_dim: int,
+# expansion_dim: int,
+# dropout_rate: float,
+# activation: str = "relu",
+# ) -> None:
+# super().__init__()
+#
+# in_projection = (
+# nn.Sequential(
+# nn.Linear(hidden_dim, expansion_dim), activation_function(activation)
+# )
+# if activation != "glu"
+# else GEGLU(hidden_dim, expansion_dim)
+# )
+#
+# self.layer = nn.Sequential(
+# in_projection,
+# nn.Dropout(p=dropout_rate),
+# nn.Linear(in_features=expansion_dim, out_features=hidden_dim),
+# )
+#
+# def forward(self, x: Tensor) -> Tensor:
+# return self.layer(x)
+#
+#
+# class EncoderLayer(nn.Module):
+# """Transfomer encoding layer."""
+#
+# def __init__(
+# self,
+# hidden_dim: int,
+# num_heads: int,
+# expansion_dim: int,
+# dropout_rate: float,
+# activation: str = "relu",
+# ) -> None:
+# super().__init__()
+# self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
+# self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation)
+# self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
+# self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
+#
+# def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
+# """Forward pass through the encoder."""
+# # First block.
+# # Multi head attention.
+# out, _ = self.self_attention(src, src, src, mask)
+#
+# # Add & norm.
+# out = self.block1(out, src)
+#
+# # Second block.
+# # Apply 1D-convolution.
+# mlp_out = self.mlp(out)
+#
+# # Add & norm.
+# out = self.block2(mlp_out, out)
+#
+# return out
+#
+#
+# class Encoder(nn.Module):
+# """Transfomer encoder module."""
+#
+# def __init__(
+# self,
+# num_layers: int,
+# encoder_layer: Type[nn.Module],
+# norm: Optional[Type[nn.Module]] = None,
+# ) -> None:
+# super().__init__()
+# self.layers = _get_clones(encoder_layer, num_layers)
+# self.norm = norm
+#
+# def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
+# """Forward pass through all encoder layers."""
+# for layer in self.layers:
+# src = layer(src, src_mask)
+#
+# if self.norm is not None:
+# src = self.norm(src)
+#
+# return src
+#
+#
+# class DecoderLayer(nn.Module):
+# """Transfomer decoder layer."""
+#
+# def __init__(
+# self,
+# hidden_dim: int,
+# num_heads: int,
+# expansion_dim: int,
+# dropout_rate: float = 0.0,
+# activation: str = "relu",
+# ) -> None:
+# super().__init__()
+# self.hidden_dim = hidden_dim
+# self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
+# self.multihead_attention = MultiHeadAttention(
+# hidden_dim, num_heads, dropout_rate
+# )
+# self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation)
+# self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
+# self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
+# self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim)
+#
+# def forward(
+# self,
+# trg: Tensor,
+# memory: Tensor,
+# trg_mask: Optional[Tensor] = None,
+# memory_mask: Optional[Tensor] = None,
+# ) -> Tensor:
+# """Forward pass of the layer."""
+# out, _ = self.self_attention(trg, trg, trg, trg_mask)
+# trg = self.block1(out, trg)
+#
+# out, _ = self.multihead_attention(trg, memory, memory, memory_mask)
+# trg = self.block2(out, trg)
+#
+# out = self.mlp(trg)
+# out = self.block3(out, trg)
+#
+# return out
+#
+#
+# class Decoder(nn.Module):
+# """Transfomer decoder module."""
+#
+# def __init__(
+# self,
+# decoder_layer: Type[nn.Module],
+# num_layers: int,
+# norm: Optional[Type[nn.Module]] = None,
+# ) -> None:
+# super().__init__()
+# self.layers = _get_clones(decoder_layer, num_layers)
+# self.num_layers = num_layers
+# self.norm = norm
+#
+# def forward(
+# self,
+# trg: Tensor,
+# memory: Tensor,
+# trg_mask: Optional[Tensor] = None,
+# memory_mask: Optional[Tensor] = None,
+# ) -> Tensor:
+# """Forward pass through the decoder."""
+# for layer in self.layers:
+# trg = layer(trg, memory, trg_mask, memory_mask)
+#
+# if self.norm is not None:
+# trg = self.norm(trg)
+#
+# return trg
+#
+#
+# class Transformer(nn.Module):
+# """Transformer network."""
+#
+# def __init__(
+# self,
+# num_encoder_layers: int,
+# num_decoder_layers: int,
+# hidden_dim: int,
+# num_heads: int,
+# expansion_dim: int,
+# dropout_rate: float,
+# activation: str = "relu",
+# ) -> None:
+# super().__init__()
+#
+# # Configure encoder.
+# encoder_norm = nn.LayerNorm(hidden_dim)
+# encoder_layer = EncoderLayer(
+# hidden_dim, num_heads, expansion_dim, dropout_rate, activation
+# )
+# self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm)
+#
+# # Configure decoder.
+# decoder_norm = nn.LayerNorm(hidden_dim)
+# decoder_layer = DecoderLayer(
+# hidden_dim, num_heads, expansion_dim, dropout_rate, activation
+# )
+# self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm)
+#
+# self._reset_parameters()
+#
+# def _reset_parameters(self) -> None:
+# for p in self.parameters():
+# if p.dim() > 1:
+# nn.init.xavier_uniform_(p)
+#
+# def forward(
+# self,
+# src: Tensor,
+# trg: Tensor,
+# src_mask: Optional[Tensor] = None,
+# trg_mask: Optional[Tensor] = None,
+# memory_mask: Optional[Tensor] = None,
+# ) -> Tensor:
+# """Forward pass through the transformer."""
+# if src.shape[0] != trg.shape[0]:
+# print(trg.shape)
+# raise RuntimeError("The batch size of the src and trg must be the same.")
+# if src.shape[2] != trg.shape[2]:
+# raise RuntimeError(
+# "The number of features for the src and trg must be the same."
+# )
+#
+# memory = self.encoder(src, src_mask)
+# output = self.decoder(trg, memory, trg_mask, memory_mask)
+# return output
diff --git a/training/.gitignore b/training/.gitignore
index 333c1e9..7d268ea 100644
--- a/training/.gitignore
+++ b/training/.gitignore
@@ -1 +1,2 @@
logs/
+outputs/
diff --git a/training/conf/callbacks/default.yaml b/training/conf/callbacks/default.yaml
new file mode 100644
index 0000000..74dc30c
--- /dev/null
+++ b/training/conf/callbacks/default.yaml
@@ -0,0 +1,14 @@
+# @package _group_
+- type: ModelCheckpoint
+ args:
+ monitor: val_loss
+ mode: min
+ save_last: true
+- type: LearningRateMonitor
+ args:
+ logging_interval: step
+# - type: EarlyStopping
+# args:
+# monitor: val_loss
+# mode: min
+# patience: 10
diff --git a/training/conf/callbacks/swa.yaml b/training/conf/callbacks/swa.yaml
new file mode 100644
index 0000000..144ad6e
--- /dev/null
+++ b/training/conf/callbacks/swa.yaml
@@ -0,0 +1,16 @@
+# @package _group_
+- type: ModelCheckpoint
+ args:
+ monitor: val_loss
+ mode: min
+ save_last: true
+- type: StochasticWeightAveraging
+ args:
+ swa_epoch_start: 0.8
+ swa_lrs: 0.05
+ annealing_epochs: 10
+ annealing_strategy: cos
+ device: null
+- type: LearningRateMonitor
+ args:
+ logging_interval: step
diff --git a/training/configs/cnn_transformer.yaml b/training/conf/cnn_transformer.yaml
index a4f16df..a4f16df 100644
--- a/training/configs/cnn_transformer.yaml
+++ b/training/conf/cnn_transformer.yaml
diff --git a/training/conf/config.yaml b/training/conf/config.yaml
new file mode 100644
index 0000000..11adeb7
--- /dev/null
+++ b/training/conf/config.yaml
@@ -0,0 +1,6 @@
+defaults:
+ - network: vqvae
+ - model: lit_vqvae
+ - dataset: iam_extended_paragraphs
+ - trainer: default
+ - callbacks: default
diff --git a/training/conf/dataset/iam_extended_paragraphs.yaml b/training/conf/dataset/iam_extended_paragraphs.yaml
new file mode 100644
index 0000000..6bd7fc9
--- /dev/null
+++ b/training/conf/dataset/iam_extended_paragraphs.yaml
@@ -0,0 +1,7 @@
+# @package _group_
+type: IAMExtendedParagraphs
+args:
+ batch_size: 32
+ num_workers: 12
+ train_fraction: 0.8
+ augment: true
diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml
new file mode 100644
index 0000000..90780b7
--- /dev/null
+++ b/training/conf/model/lit_vqvae.yaml
@@ -0,0 +1,24 @@
+# @package _group_
+type: LitVQVAEModel
+args:
+ optimizer:
+ type: MADGRAD
+ args:
+ lr: 1.0e-3
+ momentum: 0.9
+ weight_decay: 0
+ eps: 1.0e-6
+ lr_scheduler:
+ type: OneCycleLR
+ args:
+ interval: step
+ max_lr: 1.0e-3
+ three_phase: true
+ epochs: 64
+ steps_per_epoch: 633 # num_samples / batch_size
+ criterion:
+ type: MSELoss
+ args:
+ reduction: mean
+ monitor: val_loss
+ mapping: sentence_piece
diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml
new file mode 100644
index 0000000..8c30bbd
--- /dev/null
+++ b/training/conf/network/vqvae.yaml
@@ -0,0 +1,14 @@
+# @package _group_
+type: VQVAE
+args:
+ in_channels: 1
+ channels: [32, 64, 64]
+ kernel_sizes: [4, 4, 4]
+ strides: [2, 2, 2]
+ num_residual_layers: 2
+ embedding_dim: 64
+ num_embeddings: 256
+ upsampling: null
+ beta: 0.25
+ activation: leaky_relu
+ dropout_rate: 0.2
diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml
new file mode 100644
index 0000000..82afd93
--- /dev/null
+++ b/training/conf/trainer/default.yaml
@@ -0,0 +1,18 @@
+# @package _group_
+seed: 4711
+load_checkpoint: null
+wandb: false
+tune: false
+train: true
+test: true
+logging: INFO
+args:
+ stochastic_weight_avg: false
+ auto_scale_batch_size: binsearch
+ gradient_clip_val: 0
+ fast_dev_run: false
+ gpus: 1
+ precision: 16
+ max_epochs: 64
+ terminate_on_nan: true
+ weights_summary: top
diff --git a/training/configs/vqvae.yaml b/training/configs/vqvae.yaml
deleted file mode 100644
index 13d7c97..0000000
--- a/training/configs/vqvae.yaml
+++ /dev/null
@@ -1,89 +0,0 @@
-seed: 4711
-
-network:
- desc: Configuration of the PyTorch neural network.
- type: VQVAE
- args:
- in_channels: 1
- channels: [32, 64, 64, 96, 96]
- kernel_sizes: [4, 4, 4, 4, 4]
- strides: [2, 2, 2, 2, 2]
- num_residual_layers: 2
- embedding_dim: 512
- num_embeddings: 1024
- upsampling: null
- beta: 0.25
- activation: leaky_relu
- dropout_rate: 0.2
-
-model:
- desc: Configuration of the PyTorch Lightning model.
- type: LitVQVAEModel
- args:
- optimizer:
- type: MADGRAD
- args:
- lr: 1.0e-3
- momentum: 0.9
- weight_decay: 0
- eps: 1.0e-6
- lr_scheduler:
- type: OneCycleLR
- args:
- interval: &interval step
- max_lr: 1.0e-3
- three_phase: true
- epochs: 64
- steps_per_epoch: 633 # num_samples / batch_size
- criterion:
- type: MSELoss
- args:
- reduction: mean
- monitor: val_loss
- mapping: sentence_piece
-
-data:
- desc: Configuration of the training/test data.
- type: IAMExtendedParagraphs
- args:
- batch_size: 32
- num_workers: 12
- train_fraction: 0.8
- augment: true
-
-callbacks:
- - type: ModelCheckpoint
- args:
- monitor: val_loss
- mode: min
- save_last: true
- - type: StochasticWeightAveraging
- args:
- swa_epoch_start: 0.8
- swa_lrs: 0.05
- annealing_epochs: 10
- annealing_strategy: cos
- device: null
- - type: LearningRateMonitor
- args:
- logging_interval: *interval
- # - type: EarlyStopping
- # args:
- # monitor: val_loss
- # mode: min
- # patience: 10
-
-trainer:
- desc: Configuration of the PyTorch Lightning Trainer.
- args:
- stochastic_weight_avg: true
- auto_scale_batch_size: binsearch
- gradient_clip_val: 0
- fast_dev_run: false
- gpus: 1
- precision: 16
- max_epochs: 64
- terminate_on_nan: true
- weights_summary: top
-
-load_checkpoint: null
diff --git a/training/run_experiment.py b/training/run_experiment.py
index bdefbf0..2b3ecab 100644
--- a/training/run_experiment.py
+++ b/training/run_experiment.py
@@ -4,17 +4,15 @@ import importlib
from pathlib import Path
from typing import Dict, List, Optional, Type
-import click
+import hydra
from loguru import logger
-from omegaconf import DictConfig, OmegaConf
+from omegaconf import DictConfig
import pytorch_lightning as pl
from torch import nn
from tqdm import tqdm
import wandb
-SEED = 4711
-CONFIGS_DIRNAME = Path(__file__).parent.resolve() / "configs"
LOGS_DIRNAME = Path(__file__).parent.resolve() / "logs"
@@ -29,21 +27,10 @@ def _create_experiment_dir(config: DictConfig) -> Path:
return log_dir
-def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None:
+def _configure_logging(log_dir: Optional[Path], level: str) -> None:
"""Configure the loguru logger for output to terminal and disk."""
-
- def _get_level(verbose: int) -> str:
- """Sets the logger level."""
- levels = {0: "WARNING", 1: "INFO", 2: "DEBUG"}
- verbose = min(verbose, 2)
- return levels[verbose]
-
# Remove default logger to get tqdm to work properly.
logger.remove()
-
- # Fetch verbosity level.
- level = _get_level(verbose)
-
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=level)
if log_dir is not None:
logger.add(
@@ -52,14 +39,6 @@ def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None:
)
-def _load_config(file_path: Path) -> DictConfig:
- """Return experiment config."""
- logger.info(f"Loading config from: {file_path}")
- if not file_path.exists():
- raise FileNotFoundError(f"Experiment config not found at: {file_path}")
- return OmegaConf.load(file_path)
-
-
def _import_class(module_and_class_name: str) -> type:
"""Import class from module."""
module_name, class_name = module_and_class_name.rsplit(".", 1)
@@ -78,14 +57,16 @@ def _configure_callbacks(
def _configure_logger(
- network: Type[nn.Module], args: Dict, log_dir: Path, use_wandb: bool
+ network: Type[nn.Module], config: DictConfig, log_dir: Path
) -> Type[pl.loggers.LightningLoggerBase]:
"""Configures lightning logger."""
- if use_wandb:
+ if config.trainer.wandb:
+ logger.info("Logging model with W&B")
pl_logger = pl.loggers.WandbLogger(save_dir=str(log_dir))
pl_logger.watch(network)
- pl_logger.log_hyperparams(vars(args))
+ pl_logger.log_hyperparams(vars(config))
return pl_logger
+ logger.info("Logging model with Tensorboard")
return pl.loggers.TensorBoardLogger(save_dir=str(log_dir))
@@ -110,50 +91,36 @@ def _load_lit_model(
lit_model_class: type, network: Type[nn.Module], config: DictConfig
) -> Type[pl.LightningModule]:
"""Load lightning model."""
- if config.load_checkpoint is not None:
+ if config.trainer.load_checkpoint is not None:
logger.info(
- f"Loading network weights from checkpoint: {config.load_checkpoint}"
+ f"Loading network weights from checkpoint: {config.trainer.load_checkpoint}"
)
return lit_model_class.load_from_checkpoint(
- config.load_checkpoint, network=network, **config.model.args
+ config.trainer.load_checkpoint, network=network, **config.model.args
)
return lit_model_class(network=network, **config.model.args)
-def run(
- filename: str,
- fast_dev_run: bool,
- train: bool,
- test: bool,
- tune: bool,
- use_wandb: bool,
- verbose: int = 0,
-) -> None:
+def run(config: DictConfig) -> None:
"""Runs experiment."""
- # Load config.
- file_path = CONFIGS_DIRNAME / filename
- config = _load_config(file_path)
-
log_dir = _create_experiment_dir(config)
- _configure_logging(log_dir, verbose=verbose)
+ _configure_logging(log_dir, level=config.trainer.logging)
logger.info("Starting experiment...")
- # Seed everything in the experiment.
- logger.info(f"Seeding everthing with seed={SEED}")
- pl.utilities.seed.seed_everything(SEED)
+ pl.utilities.seed.seed_everything(config.trainer.seed)
# Load classes.
- data_module_class = _import_class(f"text_recognizer.data.{config.data.type}")
+ data_module_class = _import_class(f"text_recognizer.data.{config.dataset.type}")
network_class = _import_class(f"text_recognizer.networks.{config.network.type}")
lit_model_class = _import_class(f"text_recognizer.models.{config.model.type}")
# Initialize data object and network.
- data_module = data_module_class(**config.data.args)
+ data_module = data_module_class(**config.dataset.args)
network = network_class(**data_module.config(), **config.network.args)
# Load callback and logger.
callbacks = _configure_callbacks(config.callbacks)
- pl_logger = _configure_logger(network, config, log_dir, use_wandb)
+ pl_logger = _configure_logger(network, config, log_dir)
# Load ligtning model.
lit_model = _load_lit_model(lit_model_class, network, config)
@@ -164,55 +131,28 @@ def run(
logger=pl_logger,
weights_save_path=str(log_dir),
)
- if fast_dev_run:
- logger.info("Fast dev run...")
+
+ if config.trainer.tune and not config.trainer.args.fast_dev_run:
+ logger.info("Tuning learning rate and batch size...")
+ trainer.tune(lit_model, datamodule=data_module)
+
+ if config.trainer.train:
+ logger.info("Training network...")
trainer.fit(lit_model, datamodule=data_module)
- else:
- if tune:
- logger.info("Tuning learning rate and batch size...")
- trainer.tune(lit_model, datamodule=data_module)
-
- if train:
- logger.info("Training network...")
- trainer.fit(lit_model, datamodule=data_module)
-
- if test:
- logger.info("Testing network...")
- trainer.test(lit_model, datamodule=data_module)
-
- _save_best_weights(callbacks, use_wandb)
-
-
-@click.command()
-@click.option("-f", "--experiment_config", type=str, help="Path to experiment config.")
-@click.option("--use_wandb", is_flag=True, help="If true, do use wandb for logging.")
-@click.option("--dev", is_flag=True, help="If true, run a fast dev run.")
-@click.option(
- "--tune", is_flag=True, help="If true, tune hyperparameters for training."
-)
-@click.option("-t", "--train", is_flag=True, help="If true, train the model.")
-@click.option("-e", "--test", is_flag=True, help="If true, test the model.")
-@click.option("-v", "--verbose", count=True)
-def cli(
- experiment_config: str,
- use_wandb: bool,
- dev: bool,
- tune: bool,
- train: bool,
- test: bool,
- verbose: int,
-) -> None:
- """Run experiment."""
- run(
- filename=experiment_config,
- fast_dev_run=dev,
- train=train,
- test=test,
- tune=tune,
- use_wandb=use_wandb,
- verbose=verbose,
- )
+
+ if config.trainer.test and not config.trainer.args.fast_dev_run:
+ logger.info("Testing network...")
+ trainer.test(lit_model, datamodule=data_module)
+
+ if not config.trainer.args.fast_dev_run:
+ _save_best_weights(callbacks, config.trainer.wandb)
+
+
+@hydra.main(config_path="conf", config_name="config")
+def main(cfg: DictConfig) -> None:
+ """Loads config with hydra."""
+ run(cfg)
if __name__ == "__main__":
- cli()
+ main()