diff options
24 files changed, 108 insertions, 40 deletions
diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..5c74fa2 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,65 @@ +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", +] + +# Same as Black. +line-length = 88 +indent-width = 4 + + +# Assume Python 3.8 +target-version = "py311" + +[lint] +# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. +select = ["E4", "E7", "E9", "F"] +ignore = [] + +extend-select = ["I"] + +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[format] +# Like Black, use double quotes for strings. +quote-style = "double" + + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index b5db075..4a14521 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -11,10 +11,10 @@ import numpy as np import toml from loguru import logger as log +import text_recognizer.metadata.emnist as metadata from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.base_dataset import BaseDataset, split_dataset from text_recognizer.data.utils.download_utils import download_dataset -import text_recognizer.metadata.emnist as metadata class EMNIST(BaseDataModule): diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 4ae4787..aaec93f 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -9,13 +9,13 @@ import torch from loguru import logger as log from torch import Tensor +import text_recognizer.metadata.emnist_lines as metadata from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels from text_recognizer.data.emnist import EMNIST from text_recognizer.data.tokenizer import Tokenizer from text_recognizer.data.transforms.line import LineStem from text_recognizer.data.utils.sentence_generator import SentenceGenerator -import text_recognizer.metadata.emnist_lines as metadata class EMNISTLines(BaseDataModule): diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py index 8a31205..e1f6c21 100644 --- a/text_recognizer/data/iam.py +++ b/text_recognizer/data/iam.py @@ -4,18 +4,18 @@ Which encompasses both paragraphs and lines, with associated utilities. """ import os -from pathlib import Path -from typing import Any, Dict, List import xml.etree.ElementTree as ElementTree import zipfile +from pathlib import Path +from typing import Any, Dict, List +import toml from boltons.cacheutils import cachedproperty from loguru import logger as log -import toml +import text_recognizer.metadata.iam as metadata from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.utils.download_utils import download_dataset -import text_recognizer.metadata.iam as metadata class IAM(BaseDataModule): diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 3f3500f..3f730ab 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -3,13 +3,13 @@ from typing import Callable, Optional from torch.utils.data import ConcatDataset +import text_recognizer.metadata.iam_paragraphs as metadata from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam_paragraphs import IAMParagraphs from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs -from text_recognizer.data.transforms.pad import Pad from text_recognizer.data.tokenizer import Tokenizer +from text_recognizer.data.transforms.pad import Pad from text_recognizer.data.transforms.paragraph import ParagraphStem -import text_recognizer.metadata.iam_paragraphs as metadata class IAMExtendedParagraphs(BaseDataModule): diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index 800e0d7..443386c 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -12,6 +12,7 @@ from loguru import logger as log from PIL import Image, ImageFile, ImageOps from torch import Tensor +import text_recognizer.metadata.iam_lines as metadata from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.base_dataset import ( BaseDataset, @@ -22,7 +23,6 @@ from text_recognizer.data.iam import IAM from text_recognizer.data.tokenizer import Tokenizer from text_recognizer.data.transforms.line import IamLinesStem from text_recognizer.data.utils import image_utils -import text_recognizer.metadata.iam_lines as metadata ImageFile.LOAD_TRUNCATED_IMAGES = True diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index 0d53d6b..247b54d 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -9,6 +9,7 @@ from loguru import logger as log from PIL import Image, ImageOps from tqdm import tqdm +import text_recognizer.metadata.iam_paragraphs as metadata from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.base_dataset import ( BaseDataset, @@ -16,10 +17,9 @@ from text_recognizer.data.base_dataset import ( split_dataset, ) from text_recognizer.data.iam import IAM -from text_recognizer.data.transforms.pad import Pad from text_recognizer.data.tokenizer import Tokenizer +from text_recognizer.data.transforms.pad import Pad from text_recognizer.data.transforms.paragraph import ParagraphStem -import text_recognizer.metadata.iam_paragraphs as metadata class IAMParagraphs(BaseDataModule): diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index 5743bcc..45d7904 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -6,23 +6,23 @@ import numpy as np from loguru import logger as log from PIL import Image +import text_recognizer.metadata.iam_synthetic_paragraphs as metadata from text_recognizer.data.base_data_module import load_and_print_info from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels from text_recognizer.data.iam import IAM -from text_recognizer.data.iam_paragraphs import ( - IAMParagraphs, - get_dataset_properties, - resize_image, -) from text_recognizer.data.iam_lines import ( line_crops_and_labels, load_line_crops_and_labels, save_images_and_labels, ) +from text_recognizer.data.iam_paragraphs import ( + IAMParagraphs, + get_dataset_properties, + resize_image, +) from text_recognizer.data.tokenizer import Tokenizer -from text_recognizer.data.transforms.paragraph import ParagraphStem from text_recognizer.data.transforms.pad import Pad -import text_recognizer.metadata.iam_synthetic_paragraphs as metadata +from text_recognizer.data.transforms.paragraph import ParagraphStem class IAMSyntheticParagraphs(IAMParagraphs): diff --git a/text_recognizer/data/transforms/image.py b/text_recognizer/data/transforms/image.py index f04b3a0..05c9d94 100644 --- a/text_recognizer/data/transforms/image.py +++ b/text_recognizer/data/transforms/image.py @@ -1,7 +1,7 @@ -from PIL import Image import torch -from torch import Tensor import torchvision.transforms as T +from PIL import Image +from torch import Tensor class ImageStem: diff --git a/text_recognizer/data/transforms/line.py b/text_recognizer/data/transforms/line.py index 6c38213..e4473eb 100644 --- a/text_recognizer/data/transforms/line.py +++ b/text_recognizer/data/transforms/line.py @@ -1,8 +1,8 @@ import random from typing import Any, Dict -from PIL import Image import torchvision.transforms as T +from PIL import Image import text_recognizer.metadata.iam_lines as metadata from text_recognizer.data.transforms.image import ImageStem diff --git a/text_recognizer/data/transforms/paragraph.py b/text_recognizer/data/transforms/paragraph.py index b364f91..639bb59 100644 --- a/text_recognizer/data/transforms/paragraph.py +++ b/text_recognizer/data/transforms/paragraph.py @@ -4,7 +4,6 @@ import torchvision.transforms as T import text_recognizer.metadata.iam_paragraphs as metadata from text_recognizer.data.transforms.image import ImageStem - IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH IMAGE_SHAPE = metadata.IMAGE_SHAPE diff --git a/text_recognizer/data/utils/sentence_generator.py b/text_recognizer/data/utils/sentence_generator.py index c40373d..8ea345a 100644 --- a/text_recognizer/data/utils/sentence_generator.py +++ b/text_recognizer/data/utils/sentence_generator.py @@ -5,8 +5,8 @@ import string from typing import Optional import nltk -from nltk.corpus.reader.util import ConcatenatedCorpusView import numpy as np +from nltk.corpus.reader.util import ConcatenatedCorpusView import text_recognizer.metadata.shared as metadata diff --git a/text_recognizer/metadata/iam_paragraphs.py b/text_recognizer/metadata/iam_paragraphs.py index 7bb909a..57aa2dc 100644 --- a/text_recognizer/metadata/iam_paragraphs.py +++ b/text_recognizer/metadata/iam_paragraphs.py @@ -1,7 +1,6 @@ import text_recognizer.metadata.emnist as emnist import text_recognizer.metadata.shared as shared - PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_paragraphs" NEW_LINE_TOKEN = "\n" diff --git a/text_recognizer/model/base.py b/text_recognizer/model/base.py index 9a751bf..c662583 100644 --- a/text_recognizer/model/base.py +++ b/text_recognizer/model/base.py @@ -2,11 +2,11 @@ from typing import Any, Dict, Optional, Tuple, Type import hydra +import pytorch_lightning as L import torch from loguru import logger as log from omegaconf import DictConfig -import pytorch_lightning as L -from torch import nn, Tensor +from torch import Tensor, nn from text_recognizer.data.tokenizer import Tokenizer diff --git a/text_recognizer/model/transformer.py b/text_recognizer/model/transformer.py index 5842bdb..783e134 100644 --- a/text_recognizer/model/transformer.py +++ b/text_recognizer/model/transformer.py @@ -3,11 +3,12 @@ from typing import Callable, Optional, Tuple, Type import torch from omegaconf import DictConfig -from torch import nn, Tensor +from torch import Tensor, nn from torchmetrics import CharErrorRate, WordErrorRate -from text_recognizer.decoder.greedy_decoder import GreedyDecoder from text_recognizer.data.tokenizer import Tokenizer +from text_recognizer.decoder.greedy_decoder import GreedyDecoder + from .base import LitBase diff --git a/text_recognizer/network/convnext/convnext.py b/text_recognizer/network/convnext/convnext.py index 8eea9df..9d2b3ec 100644 --- a/text_recognizer/network/convnext/convnext.py +++ b/text_recognizer/network/convnext/convnext.py @@ -4,9 +4,9 @@ from typing import Optional, Sequence import torch from torch import Tensor, nn -from .transformer import Transformer from .downsample import Downsample from .norm import LayerNorm +from .transformer import Transformer class GRN(nn.Module): diff --git a/text_recognizer/network/transformer/embedding/l2_norm.py b/text_recognizer/network/transformer/embedding/l2_norm.py index 0e48bca..f5ec4ba 100644 --- a/text_recognizer/network/transformer/embedding/l2_norm.py +++ b/text_recognizer/network/transformer/embedding/l2_norm.py @@ -1,5 +1,5 @@ -from einops import rearrange import torch.nn.functional as F +from einops import rearrange from torch import Tensor diff --git a/text_recognizer/network/transformer/embedding/rotary.py b/text_recognizer/network/transformer/embedding/rotary.py index 2254f81..7e88264 100644 --- a/text_recognizer/network/transformer/embedding/rotary.py +++ b/text_recognizer/network/transformer/embedding/rotary.py @@ -1,6 +1,6 @@ import torch -from torch import nn, einsum from einops import rearrange +from torch import einsum, nn class RotaryEmbedding(nn.Module): diff --git a/text_recognizer/network/transformer/embedding/token.py b/text_recognizer/network/transformer/embedding/token.py index 838f514..bf5149e 100644 --- a/text_recognizer/network/transformer/embedding/token.py +++ b/text_recognizer/network/transformer/embedding/token.py @@ -1,4 +1,4 @@ -from torch import nn, Tensor +from torch import Tensor, nn from .l2_norm import l2_norm diff --git a/training/artifacts.py b/training/artifacts.py index 2cb8bda..f7af085 100644 --- a/training/artifacts.py +++ b/training/artifacts.py @@ -1,14 +1,15 @@ """Fetches model artifacts from wandb.""" -from datetime import datetime -from pathlib import Path import shutil import sys +from datetime import datetime +from pathlib import Path from typing import Optional import click from loguru import logger as log -from training import metadata + import wandb +from training import metadata from wandb.apis.public import Run diff --git a/training/callbacks/wandb.py b/training/callbacks/wandb.py index 6adbebe..0841000 100644 --- a/training/callbacks/wandb.py +++ b/training/callbacks/wandb.py @@ -2,11 +2,12 @@ from pathlib import Path from typing import Tuple -import wandb -from torch import Tensor from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.utilities import rank_zero_only +from torch import Tensor + +import wandb def get_wandb_logger(trainer: Trainer) -> WandbLogger: diff --git a/training/main.py b/training/main.py index c36e397..17b961c 100644 --- a/training/main.py +++ b/training/main.py @@ -1,6 +1,7 @@ """Loads config with hydra and runs experiment.""" import hydra from omegaconf import DictConfig + from training.metadata import TRAINING_DIR diff --git a/training/run.py b/training/run.py index b8f2178..2089a86 100644 --- a/training/run.py +++ b/training/run.py @@ -2,19 +2,19 @@ from typing import Callable, List, Optional, Type import hydra +import utils from loguru import logger as log from omegaconf import DictConfig from pytorch_lightning import ( Callback, LightningDataModule, LightningModule, - seed_everything, Trainer, + seed_everything, ) from pytorch_lightning.loggers import Logger from torch import nn from torchinfo import summary -import utils def run(config: DictConfig) -> Optional[float]: diff --git a/training/utils.py b/training/utils.py index d1801a7..f0a0d4d 100644 --- a/training/utils.py +++ b/training/utils.py @@ -1,6 +1,6 @@ """Util functions for training with hydra and pytorch lightning.""" -from typing import List, Type import warnings +from typing import List, Type import hydra from loguru import logger as log @@ -14,6 +14,7 @@ from pytorch_lightning.loggers import Logger from pytorch_lightning.loggers.wandb import WandbLogger from pytorch_lightning.utilities import rank_zero_only from tqdm import tqdm + import wandb |