summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/data/emnist.py14
-rw-r--r--text_recognizer/data/emnist_lines.py10
-rw-r--r--text_recognizer/data/iam.py7
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py14
-rw-r--r--text_recognizer/data/iam_lines.py3
-rw-r--r--text_recognizer/data/iam_paragraphs.py5
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py12
-rw-r--r--text_recognizer/data/mappings/emnist.py2
-rw-r--r--text_recognizer/data/transforms/load_transform.py7
-rw-r--r--text_recognizer/data/transforms/pad.py2
-rw-r--r--text_recognizer/data/utils/image_utils.py2
-rw-r--r--text_recognizer/data/utils/sentence_generator.py2
-rw-r--r--text_recognizer/models/base.py5
-rw-r--r--text_recognizer/models/transformer.py4
-rw-r--r--training/callbacks/wandb_callbacks.py12
15 files changed, 49 insertions, 52 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index e551543..72cc80a 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -1,25 +1,21 @@
"""EMNIST dataset: downloads it from FSDL aws url if not present."""
import json
import os
-from pathlib import Path
import shutil
-from typing import Sequence, Tuple
import zipfile
+from pathlib import Path
+from typing import Optional, Sequence, Tuple
import h5py
-from loguru import logger as log
import numpy as np
import toml
+from loguru import logger as log
-from text_recognizer.data.base_data_module import (
- BaseDataModule,
- load_and_print_info,
-)
+from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.base_dataset import BaseDataset, split_dataset
from text_recognizer.data.transforms.load_transform import load_transform_from_file
from text_recognizer.data.utils.download_utils import download_dataset
-
SEED = 4711
NUM_SPECIAL_TOKENS = 4
SAMPLE_TO_BALANCE = True
@@ -55,7 +51,7 @@ class EMNIST(BaseDataModule):
if not PROCESSED_DATA_FILENAME.exists():
download_and_process_emnist()
- def setup(self, stage: str = None) -> None:
+ def setup(self, stage: Optional[str] = None) -> None:
"""Loads the dataset specified by the stage."""
if stage == "fit" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index b33c9e1..c36132e 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -4,23 +4,19 @@ from pathlib import Path
from typing import Callable, DefaultDict, List, Optional, Tuple, Type
import h5py
-from loguru import logger as log
import numpy as np
import torch
-from torch import Tensor
import torchvision.transforms as T
+from loguru import logger as log
+from torch import Tensor
-from text_recognizer.data.base_data_module import (
- BaseDataModule,
- load_and_print_info,
-)
+from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels
from text_recognizer.data.emnist import EMNIST
from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.transforms.load_transform import load_transform_from_file
from text_recognizer.data.utils.sentence_generator import SentenceGenerator
-
DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist_lines"
ESSENTIALS_FILENAME = (
Path(__file__).parents[0].resolve() / "mappings" / "emnist_lines_essentials.json"
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
index a4d1d21..e3baf88 100644
--- a/text_recognizer/data/iam.py
+++ b/text_recognizer/data/iam.py
@@ -4,19 +4,18 @@ Which encompasses both paragraphs and lines, with associated utilities.
"""
import os
-from pathlib import Path
-from typing import Any, Dict, List
import xml.etree.ElementTree as ElementTree
import zipfile
+from pathlib import Path
+from typing import Any, Dict, List
+import toml
from boltons.cacheutils import cachedproperty
from loguru import logger as log
-import toml
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.utils.download_utils import download_dataset
-
RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "iam"
METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "iam"
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 90df5f8..3ec8221 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -1,11 +1,12 @@
"""IAM original and sythetic dataset class."""
from typing import Callable, Optional
+
from torch.utils.data import ConcatDataset
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.iam_paragraphs import IAMParagraphs
-from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
+from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.transforms.load_transform import load_transform_from_file
@@ -87,10 +88,15 @@ class IAMExtendedParagraphs(BaseDataModule):
x = x[0] if isinstance(x, list) else x
xt = xt[0] if isinstance(xt, list) else xt
data = (
- f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
- f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
+ "Train/val/test sizes: "
+ f"{len(self.data_train)}, "
+ f"{len(self.data_val)}, "
+ f"{len(self.data_test)}\n"
+ "Train Batch x stats: "
+ f"{(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
- f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n"
+ f"Test Batch x stats: "
+ f"{(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n"
f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n"
)
return basic + data
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index 4899a48..a55ff1c 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -7,8 +7,8 @@ import json
from pathlib import Path
from typing import Callable, List, Optional, Sequence, Tuple, Type
-from loguru import logger as log
import numpy as np
+from loguru import logger as log
from PIL import Image, ImageFile, ImageOps
from torch import Tensor
@@ -23,7 +23,6 @@ from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.transforms.load_transform import load_transform_from_file
from text_recognizer.data.utils import image_utils
-
ImageFile.LOAD_TRUNCATED_IMAGES = True
SEED = 4711
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 3bf28ff..c7d5229 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -3,10 +3,10 @@ import json
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Tuple
-from loguru import logger as log
import numpy as np
-from PIL import Image, ImageOps
import torchvision.transforms as T
+from loguru import logger as log
+from PIL import Image, ImageOps
from tqdm import tqdm
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
@@ -19,7 +19,6 @@ from text_recognizer.data.iam import IAM
from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.transforms.load_transform import load_transform_from_file
-
PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_paragraphs"
NEW_LINE_TOKEN = "\n"
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index d51e010..5e66499 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -2,15 +2,12 @@
import random
from typing import Any, Callable, List, Optional, Sequence, Tuple
-from loguru import logger as log
import numpy as np
+from loguru import logger as log
from PIL import Image
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
-from text_recognizer.data.base_dataset import (
- BaseDataset,
- convert_strings_to_labels,
-)
+from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels
from text_recognizer.data.iam import IAM
from text_recognizer.data.iam_lines import (
line_crops_and_labels,
@@ -18,16 +15,15 @@ from text_recognizer.data.iam_lines import (
save_images_and_labels,
)
from text_recognizer.data.iam_paragraphs import (
- get_dataset_properties,
- IAMParagraphs,
IMAGE_SCALE_FACTOR,
NEW_LINE_TOKEN,
+ IAMParagraphs,
+ get_dataset_properties,
resize_image,
)
from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.transforms.load_transform import load_transform_from_file
-
PROCESSED_DATA_DIRNAME = (
BaseDataModule.data_dirname() / "processed" / "iam_synthetic_paragraphs"
)
diff --git a/text_recognizer/data/mappings/emnist.py b/text_recognizer/data/mappings/emnist.py
index 03465a1..331976e 100644
--- a/text_recognizer/data/mappings/emnist.py
+++ b/text_recognizer/data/mappings/emnist.py
@@ -1,7 +1,7 @@
"""Emnist mapping."""
import json
from pathlib import Path
-from typing import Dict, List, Optional, Sequence, Union, Tuple
+from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
from torch import Tensor
diff --git a/text_recognizer/data/transforms/load_transform.py b/text_recognizer/data/transforms/load_transform.py
index cf590c1..e8c57bc 100644
--- a/text_recognizer/data/transforms/load_transform.py
+++ b/text_recognizer/data/transforms/load_transform.py
@@ -2,11 +2,10 @@
from pathlib import Path
from typing import Callable
-from loguru import logger as log
-from omegaconf import OmegaConf
-from omegaconf import DictConfig
-from hydra.utils import instantiate
import torchvision.transforms as T
+from hydra.utils import instantiate
+from loguru import logger as log
+from omegaconf import DictConfig, OmegaConf
TRANSFORM_DIRNAME = (
Path(__file__).resolve().parents[3] / "training" / "conf" / "datamodule"
diff --git a/text_recognizer/data/transforms/pad.py b/text_recognizer/data/transforms/pad.py
index 1da4534..df1d83f 100644
--- a/text_recognizer/data/transforms/pad.py
+++ b/text_recognizer/data/transforms/pad.py
@@ -1,8 +1,8 @@
"""Pad targets to equal length."""
import torch
-from torch import Tensor
import torch.functional as F
+from torch import Tensor
class Pad:
diff --git a/text_recognizer/data/utils/image_utils.py b/text_recognizer/data/utils/image_utils.py
index d374fc3..3f4901c 100644
--- a/text_recognizer/data/utils/image_utils.py
+++ b/text_recognizer/data/utils/image_utils.py
@@ -4,8 +4,8 @@ from io import BytesIO
from pathlib import Path
from typing import Any, Union
-from PIL import Image
import smart_open
+from PIL import Image
def read_image_pil(image_uri: Union[Path, str], grayscale: bool = False) -> Image:
diff --git a/text_recognizer/data/utils/sentence_generator.py b/text_recognizer/data/utils/sentence_generator.py
index 8567e6d..c98d0da 100644
--- a/text_recognizer/data/utils/sentence_generator.py
+++ b/text_recognizer/data/utils/sentence_generator.py
@@ -5,8 +5,8 @@ import string
from typing import Optional
import nltk
-from nltk.corpus.reader.util import ConcatenatedCorpusView
import numpy as np
+from nltk.corpus.reader.util import ConcatenatedCorpusView
from text_recognizer.data.base_data_module import BaseDataModule
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 90752e9..f917635 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -2,12 +2,11 @@
from typing import Any, Dict, Optional, Tuple, Type
import hydra
+import torch
from loguru import logger as log
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
-import torch
-from torch import nn
-from torch import Tensor
+from torch import Tensor, nn
from torchmetrics import Accuracy
from text_recognizer.data.mappings import EmnistMapping
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index f44241d..dcec756 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,9 +1,9 @@
"""Lightning model for base Transformers."""
from typing import Optional, Tuple, Type
-from omegaconf import DictConfig
import torch
-from torch import nn, Tensor
+from omegaconf import DictConfig
+from torch import Tensor, nn
from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.models.base import LitBase
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py
index aa72480..dc59f19 100644
--- a/training/callbacks/wandb_callbacks.py
+++ b/training/callbacks/wandb_callbacks.py
@@ -27,15 +27,23 @@ def get_wandb_logger(trainer: Trainer) -> WandbLogger:
class WatchModel(Callback):
"""Make W&B watch the model at the beginning of the run."""
- def __init__(self, log: str = "gradients", log_freq: int = 100) -> None:
+ def __init__(
+ self, log: str = "gradients", log_freq: int = 100, log_graph: bool = False
+ ) -> None:
self.log = log
self.log_freq = log_freq
+ self.log_graph = log_graph
@rank_zero_only
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Watches model weights with wandb."""
logger = get_wandb_logger(trainer)
- logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq)
+ logger.watch(
+ model=trainer.model,
+ log=self.log,
+ log_freq=self.log_freq,
+ log_graph=self.log_graph,
+ )
class UploadConfigAsArtifact(Callback):