summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py11
-rw-r--r--src/text_recognizer/datasets/iam_paragraphs_dataset.py8
-rw-r--r--src/text_recognizer/datasets/transforms.py18
3 files changed, 35 insertions, 2 deletions
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index 6871492..eddf341 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -10,6 +10,7 @@ from loguru import logger
import numpy as np
import torch
from torch import Tensor
+import torch.nn.functional as F
from torchvision.transforms import ToTensor
from text_recognizer.datasets.dataset import Dataset
@@ -23,6 +24,8 @@ from text_recognizer.datasets.util import (
DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines"
+MAX_WIDTH = 952
+
class EmnistLinesDataset(Dataset):
"""Synthetic dataset of lines from the Brown corpus with Emnist characters."""
@@ -254,6 +257,14 @@ def construct_image_from_string(
for image in sampled_images:
concatenated_image[:, x : (x + width)] += image
x += next_overlap_width
+
+ if concatenated_image.shape[-1] > MAX_WIDTH:
+ concatenated_image = Tensor(concatenated_image).unsqueeze(0)
+ concatenated_image = F.interpolate(
+ concatenated_image, size=MAX_WIDTH, mode="nearest"
+ )
+ concatenated_image = concatenated_image.squeeze(0).numpy()
+
return np.minimum(255, concatenated_image)
diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/src/text_recognizer/datasets/iam_paragraphs_dataset.py
index c1e8fe2..8ba5142 100644
--- a/src/text_recognizer/datasets/iam_paragraphs_dataset.py
+++ b/src/text_recognizer/datasets/iam_paragraphs_dataset.py
@@ -1,4 +1,5 @@
"""IamParagraphsDataset class and functions for data processing."""
+import random
from typing import Callable, Dict, List, Optional, Tuple, Union
import click
@@ -71,13 +72,18 @@ class IamParagraphsDataset(Dataset):
data = self.data[index]
targets = self.targets[index]
+ seed = np.random.randint(SEED)
+ random.seed(seed) # apply this seed to target tranfsorms
+ torch.manual_seed(seed) # needed for torchvision 0.7
if self.transform:
data = self.transform(data)
+ random.seed(seed) # apply this seed to target tranfsorms
+ torch.manual_seed(seed) # needed for torchvision 0.7
if self.target_transform:
targets = self.target_transform(targets)
- return data, targets
+ return data, targets.long()
@property
def ids(self) -> Tensor:
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
index 1ec23dc..016ec80 100644
--- a/src/text_recognizer/datasets/transforms.py
+++ b/src/text_recognizer/datasets/transforms.py
@@ -4,7 +4,7 @@ from PIL import Image
import torch
from torch import Tensor
import torch.nn.functional as F
-from torchvision.transforms import Compose, RandomAffine, ToTensor
+from torchvision.transforms import Compose, RandomAffine, RandomHorizontalFlip, ToTensor
from text_recognizer.datasets.util import EmnistMapper
@@ -77,3 +77,19 @@ class ApplyContrast:
"""Apply mask binary mask to input tensor."""
mask = x > np.random.RandomState().uniform(low=self.low, high=self.high)
return x * mask
+
+
+class Unsqueeze:
+ """Add a dimension to the tensor."""
+
+ def __call__(self, x: Tensor) -> Tensor:
+ """Adds dim."""
+ return x.unsqueeze(0)
+
+
+class Squeeze:
+ """Removes the first dimension of a tensor."""
+
+ def __call__(self, x: Tensor) -> Tensor:
+ """Removes first dim."""
+ return x.squeeze(0)