summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/data/emnist.py4
-rw-r--r--text_recognizer/data/iam_lines.py14
-rw-r--r--text_recognizer/data/iam_paragraphs.py20
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py4
-rw-r--r--text_recognizer/models/dino.py28
-rw-r--r--text_recognizer/networks/encoders/efficientnet/efficientnet.py4
6 files changed, 50 insertions, 24 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index bf3faec..824b947 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -10,7 +10,7 @@ import h5py
from loguru import logger
import numpy as np
import toml
-from torchvision import transforms
+import torchvision.transforms as T
from text_recognizer.data.base_data_module import (
BaseDataModule,
@@ -53,7 +53,7 @@ class EMNIST(BaseDataModule):
self.data_train = None
self.data_val = None
self.data_test = None
- self.transform = transforms.Compose([transforms.ToTensor()])
+ self.transform = T.Compose([T.ToTensor()])
self.dims = (1, *self.input_shape)
self.output_dims = (1,)
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index 78bc8e1..9c78a22 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -13,7 +13,7 @@ from loguru import logger
from PIL import Image, ImageFile, ImageOps
import numpy as np
from torch import Tensor
-from torchvision import transforms
+import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from text_recognizer.data.base_dataset import (
@@ -208,7 +208,7 @@ def load_line_crops_and_labels(split: str, data_dirname: Path) -> Tuple[List, Li
return crops, labels
-def get_transform(image_width: int, augment: bool = False) -> transforms.Compose:
+def get_transform(image_width: int, augment: bool = False) -> T.Compose:
"""Augment with brigthness, sligth rotation, slant, translation, scale, and Gaussian noise."""
def embed_crop(
@@ -237,20 +237,20 @@ def get_transform(image_width: int, augment: bool = False) -> transforms.Compose
return image
- transfroms_list = [transforms.Lambda(embed_crop)]
+ transfroms_list = [T.Lambda(embed_crop)]
if augment:
transfroms_list += [
- transforms.ColorJitter(brightness=(0.8, 1.6)),
- transforms.RandomAffine(
+ T.ColorJitter(brightness=(0.8, 1.6)),
+ T.RandomAffine(
degrees=1,
shear=(-30, 20),
interpolation=InterpolationMode.BILINEAR,
fill=0,
),
]
- transfroms_list.append(transforms.ToTensor())
- return transforms.Compose(transfroms_list)
+ transfroms_list.append(T.ToTensor())
+ return T.Compose(transfroms_list)
def generate_iam_lines() -> None:
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 24409bc..6022804 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Sequence, Tuple
from loguru import logger
import numpy as np
from PIL import Image, ImageOps
-import torchvision.transforms as transforms
+import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
@@ -270,31 +270,31 @@ def _load_processed_crops_and_labels(
return ordered_crops, ordered_labels
-def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Compose:
+def get_transform(image_shape: Tuple[int, int], augment: bool) -> T.Compose:
"""Get transformations for images."""
if augment:
transforms_list = [
- transforms.RandomCrop(
+ T.RandomCrop(
size=image_shape,
padding=None,
pad_if_needed=True,
fill=0,
padding_mode="constant",
),
- transforms.ColorJitter(brightness=(0.8, 1.6)),
- transforms.RandomAffine(
+ T.ColorJitter(brightness=(0.8, 1.6)),
+ T.RandomAffine(
degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR,
),
]
else:
- transforms_list = [transforms.CenterCrop(image_shape)]
- transforms_list.append(transforms.ToTensor())
- return transforms.Compose(transforms_list)
+ transforms_list = [T.CenterCrop(image_shape)]
+ transforms_list.append(T.ToTensor())
+ return T.Compose(transforms_list)
-def get_target_transform(word_pieces: bool) -> Optional[transforms.Compose]:
+def get_target_transform(word_pieces: bool) -> Optional[T.Compose]:
"""Transform emnist characters to word pieces."""
- return transforms.Compose([WordPiece()]) if word_pieces else None
+ return T.Compose([WordPiece()]) if word_pieces else None
def _labels_filename(split: str) -> Path:
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index ad6fa25..00fa2b6 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -1,8 +1,5 @@
"""IAM Synthetic Paragraphs Dataset class."""
-import itertools
-from pathlib import Path
import random
-import time
from typing import Any, List, Sequence, Tuple
from loguru import logger
@@ -12,7 +9,6 @@ from PIL import Image
from text_recognizer.data.base_dataset import (
BaseDataset,
convert_strings_to_labels,
- split_dataset,
)
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.iam_paragraphs import (
diff --git a/text_recognizer/models/dino.py b/text_recognizer/models/dino.py
new file mode 100644
index 0000000..dca954c
--- /dev/null
+++ b/text_recognizer/models/dino.py
@@ -0,0 +1,28 @@
+"""Dino: pretraining of models with self supervision."""
+import copy
+from functools import wraps, partial
+
+import torch
+from torch import nn
+import torch.nn.funtional as F
+import torchvision.transforms as T
+import wandb
+
+from text_recognizer.models.base import LitBaseModel
+
+
+def singleton(cache_key):
+ def inner_fn(fn):
+ @wraps(fn)
+ def wrapper(self, *args, **kwargs):
+ instance = getattr(self, cache_key)
+ if instance is not None:
+ return instance
+
+ instance = fn(self, *args, **kwargs)
+ setattr(self, cache_key, instance)
+ return instance
+
+ return wrapper
+
+ return inner_fn
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
index 283b686..a59abf8 100644
--- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
@@ -27,6 +27,7 @@ class EfficientNet(nn.Module):
def __init__(
self,
arch: str,
+ out_channels: int = 256,
stochastic_dropout_rate: float = 0.2,
bn_momentum: float = 0.99,
bn_eps: float = 1.0e-3,
@@ -34,6 +35,7 @@ class EfficientNet(nn.Module):
super().__init__()
assert arch in self.archs, f"{arch} not a valid efficient net architecure!"
self.arch = self.archs[arch]
+ self.out_channels = out_channels
self.stochastic_dropout_rate = stochastic_dropout_rate
self.bn_momentum = 1 - bn_momentum
self.bn_eps = bn_eps
@@ -77,7 +79,7 @@ class EfficientNet(nn.Module):
args.stride = 1
in_channels = round_filters(320, self.arch)
- out_channels = round_filters(1280, self.arch)
+ out_channels = round_filters(self.out_channels, self.arch)
self._conv_head = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(