diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/data/emnist.py | 4 | ||||
-rw-r--r-- | text_recognizer/data/iam_lines.py | 14 | ||||
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 20 | ||||
-rw-r--r-- | text_recognizer/data/iam_synthetic_paragraphs.py | 4 | ||||
-rw-r--r-- | text_recognizer/models/dino.py | 28 | ||||
-rw-r--r-- | text_recognizer/networks/encoders/efficientnet/efficientnet.py | 4 |
6 files changed, 50 insertions, 24 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index bf3faec..824b947 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -10,7 +10,7 @@ import h5py from loguru import logger import numpy as np import toml -from torchvision import transforms +import torchvision.transforms as T from text_recognizer.data.base_data_module import ( BaseDataModule, @@ -53,7 +53,7 @@ class EMNIST(BaseDataModule): self.data_train = None self.data_val = None self.data_test = None - self.transform = transforms.Compose([transforms.ToTensor()]) + self.transform = T.Compose([T.ToTensor()]) self.dims = (1, *self.input_shape) self.output_dims = (1,) diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index 78bc8e1..9c78a22 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -13,7 +13,7 @@ from loguru import logger from PIL import Image, ImageFile, ImageOps import numpy as np from torch import Tensor -from torchvision import transforms +import torchvision.transforms as T from torchvision.transforms.functional import InterpolationMode from text_recognizer.data.base_dataset import ( @@ -208,7 +208,7 @@ def load_line_crops_and_labels(split: str, data_dirname: Path) -> Tuple[List, Li return crops, labels -def get_transform(image_width: int, augment: bool = False) -> transforms.Compose: +def get_transform(image_width: int, augment: bool = False) -> T.Compose: """Augment with brigthness, sligth rotation, slant, translation, scale, and Gaussian noise.""" def embed_crop( @@ -237,20 +237,20 @@ def get_transform(image_width: int, augment: bool = False) -> transforms.Compose return image - transfroms_list = [transforms.Lambda(embed_crop)] + transfroms_list = [T.Lambda(embed_crop)] if augment: transfroms_list += [ - transforms.ColorJitter(brightness=(0.8, 1.6)), - transforms.RandomAffine( + T.ColorJitter(brightness=(0.8, 1.6)), + T.RandomAffine( degrees=1, shear=(-30, 20), interpolation=InterpolationMode.BILINEAR, fill=0, ), ] - transfroms_list.append(transforms.ToTensor()) - return transforms.Compose(transfroms_list) + transfroms_list.append(T.ToTensor()) + return T.Compose(transfroms_list) def generate_iam_lines() -> None: diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index 24409bc..6022804 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Sequence, Tuple from loguru import logger import numpy as np from PIL import Image, ImageOps -import torchvision.transforms as transforms +import torchvision.transforms as T from torchvision.transforms.functional import InterpolationMode from tqdm import tqdm @@ -270,31 +270,31 @@ def _load_processed_crops_and_labels( return ordered_crops, ordered_labels -def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Compose: +def get_transform(image_shape: Tuple[int, int], augment: bool) -> T.Compose: """Get transformations for images.""" if augment: transforms_list = [ - transforms.RandomCrop( + T.RandomCrop( size=image_shape, padding=None, pad_if_needed=True, fill=0, padding_mode="constant", ), - transforms.ColorJitter(brightness=(0.8, 1.6)), - transforms.RandomAffine( + T.ColorJitter(brightness=(0.8, 1.6)), + T.RandomAffine( degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR, ), ] else: - transforms_list = [transforms.CenterCrop(image_shape)] - transforms_list.append(transforms.ToTensor()) - return transforms.Compose(transforms_list) + transforms_list = [T.CenterCrop(image_shape)] + transforms_list.append(T.ToTensor()) + return T.Compose(transforms_list) -def get_target_transform(word_pieces: bool) -> Optional[transforms.Compose]: +def get_target_transform(word_pieces: bool) -> Optional[T.Compose]: """Transform emnist characters to word pieces.""" - return transforms.Compose([WordPiece()]) if word_pieces else None + return T.Compose([WordPiece()]) if word_pieces else None def _labels_filename(split: str) -> Path: diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index ad6fa25..00fa2b6 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -1,8 +1,5 @@ """IAM Synthetic Paragraphs Dataset class.""" -import itertools -from pathlib import Path import random -import time from typing import Any, List, Sequence, Tuple from loguru import logger @@ -12,7 +9,6 @@ from PIL import Image from text_recognizer.data.base_dataset import ( BaseDataset, convert_strings_to_labels, - split_dataset, ) from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam_paragraphs import ( diff --git a/text_recognizer/models/dino.py b/text_recognizer/models/dino.py new file mode 100644 index 0000000..dca954c --- /dev/null +++ b/text_recognizer/models/dino.py @@ -0,0 +1,28 @@ +"""Dino: pretraining of models with self supervision.""" +import copy +from functools import wraps, partial + +import torch +from torch import nn +import torch.nn.funtional as F +import torchvision.transforms as T +import wandb + +from text_recognizer.models.base import LitBaseModel + + +def singleton(cache_key): + def inner_fn(fn): + @wraps(fn) + def wrapper(self, *args, **kwargs): + instance = getattr(self, cache_key) + if instance is not None: + return instance + + instance = fn(self, *args, **kwargs) + setattr(self, cache_key, instance) + return instance + + return wrapper + + return inner_fn diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py index 283b686..a59abf8 100644 --- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py +++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py @@ -27,6 +27,7 @@ class EfficientNet(nn.Module): def __init__( self, arch: str, + out_channels: int = 256, stochastic_dropout_rate: float = 0.2, bn_momentum: float = 0.99, bn_eps: float = 1.0e-3, @@ -34,6 +35,7 @@ class EfficientNet(nn.Module): super().__init__() assert arch in self.archs, f"{arch} not a valid efficient net architecure!" self.arch = self.archs[arch] + self.out_channels = out_channels self.stochastic_dropout_rate = stochastic_dropout_rate self.bn_momentum = 1 - bn_momentum self.bn_eps = bn_eps @@ -77,7 +79,7 @@ class EfficientNet(nn.Module): args.stride = 1 in_channels = round_filters(320, self.arch) - out_channels = round_filters(1280, self.arch) + out_channels = round_filters(self.out_channels, self.arch) self._conv_head = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), nn.BatchNorm2d( |