summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-15 21:47:54 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-15 21:47:54 +0200
commit9ce21f569ecac03f15f2ad669fde3dd4a512f8cc (patch)
treee6f5bbf4cfe758788fd6ad3679b714d4ecfad568
parenta93dcc5b9c8160a441c5b5f99f2f59264778ef91 (diff)
Format
-rw-r--r--ruff.toml65
-rw-r--r--text_recognizer/data/emnist.py2
-rw-r--r--text_recognizer/data/emnist_lines.py2
-rw-r--r--text_recognizer/data/iam.py8
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py4
-rw-r--r--text_recognizer/data/iam_lines.py2
-rw-r--r--text_recognizer/data/iam_paragraphs.py4
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py14
-rw-r--r--text_recognizer/data/transforms/image.py4
-rw-r--r--text_recognizer/data/transforms/line.py2
-rw-r--r--text_recognizer/data/transforms/paragraph.py1
-rw-r--r--text_recognizer/data/utils/sentence_generator.py2
-rw-r--r--text_recognizer/metadata/iam_paragraphs.py1
-rw-r--r--text_recognizer/model/base.py4
-rw-r--r--text_recognizer/model/transformer.py5
-rw-r--r--text_recognizer/network/convnext/convnext.py2
-rw-r--r--text_recognizer/network/transformer/embedding/l2_norm.py2
-rw-r--r--text_recognizer/network/transformer/embedding/rotary.py2
-rw-r--r--text_recognizer/network/transformer/embedding/token.py2
-rw-r--r--training/artifacts.py7
-rw-r--r--training/callbacks/wandb.py5
-rw-r--r--training/main.py1
-rw-r--r--training/run.py4
-rw-r--r--training/utils.py3
24 files changed, 108 insertions, 40 deletions
diff --git a/ruff.toml b/ruff.toml
new file mode 100644
index 0000000..5c74fa2
--- /dev/null
+++ b/ruff.toml
@@ -0,0 +1,65 @@
+# Exclude a variety of commonly ignored directories.
+exclude = [
+ ".bzr",
+ ".direnv",
+ ".eggs",
+ ".git",
+ ".git-rewrite",
+ ".hg",
+ ".ipynb_checkpoints",
+ ".mypy_cache",
+ ".nox",
+ ".pants.d",
+ ".pyenv",
+ ".pytest_cache",
+ ".pytype",
+ ".ruff_cache",
+ ".svn",
+ ".tox",
+ ".venv",
+ ".vscode",
+ "__pypackages__",
+ "_build",
+ "buck-out",
+ "build",
+ "dist",
+ "node_modules",
+ "site-packages",
+ "venv",
+]
+
+# Same as Black.
+line-length = 88
+indent-width = 4
+
+
+# Assume Python 3.8
+target-version = "py311"
+
+[lint]
+# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
+select = ["E4", "E7", "E9", "F"]
+ignore = []
+
+extend-select = ["I"]
+
+# Allow fix for all enabled rules (when `--fix`) is provided.
+fixable = ["ALL"]
+unfixable = []
+
+# Allow unused variables when underscore-prefixed.
+dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
+
+[format]
+# Like Black, use double quotes for strings.
+quote-style = "double"
+
+
+# Like Black, indent with spaces, rather than tabs.
+indent-style = "space"
+
+# Like Black, respect magic trailing commas.
+skip-magic-trailing-comma = false
+
+# Like Black, automatically detect the appropriate line ending.
+line-ending = "auto"
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index b5db075..4a14521 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -11,10 +11,10 @@ import numpy as np
import toml
from loguru import logger as log
+import text_recognizer.metadata.emnist as metadata
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.base_dataset import BaseDataset, split_dataset
from text_recognizer.data.utils.download_utils import download_dataset
-import text_recognizer.metadata.emnist as metadata
class EMNIST(BaseDataModule):
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 4ae4787..aaec93f 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -9,13 +9,13 @@ import torch
from loguru import logger as log
from torch import Tensor
+import text_recognizer.metadata.emnist_lines as metadata
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels
from text_recognizer.data.emnist import EMNIST
from text_recognizer.data.tokenizer import Tokenizer
from text_recognizer.data.transforms.line import LineStem
from text_recognizer.data.utils.sentence_generator import SentenceGenerator
-import text_recognizer.metadata.emnist_lines as metadata
class EMNISTLines(BaseDataModule):
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
index 8a31205..e1f6c21 100644
--- a/text_recognizer/data/iam.py
+++ b/text_recognizer/data/iam.py
@@ -4,18 +4,18 @@ Which encompasses both paragraphs and lines, with associated utilities.
"""
import os
-from pathlib import Path
-from typing import Any, Dict, List
import xml.etree.ElementTree as ElementTree
import zipfile
+from pathlib import Path
+from typing import Any, Dict, List
+import toml
from boltons.cacheutils import cachedproperty
from loguru import logger as log
-import toml
+import text_recognizer.metadata.iam as metadata
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.utils.download_utils import download_dataset
-import text_recognizer.metadata.iam as metadata
class IAM(BaseDataModule):
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 3f3500f..3f730ab 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -3,13 +3,13 @@ from typing import Callable, Optional
from torch.utils.data import ConcatDataset
+import text_recognizer.metadata.iam_paragraphs as metadata
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.iam_paragraphs import IAMParagraphs
from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
-from text_recognizer.data.transforms.pad import Pad
from text_recognizer.data.tokenizer import Tokenizer
+from text_recognizer.data.transforms.pad import Pad
from text_recognizer.data.transforms.paragraph import ParagraphStem
-import text_recognizer.metadata.iam_paragraphs as metadata
class IAMExtendedParagraphs(BaseDataModule):
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index 800e0d7..443386c 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -12,6 +12,7 @@ from loguru import logger as log
from PIL import Image, ImageFile, ImageOps
from torch import Tensor
+import text_recognizer.metadata.iam_lines as metadata
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.base_dataset import (
BaseDataset,
@@ -22,7 +23,6 @@ from text_recognizer.data.iam import IAM
from text_recognizer.data.tokenizer import Tokenizer
from text_recognizer.data.transforms.line import IamLinesStem
from text_recognizer.data.utils import image_utils
-import text_recognizer.metadata.iam_lines as metadata
ImageFile.LOAD_TRUNCATED_IMAGES = True
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 0d53d6b..247b54d 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -9,6 +9,7 @@ from loguru import logger as log
from PIL import Image, ImageOps
from tqdm import tqdm
+import text_recognizer.metadata.iam_paragraphs as metadata
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.base_dataset import (
BaseDataset,
@@ -16,10 +17,9 @@ from text_recognizer.data.base_dataset import (
split_dataset,
)
from text_recognizer.data.iam import IAM
-from text_recognizer.data.transforms.pad import Pad
from text_recognizer.data.tokenizer import Tokenizer
+from text_recognizer.data.transforms.pad import Pad
from text_recognizer.data.transforms.paragraph import ParagraphStem
-import text_recognizer.metadata.iam_paragraphs as metadata
class IAMParagraphs(BaseDataModule):
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index 5743bcc..45d7904 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -6,23 +6,23 @@ import numpy as np
from loguru import logger as log
from PIL import Image
+import text_recognizer.metadata.iam_synthetic_paragraphs as metadata
from text_recognizer.data.base_data_module import load_and_print_info
from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels
from text_recognizer.data.iam import IAM
-from text_recognizer.data.iam_paragraphs import (
- IAMParagraphs,
- get_dataset_properties,
- resize_image,
-)
from text_recognizer.data.iam_lines import (
line_crops_and_labels,
load_line_crops_and_labels,
save_images_and_labels,
)
+from text_recognizer.data.iam_paragraphs import (
+ IAMParagraphs,
+ get_dataset_properties,
+ resize_image,
+)
from text_recognizer.data.tokenizer import Tokenizer
-from text_recognizer.data.transforms.paragraph import ParagraphStem
from text_recognizer.data.transforms.pad import Pad
-import text_recognizer.metadata.iam_synthetic_paragraphs as metadata
+from text_recognizer.data.transforms.paragraph import ParagraphStem
class IAMSyntheticParagraphs(IAMParagraphs):
diff --git a/text_recognizer/data/transforms/image.py b/text_recognizer/data/transforms/image.py
index f04b3a0..05c9d94 100644
--- a/text_recognizer/data/transforms/image.py
+++ b/text_recognizer/data/transforms/image.py
@@ -1,7 +1,7 @@
-from PIL import Image
import torch
-from torch import Tensor
import torchvision.transforms as T
+from PIL import Image
+from torch import Tensor
class ImageStem:
diff --git a/text_recognizer/data/transforms/line.py b/text_recognizer/data/transforms/line.py
index 6c38213..e4473eb 100644
--- a/text_recognizer/data/transforms/line.py
+++ b/text_recognizer/data/transforms/line.py
@@ -1,8 +1,8 @@
import random
from typing import Any, Dict
-from PIL import Image
import torchvision.transforms as T
+from PIL import Image
import text_recognizer.metadata.iam_lines as metadata
from text_recognizer.data.transforms.image import ImageStem
diff --git a/text_recognizer/data/transforms/paragraph.py b/text_recognizer/data/transforms/paragraph.py
index b364f91..639bb59 100644
--- a/text_recognizer/data/transforms/paragraph.py
+++ b/text_recognizer/data/transforms/paragraph.py
@@ -4,7 +4,6 @@ import torchvision.transforms as T
import text_recognizer.metadata.iam_paragraphs as metadata
from text_recognizer.data.transforms.image import ImageStem
-
IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH
IMAGE_SHAPE = metadata.IMAGE_SHAPE
diff --git a/text_recognizer/data/utils/sentence_generator.py b/text_recognizer/data/utils/sentence_generator.py
index c40373d..8ea345a 100644
--- a/text_recognizer/data/utils/sentence_generator.py
+++ b/text_recognizer/data/utils/sentence_generator.py
@@ -5,8 +5,8 @@ import string
from typing import Optional
import nltk
-from nltk.corpus.reader.util import ConcatenatedCorpusView
import numpy as np
+from nltk.corpus.reader.util import ConcatenatedCorpusView
import text_recognizer.metadata.shared as metadata
diff --git a/text_recognizer/metadata/iam_paragraphs.py b/text_recognizer/metadata/iam_paragraphs.py
index 7bb909a..57aa2dc 100644
--- a/text_recognizer/metadata/iam_paragraphs.py
+++ b/text_recognizer/metadata/iam_paragraphs.py
@@ -1,7 +1,6 @@
import text_recognizer.metadata.emnist as emnist
import text_recognizer.metadata.shared as shared
-
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_paragraphs"
NEW_LINE_TOKEN = "\n"
diff --git a/text_recognizer/model/base.py b/text_recognizer/model/base.py
index 9a751bf..c662583 100644
--- a/text_recognizer/model/base.py
+++ b/text_recognizer/model/base.py
@@ -2,11 +2,11 @@
from typing import Any, Dict, Optional, Tuple, Type
import hydra
+import pytorch_lightning as L
import torch
from loguru import logger as log
from omegaconf import DictConfig
-import pytorch_lightning as L
-from torch import nn, Tensor
+from torch import Tensor, nn
from text_recognizer.data.tokenizer import Tokenizer
diff --git a/text_recognizer/model/transformer.py b/text_recognizer/model/transformer.py
index 5842bdb..783e134 100644
--- a/text_recognizer/model/transformer.py
+++ b/text_recognizer/model/transformer.py
@@ -3,11 +3,12 @@ from typing import Callable, Optional, Tuple, Type
import torch
from omegaconf import DictConfig
-from torch import nn, Tensor
+from torch import Tensor, nn
from torchmetrics import CharErrorRate, WordErrorRate
-from text_recognizer.decoder.greedy_decoder import GreedyDecoder
from text_recognizer.data.tokenizer import Tokenizer
+from text_recognizer.decoder.greedy_decoder import GreedyDecoder
+
from .base import LitBase
diff --git a/text_recognizer/network/convnext/convnext.py b/text_recognizer/network/convnext/convnext.py
index 8eea9df..9d2b3ec 100644
--- a/text_recognizer/network/convnext/convnext.py
+++ b/text_recognizer/network/convnext/convnext.py
@@ -4,9 +4,9 @@ from typing import Optional, Sequence
import torch
from torch import Tensor, nn
-from .transformer import Transformer
from .downsample import Downsample
from .norm import LayerNorm
+from .transformer import Transformer
class GRN(nn.Module):
diff --git a/text_recognizer/network/transformer/embedding/l2_norm.py b/text_recognizer/network/transformer/embedding/l2_norm.py
index 0e48bca..f5ec4ba 100644
--- a/text_recognizer/network/transformer/embedding/l2_norm.py
+++ b/text_recognizer/network/transformer/embedding/l2_norm.py
@@ -1,5 +1,5 @@
-from einops import rearrange
import torch.nn.functional as F
+from einops import rearrange
from torch import Tensor
diff --git a/text_recognizer/network/transformer/embedding/rotary.py b/text_recognizer/network/transformer/embedding/rotary.py
index 2254f81..7e88264 100644
--- a/text_recognizer/network/transformer/embedding/rotary.py
+++ b/text_recognizer/network/transformer/embedding/rotary.py
@@ -1,6 +1,6 @@
import torch
-from torch import nn, einsum
from einops import rearrange
+from torch import einsum, nn
class RotaryEmbedding(nn.Module):
diff --git a/text_recognizer/network/transformer/embedding/token.py b/text_recognizer/network/transformer/embedding/token.py
index 838f514..bf5149e 100644
--- a/text_recognizer/network/transformer/embedding/token.py
+++ b/text_recognizer/network/transformer/embedding/token.py
@@ -1,4 +1,4 @@
-from torch import nn, Tensor
+from torch import Tensor, nn
from .l2_norm import l2_norm
diff --git a/training/artifacts.py b/training/artifacts.py
index 2cb8bda..f7af085 100644
--- a/training/artifacts.py
+++ b/training/artifacts.py
@@ -1,14 +1,15 @@
"""Fetches model artifacts from wandb."""
-from datetime import datetime
-from pathlib import Path
import shutil
import sys
+from datetime import datetime
+from pathlib import Path
from typing import Optional
import click
from loguru import logger as log
-from training import metadata
+
import wandb
+from training import metadata
from wandb.apis.public import Run
diff --git a/training/callbacks/wandb.py b/training/callbacks/wandb.py
index 6adbebe..0841000 100644
--- a/training/callbacks/wandb.py
+++ b/training/callbacks/wandb.py
@@ -2,11 +2,12 @@
from pathlib import Path
from typing import Tuple
-import wandb
-from torch import Tensor
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities import rank_zero_only
+from torch import Tensor
+
+import wandb
def get_wandb_logger(trainer: Trainer) -> WandbLogger:
diff --git a/training/main.py b/training/main.py
index c36e397..17b961c 100644
--- a/training/main.py
+++ b/training/main.py
@@ -1,6 +1,7 @@
"""Loads config with hydra and runs experiment."""
import hydra
from omegaconf import DictConfig
+
from training.metadata import TRAINING_DIR
diff --git a/training/run.py b/training/run.py
index b8f2178..2089a86 100644
--- a/training/run.py
+++ b/training/run.py
@@ -2,19 +2,19 @@
from typing import Callable, List, Optional, Type
import hydra
+import utils
from loguru import logger as log
from omegaconf import DictConfig
from pytorch_lightning import (
Callback,
LightningDataModule,
LightningModule,
- seed_everything,
Trainer,
+ seed_everything,
)
from pytorch_lightning.loggers import Logger
from torch import nn
from torchinfo import summary
-import utils
def run(config: DictConfig) -> Optional[float]:
diff --git a/training/utils.py b/training/utils.py
index d1801a7..f0a0d4d 100644
--- a/training/utils.py
+++ b/training/utils.py
@@ -1,6 +1,6 @@
"""Util functions for training with hydra and pytorch lightning."""
-from typing import List, Type
import warnings
+from typing import List, Type
import hydra
from loguru import logger as log
@@ -14,6 +14,7 @@ from pytorch_lightning.loggers import Logger
from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.utilities import rank_zero_only
from tqdm import tqdm
+
import wandb