diff options
Diffstat (limited to 'src')
-rwxr-xr-x | src/tasks/create_emnist_lines_datasets.sh | 2 | ||||
-rwxr-xr-x | src/tasks/create_iam_paragraphs.sh | 2 | ||||
-rwxr-xr-x | src/tasks/download_emnist.sh | 3 | ||||
-rwxr-xr-x | src/tasks/download_iam.sh | 2 | ||||
-rwxr-xr-x | src/tasks/prepare_experiments.sh | 2 | ||||
-rwxr-xr-x | src/tasks/train_crnn_line_ctc_model.sh | 1 | ||||
-rwxr-xr-x | src/tasks/train_embedding_model.sh | 3 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 9 | ||||
-rw-r--r-- | src/text_recognizer/datasets/iam_paragraphs_dataset.py | 7 | ||||
-rw-r--r-- | src/text_recognizer/datasets/util.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/models/base.py | 10 | ||||
-rw-r--r-- | src/text_recognizer/tests/support/create_emnist_support_files.py | 13 | ||||
-rw-r--r-- | src/training/experiments/default_config_emnist.yml | 1 | ||||
-rw-r--r-- | src/training/experiments/embedding_experiment.yml | 58 | ||||
-rw-r--r-- | src/training/experiments/line_ctc_experiment.yml | 3 | ||||
-rw-r--r-- | src/training/experiments/sample_experiment.yml | 1 | ||||
-rw-r--r-- | src/training/run_experiment.py | 3 | ||||
-rw-r--r-- | src/training/trainer/callbacks/base.py | 2 |
18 files changed, 101 insertions, 23 deletions
diff --git a/src/tasks/create_emnist_lines_datasets.sh b/src/tasks/create_emnist_lines_datasets.sh new file mode 100755 index 0000000..98d4f8b --- /dev/null +++ b/src/tasks/create_emnist_lines_datasets.sh @@ -0,0 +1,2 @@ +#!/bin/bash +poetry run create-emnist-lines-datasets diff --git a/src/tasks/create_iam_paragraphs.sh b/src/tasks/create_iam_paragraphs.sh new file mode 100755 index 0000000..9063d44 --- /dev/null +++ b/src/tasks/create_iam_paragraphs.sh @@ -0,0 +1,2 @@ +#!/bin/bash +poetry run create-iam-paragraphs diff --git a/src/tasks/download_emnist.sh b/src/tasks/download_emnist.sh new file mode 100755 index 0000000..d142324 --- /dev/null +++ b/src/tasks/download_emnist.sh @@ -0,0 +1,3 @@ +#!/bin/bash +poetry run download-emnist +poetry run create-emnist-support-files diff --git a/src/tasks/download_iam.sh b/src/tasks/download_iam.sh new file mode 100755 index 0000000..4bf011c --- /dev/null +++ b/src/tasks/download_iam.sh @@ -0,0 +1,2 @@ +#!/bin/bash +poetry run download-iam diff --git a/src/tasks/prepare_experiments.sh b/src/tasks/prepare_experiments.sh index 9b91daa..f2787b3 100755 --- a/src/tasks/prepare_experiments.sh +++ b/src/tasks/prepare_experiments.sh @@ -1,3 +1,3 @@ #!/bin/bash experiments_filename=${1:-training/experiments/sample_experiment.yml} -python training/prepare_experiments.py --experiments_filename $experiments_filename +poetry run prepare-experiments --experiments_filename $experiments_filename diff --git a/src/tasks/train_crnn_line_ctc_model.sh b/src/tasks/train_crnn_line_ctc_model.sh index 0f83668..020c4a6 100755 --- a/src/tasks/train_crnn_line_ctc_model.sh +++ b/src/tasks/train_crnn_line_ctc_model.sh @@ -1,4 +1,5 @@ #!/bin/bash experiments_filename=${1:-training/experiments/line_ctc_experiment.yml} OUTPUT=$(./tasks/prepare_experiments.sh $experiments_filename) +echo $OUTPUT eval $OUTPUT diff --git a/src/tasks/train_embedding_model.sh b/src/tasks/train_embedding_model.sh index 85a4a6d..da59116 100755 --- a/src/tasks/train_embedding_model.sh +++ b/src/tasks/train_embedding_model.sh @@ -1,4 +1,5 @@ #!/bin/bash -experiments_filename=${1:-training/experiments/embedding_encoder.yml} +experiments_filename=${1:-training/experiments/embedding_experiment.yml} OUTPUT=$(./tasks/prepare_experiments.sh $experiments_filename) +echo $OUTPUT eval $OUTPUT diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index 6268a01..beb5343 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -149,6 +149,7 @@ class EmnistLinesDataset(Dataset): # Load emnist dataset. emnist = EmnistDataset(train=self.train, sample_to_balance=True) + emnist.load_or_generate_data() samples_by_character = get_samples_by_character( emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping, @@ -306,17 +307,13 @@ def create_datasets( num_test: int = 1000, ) -> None: """Creates a training an validation dataset of Emnist lines.""" - emnist_train = EmnistDataset(train=True, sample_to_balance=True) - emnist_test = EmnistDataset(train=False, sample_to_balance=True) - datasets = [emnist_train, emnist_test] num_samples = [num_train, num_test] - for num, train, dataset in zip(num_samples, [True, False], datasets): + for num, train in zip(num_samples, [True, False]): emnist_lines = EmnistLinesDataset( train=train, - emnist=dataset, max_length=max_length, min_overlap=min_overlap, max_overlap=max_overlap, num_samples=num, ) - emnist_lines._load_or_generate_data() + emnist_lines.load_or_generate_data() diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/src/text_recognizer/datasets/iam_paragraphs_dataset.py index 4b34bd1..c1e8fe2 100644 --- a/src/text_recognizer/datasets/iam_paragraphs_dataset.py +++ b/src/text_recognizer/datasets/iam_paragraphs_dataset.py @@ -266,11 +266,16 @@ def _load_iam_paragraphs() -> None: @click.option( "--subsample_fraction", type=float, - default=0.0, + default=None, help="The subsampling factor of the dataset.", ) def main(subsample_fraction: float) -> None: """Load dataset and print info.""" + logger.info("Creating train set...") + dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction) + dataset.load_or_generate_data() + print(dataset) + logger.info("Creating test set...") dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction) dataset.load_or_generate_data() print(dataset) diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index 73968a1..125f05a 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -26,7 +26,7 @@ def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None: mapping = [(i, str(label)) for i, label in enumerate(labels)] essentials = { "mapping": mapping, - "input_shape": tuple(emnsit_dataset[0][0].shape[:]), + "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]), } logger.info("Saving emnist essentials...") with open(ESSENTIALS_FILENAME, "w") as f: diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index caf8065..e89b670 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -356,7 +356,8 @@ class Model(ABC): state["optimizer_state"] = self._optimizer.state_dict() if self._lr_scheduler is not None: - state["scheduler_state"] = self._lr_scheduler.state_dict() + state["scheduler_state"] = self._lr_scheduler["lr_scheduler"].state_dict() + state["scheduler_interval"] = self._lr_scheduler["interval"] if self._swa_network is not None: state["swa_network"] = self._swa_network.state_dict() @@ -383,8 +384,11 @@ class Model(ABC): if self._lr_scheduler is not None: # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs # with OneCycleLR. - if self._lr_scheduler.__class__.__name__ != "OneCycleLR": - self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"]) + if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR": + self._lr_scheduler["lr_scheduler"].load_state_dict( + checkpoint["scheduler_state"] + ) + self._lr_scheduler["interval"] = checkpoint["scheduler_interval"] if self._swa_network is not None: self._swa_network.load_state_dict(checkpoint["swa_network"]) diff --git a/src/text_recognizer/tests/support/create_emnist_support_files.py b/src/text_recognizer/tests/support/create_emnist_support_files.py index 5dd1a81..c04860d 100644 --- a/src/text_recognizer/tests/support/create_emnist_support_files.py +++ b/src/text_recognizer/tests/support/create_emnist_support_files.py @@ -2,10 +2,8 @@ from pathlib import Path import shutil -from text_recognizer.datasets.emnist_dataset import ( - fetch_emnist_dataset, - load_emnist_mapping, -) +from text_recognizer.datasets.emnist_dataset import EmnistDataset +from text_recognizer.datasets.util import EmnistMapper from text_recognizer.util import write_image SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "emnist" @@ -16,15 +14,16 @@ def create_emnist_support_files() -> None: shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True) SUPPORT_DIRNAME.mkdir() - dataset = fetch_emnist_dataset(split="byclass", train=False) - mapping = load_emnist_mapping() + dataset = EmnistDataset(train=False) + dataset.load_or_generate_data() + mapping = EmnistMapper() for index in [5, 7, 9]: image, label = dataset[index] if len(image.shape) == 3: image = image.squeeze(0) image = image.numpy() - label = mapping[int(label)] + label = mapping(int(label)) print(index, label) write_image(image, str(SUPPORT_DIRNAME / f"{label}.png")) diff --git a/src/training/experiments/default_config_emnist.yml b/src/training/experiments/default_config_emnist.yml index 12a0a9d..bf2ed0a 100644 --- a/src/training/experiments/default_config_emnist.yml +++ b/src/training/experiments/default_config_emnist.yml @@ -66,4 +66,5 @@ callback_args: null verbosity: 1 # 0, 1, 2 resume_experiment: null +train: true validation_metric: val_accuracy diff --git a/src/training/experiments/embedding_experiment.yml b/src/training/experiments/embedding_experiment.yml new file mode 100644 index 0000000..e674c26 --- /dev/null +++ b/src/training/experiments/embedding_experiment.yml @@ -0,0 +1,58 @@ +experiment_group: Embedding Experiments +experiments: + - train_args: + batch_size: 256 + max_epochs: &max_epochs 8 + dataset: + type: EmnistDataset + args: + sample_to_balance: true + subsample_fraction: null + transform: null + target_transform: null + seed: 4711 + train_args: + num_workers: 8 + train_fraction: 0.85 + model: CharacterModel + metrics: [] + network: + type: ResidualNetwork + args: + in_channels: 1 + num_classes: 64 # Embedding + depths: [2,2] + block_sizes: [32, 64] + activation: selu + stn: false + criterion: + type: EmbeddingLoss + args: + margin: 0.2 + type_of_triplets: semihard + optimizer: + type: AdamW + args: + lr: 1.e-02 + betas: [0.9, 0.999] + eps: 1.e-08 + weight_decay: 5.e-4 + amsgrad: false + lr_scheduler: + type: CosineAnnealingLR + args: + T_max: *max_epochs + callbacks: [Checkpoint, ProgressBar, WandbCallback] + callback_args: + Checkpoint: + monitor: val_loss + mode: min + ProgressBar: + epochs: *max_epochs + WandbCallback: + log_batch_frequency: 10 + verbosity: 1 # 0, 1, 2 + resume_experiment: null + train: true + test: true + test_metric: mean_average_precision_at_r diff --git a/src/training/experiments/line_ctc_experiment.yml b/src/training/experiments/line_ctc_experiment.yml index f309990..ef97527 100644 --- a/src/training/experiments/line_ctc_experiment.yml +++ b/src/training/experiments/line_ctc_experiment.yml @@ -27,7 +27,7 @@ experiments: # stn: false backbone: ResidualNetwork backbone_args: - pretrained: training/experiments/CharacterModel_EmnistDataset_ResidualNetwork/0920_010806/model/best.pt + pretrained: training/experiments/CharacterModel_EmnistDataset_ResidualNetwork/0920_025816/model/best.pt freeze: false flatten: false input_size: 64 @@ -87,5 +87,6 @@ experiments: num_examples: 6 verbosity: 1 # 0, 1, 2 resume_experiment: null + train: true test: true test_metric: test_cer diff --git a/src/training/experiments/sample_experiment.yml b/src/training/experiments/sample_experiment.yml index 8664a15..a073a87 100644 --- a/src/training/experiments/sample_experiment.yml +++ b/src/training/experiments/sample_experiment.yml @@ -95,5 +95,6 @@ experiments: use_transpose: true verbosity: 0 # 0, 1, 2 resume_experiment: null + train: true test: true test_metric: test_accuracy diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index cc882ad..9d45841 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -273,7 +273,8 @@ def run_experiment( ) # Train the model. - trainer.fit(model) + if experiment_config["train"]: + trainer.fit(model) # Run inference over test set. if experiment_config["test"]: diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py index 8c7b085..f81fc1f 100644 --- a/src/training/trainer/callbacks/base.py +++ b/src/training/trainer/callbacks/base.py @@ -92,7 +92,7 @@ class CallbackList: def append(self, callback: Type[Callback]) -> None: """Append new callback to callback list.""" - self.callbacks.append(callback) + self._callbacks.append(callback) def on_fit_begin(self) -> None: """Called when fit begins.""" |