diff options
-rw-r--r-- | text_recognizer/data/emnist.py | 14 | ||||
-rw-r--r-- | text_recognizer/data/emnist_lines.py | 10 | ||||
-rw-r--r-- | text_recognizer/data/iam.py | 7 | ||||
-rw-r--r-- | text_recognizer/data/iam_extended_paragraphs.py | 14 | ||||
-rw-r--r-- | text_recognizer/data/iam_lines.py | 3 | ||||
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 5 | ||||
-rw-r--r-- | text_recognizer/data/iam_synthetic_paragraphs.py | 12 | ||||
-rw-r--r-- | text_recognizer/data/mappings/emnist.py | 2 | ||||
-rw-r--r-- | text_recognizer/data/transforms/load_transform.py | 7 | ||||
-rw-r--r-- | text_recognizer/data/transforms/pad.py | 2 | ||||
-rw-r--r-- | text_recognizer/data/utils/image_utils.py | 2 | ||||
-rw-r--r-- | text_recognizer/data/utils/sentence_generator.py | 2 | ||||
-rw-r--r-- | text_recognizer/models/base.py | 5 | ||||
-rw-r--r-- | text_recognizer/models/transformer.py | 4 | ||||
-rw-r--r-- | training/callbacks/wandb_callbacks.py | 12 |
15 files changed, 49 insertions, 52 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index e551543..72cc80a 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -1,25 +1,21 @@ """EMNIST dataset: downloads it from FSDL aws url if not present.""" import json import os -from pathlib import Path import shutil -from typing import Sequence, Tuple import zipfile +from pathlib import Path +from typing import Optional, Sequence, Tuple import h5py -from loguru import logger as log import numpy as np import toml +from loguru import logger as log -from text_recognizer.data.base_data_module import ( - BaseDataModule, - load_and_print_info, -) +from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.base_dataset import BaseDataset, split_dataset from text_recognizer.data.transforms.load_transform import load_transform_from_file from text_recognizer.data.utils.download_utils import download_dataset - SEED = 4711 NUM_SPECIAL_TOKENS = 4 SAMPLE_TO_BALANCE = True @@ -55,7 +51,7 @@ class EMNIST(BaseDataModule): if not PROCESSED_DATA_FILENAME.exists(): download_and_process_emnist() - def setup(self, stage: str = None) -> None: + def setup(self, stage: Optional[str] = None) -> None: """Loads the dataset specified by the stage.""" if stage == "fit" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index b33c9e1..c36132e 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -4,23 +4,19 @@ from pathlib import Path from typing import Callable, DefaultDict, List, Optional, Tuple, Type import h5py -from loguru import logger as log import numpy as np import torch -from torch import Tensor import torchvision.transforms as T +from loguru import logger as log +from torch import Tensor -from text_recognizer.data.base_data_module import ( - BaseDataModule, - load_and_print_info, -) +from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels from text_recognizer.data.emnist import EMNIST from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file from text_recognizer.data.utils.sentence_generator import SentenceGenerator - DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist_lines" ESSENTIALS_FILENAME = ( Path(__file__).parents[0].resolve() / "mappings" / "emnist_lines_essentials.json" diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py index a4d1d21..e3baf88 100644 --- a/text_recognizer/data/iam.py +++ b/text_recognizer/data/iam.py @@ -4,19 +4,18 @@ Which encompasses both paragraphs and lines, with associated utilities. """ import os -from pathlib import Path -from typing import Any, Dict, List import xml.etree.ElementTree as ElementTree import zipfile +from pathlib import Path +from typing import Any, Dict, List +import toml from boltons.cacheutils import cachedproperty from loguru import logger as log -import toml from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.utils.download_utils import download_dataset - RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "iam" METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "iam" diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 90df5f8..3ec8221 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -1,11 +1,12 @@ """IAM original and sythetic dataset class.""" from typing import Callable, Optional + from torch.utils.data import ConcatDataset from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam_paragraphs import IAMParagraphs -from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs +from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file @@ -87,10 +88,15 @@ class IAMExtendedParagraphs(BaseDataModule): x = x[0] if isinstance(x, list) else x xt = xt[0] if isinstance(xt, list) else xt data = ( - f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" - f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" + "Train/val/test sizes: " + f"{len(self.data_train)}, " + f"{len(self.data_val)}, " + f"{len(self.data_test)}\n" + "Train Batch x stats: " + f"{(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" - f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" + f"Test Batch x stats: " + f"{(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n" ) return basic + data diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index 4899a48..a55ff1c 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -7,8 +7,8 @@ import json from pathlib import Path from typing import Callable, List, Optional, Sequence, Tuple, Type -from loguru import logger as log import numpy as np +from loguru import logger as log from PIL import Image, ImageFile, ImageOps from torch import Tensor @@ -23,7 +23,6 @@ from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file from text_recognizer.data.utils import image_utils - ImageFile.LOAD_TRUNCATED_IMAGES = True SEED = 4711 diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index 3bf28ff..c7d5229 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -3,10 +3,10 @@ import json from pathlib import Path from typing import Callable, Dict, List, Optional, Sequence, Tuple -from loguru import logger as log import numpy as np -from PIL import Image, ImageOps import torchvision.transforms as T +from loguru import logger as log +from PIL import Image, ImageOps from tqdm import tqdm from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info @@ -19,7 +19,6 @@ from text_recognizer.data.iam import IAM from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file - PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_paragraphs" NEW_LINE_TOKEN = "\n" diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index d51e010..5e66499 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -2,15 +2,12 @@ import random from typing import Any, Callable, List, Optional, Sequence, Tuple -from loguru import logger as log import numpy as np +from loguru import logger as log from PIL import Image from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info -from text_recognizer.data.base_dataset import ( - BaseDataset, - convert_strings_to_labels, -) +from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels from text_recognizer.data.iam import IAM from text_recognizer.data.iam_lines import ( line_crops_and_labels, @@ -18,16 +15,15 @@ from text_recognizer.data.iam_lines import ( save_images_and_labels, ) from text_recognizer.data.iam_paragraphs import ( - get_dataset_properties, - IAMParagraphs, IMAGE_SCALE_FACTOR, NEW_LINE_TOKEN, + IAMParagraphs, + get_dataset_properties, resize_image, ) from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file - PROCESSED_DATA_DIRNAME = ( BaseDataModule.data_dirname() / "processed" / "iam_synthetic_paragraphs" ) diff --git a/text_recognizer/data/mappings/emnist.py b/text_recognizer/data/mappings/emnist.py index 03465a1..331976e 100644 --- a/text_recognizer/data/mappings/emnist.py +++ b/text_recognizer/data/mappings/emnist.py @@ -1,7 +1,7 @@ """Emnist mapping.""" import json from pathlib import Path -from typing import Dict, List, Optional, Sequence, Union, Tuple +from typing import Dict, List, Optional, Sequence, Tuple, Union import torch from torch import Tensor diff --git a/text_recognizer/data/transforms/load_transform.py b/text_recognizer/data/transforms/load_transform.py index cf590c1..e8c57bc 100644 --- a/text_recognizer/data/transforms/load_transform.py +++ b/text_recognizer/data/transforms/load_transform.py @@ -2,11 +2,10 @@ from pathlib import Path from typing import Callable -from loguru import logger as log -from omegaconf import OmegaConf -from omegaconf import DictConfig -from hydra.utils import instantiate import torchvision.transforms as T +from hydra.utils import instantiate +from loguru import logger as log +from omegaconf import DictConfig, OmegaConf TRANSFORM_DIRNAME = ( Path(__file__).resolve().parents[3] / "training" / "conf" / "datamodule" diff --git a/text_recognizer/data/transforms/pad.py b/text_recognizer/data/transforms/pad.py index 1da4534..df1d83f 100644 --- a/text_recognizer/data/transforms/pad.py +++ b/text_recognizer/data/transforms/pad.py @@ -1,8 +1,8 @@ """Pad targets to equal length.""" import torch -from torch import Tensor import torch.functional as F +from torch import Tensor class Pad: diff --git a/text_recognizer/data/utils/image_utils.py b/text_recognizer/data/utils/image_utils.py index d374fc3..3f4901c 100644 --- a/text_recognizer/data/utils/image_utils.py +++ b/text_recognizer/data/utils/image_utils.py @@ -4,8 +4,8 @@ from io import BytesIO from pathlib import Path from typing import Any, Union -from PIL import Image import smart_open +from PIL import Image def read_image_pil(image_uri: Union[Path, str], grayscale: bool = False) -> Image: diff --git a/text_recognizer/data/utils/sentence_generator.py b/text_recognizer/data/utils/sentence_generator.py index 8567e6d..c98d0da 100644 --- a/text_recognizer/data/utils/sentence_generator.py +++ b/text_recognizer/data/utils/sentence_generator.py @@ -5,8 +5,8 @@ import string from typing import Optional import nltk -from nltk.corpus.reader.util import ConcatenatedCorpusView import numpy as np +from nltk.corpus.reader.util import ConcatenatedCorpusView from text_recognizer.data.base_data_module import BaseDataModule diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 90752e9..f917635 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -2,12 +2,11 @@ from typing import Any, Dict, Optional, Tuple, Type import hydra +import torch from loguru import logger as log from omegaconf import DictConfig from pytorch_lightning import LightningModule -import torch -from torch import nn -from torch import Tensor +from torch import Tensor, nn from torchmetrics import Accuracy from text_recognizer.data.mappings import EmnistMapping diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index f44241d..dcec756 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,9 +1,9 @@ """Lightning model for base Transformers.""" from typing import Optional, Tuple, Type -from omegaconf import DictConfig import torch -from torch import nn, Tensor +from omegaconf import DictConfig +from torch import Tensor, nn from text_recognizer.data.mappings import EmnistMapping from text_recognizer.models.base import LitBase diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index aa72480..dc59f19 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -27,15 +27,23 @@ def get_wandb_logger(trainer: Trainer) -> WandbLogger: class WatchModel(Callback): """Make W&B watch the model at the beginning of the run.""" - def __init__(self, log: str = "gradients", log_freq: int = 100) -> None: + def __init__( + self, log: str = "gradients", log_freq: int = 100, log_graph: bool = False + ) -> None: self.log = log self.log_freq = log_freq + self.log_graph = log_graph @rank_zero_only def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: """Watches model weights with wandb.""" logger = get_wandb_logger(trainer) - logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq) + logger.watch( + model=trainer.model, + log=self.log, + log_freq=self.log_freq, + log_graph=self.log_graph, + ) class UploadConfigAsArtifact(Callback): |