summaryrefslogtreecommitdiff
path: root/src/text_recognizer
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-09-20 11:47:24 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-09-20 11:47:24 +0200
commit1b3b8073a19f939d18a0bb85247eb0d99284f7cc (patch)
treee74e78230ebb179237c063fecf0b52458ce3aa3e /src/text_recognizer
parent6137f43c910946301279825e50759a9dd76c6131 (diff)
Bash scripts and some bug fixes.
Diffstat (limited to 'src/text_recognizer')
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py9
-rw-r--r--src/text_recognizer/datasets/iam_paragraphs_dataset.py7
-rw-r--r--src/text_recognizer/datasets/util.py2
-rw-r--r--src/text_recognizer/models/base.py10
-rw-r--r--src/text_recognizer/tests/support/create_emnist_support_files.py13
5 files changed, 23 insertions, 18 deletions
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index 6268a01..beb5343 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -149,6 +149,7 @@ class EmnistLinesDataset(Dataset):
# Load emnist dataset.
emnist = EmnistDataset(train=self.train, sample_to_balance=True)
+ emnist.load_or_generate_data()
samples_by_character = get_samples_by_character(
emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping,
@@ -306,17 +307,13 @@ def create_datasets(
num_test: int = 1000,
) -> None:
"""Creates a training an validation dataset of Emnist lines."""
- emnist_train = EmnistDataset(train=True, sample_to_balance=True)
- emnist_test = EmnistDataset(train=False, sample_to_balance=True)
- datasets = [emnist_train, emnist_test]
num_samples = [num_train, num_test]
- for num, train, dataset in zip(num_samples, [True, False], datasets):
+ for num, train in zip(num_samples, [True, False]):
emnist_lines = EmnistLinesDataset(
train=train,
- emnist=dataset,
max_length=max_length,
min_overlap=min_overlap,
max_overlap=max_overlap,
num_samples=num,
)
- emnist_lines._load_or_generate_data()
+ emnist_lines.load_or_generate_data()
diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/src/text_recognizer/datasets/iam_paragraphs_dataset.py
index 4b34bd1..c1e8fe2 100644
--- a/src/text_recognizer/datasets/iam_paragraphs_dataset.py
+++ b/src/text_recognizer/datasets/iam_paragraphs_dataset.py
@@ -266,11 +266,16 @@ def _load_iam_paragraphs() -> None:
@click.option(
"--subsample_fraction",
type=float,
- default=0.0,
+ default=None,
help="The subsampling factor of the dataset.",
)
def main(subsample_fraction: float) -> None:
"""Load dataset and print info."""
+ logger.info("Creating train set...")
+ dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction)
+ dataset.load_or_generate_data()
+ print(dataset)
+ logger.info("Creating test set...")
dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction)
dataset.load_or_generate_data()
print(dataset)
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
index 73968a1..125f05a 100644
--- a/src/text_recognizer/datasets/util.py
+++ b/src/text_recognizer/datasets/util.py
@@ -26,7 +26,7 @@ def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None:
mapping = [(i, str(label)) for i, label in enumerate(labels)]
essentials = {
"mapping": mapping,
- "input_shape": tuple(emnsit_dataset[0][0].shape[:]),
+ "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]),
}
logger.info("Saving emnist essentials...")
with open(ESSENTIALS_FILENAME, "w") as f:
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index caf8065..e89b670 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -356,7 +356,8 @@ class Model(ABC):
state["optimizer_state"] = self._optimizer.state_dict()
if self._lr_scheduler is not None:
- state["scheduler_state"] = self._lr_scheduler.state_dict()
+ state["scheduler_state"] = self._lr_scheduler["lr_scheduler"].state_dict()
+ state["scheduler_interval"] = self._lr_scheduler["interval"]
if self._swa_network is not None:
state["swa_network"] = self._swa_network.state_dict()
@@ -383,8 +384,11 @@ class Model(ABC):
if self._lr_scheduler is not None:
# Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs
# with OneCycleLR.
- if self._lr_scheduler.__class__.__name__ != "OneCycleLR":
- self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"])
+ if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR":
+ self._lr_scheduler["lr_scheduler"].load_state_dict(
+ checkpoint["scheduler_state"]
+ )
+ self._lr_scheduler["interval"] = checkpoint["scheduler_interval"]
if self._swa_network is not None:
self._swa_network.load_state_dict(checkpoint["swa_network"])
diff --git a/src/text_recognizer/tests/support/create_emnist_support_files.py b/src/text_recognizer/tests/support/create_emnist_support_files.py
index 5dd1a81..c04860d 100644
--- a/src/text_recognizer/tests/support/create_emnist_support_files.py
+++ b/src/text_recognizer/tests/support/create_emnist_support_files.py
@@ -2,10 +2,8 @@
from pathlib import Path
import shutil
-from text_recognizer.datasets.emnist_dataset import (
- fetch_emnist_dataset,
- load_emnist_mapping,
-)
+from text_recognizer.datasets.emnist_dataset import EmnistDataset
+from text_recognizer.datasets.util import EmnistMapper
from text_recognizer.util import write_image
SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "emnist"
@@ -16,15 +14,16 @@ def create_emnist_support_files() -> None:
shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True)
SUPPORT_DIRNAME.mkdir()
- dataset = fetch_emnist_dataset(split="byclass", train=False)
- mapping = load_emnist_mapping()
+ dataset = EmnistDataset(train=False)
+ dataset.load_or_generate_data()
+ mapping = EmnistMapper()
for index in [5, 7, 9]:
image, label = dataset[index]
if len(image.shape) == 3:
image = image.squeeze(0)
image = image.numpy()
- label = mapping[int(label)]
+ label = mapping(int(label))
print(index, label)
write_image(image, str(SUPPORT_DIRNAME / f"{label}.png"))