diff options
Diffstat (limited to 'src/text_recognizer')
-rw-r--r-- | src/text_recognizer/datasets/transforms.py | 15 | ||||
-rw-r--r-- | src/text_recognizer/line_predictor.py | 4 | ||||
-rw-r--r-- | src/text_recognizer/models/__init__.py | 5 | ||||
-rw-r--r-- | src/text_recognizer/networks/__init__.py | 5 | ||||
-rw-r--r-- | src/text_recognizer/networks/metrics.py (renamed from src/text_recognizer/models/metrics.py) | 0 |
5 files changed, 17 insertions, 12 deletions
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py index d1ca127..1ec23dc 100644 --- a/src/text_recognizer/datasets/transforms.py +++ b/src/text_recognizer/datasets/transforms.py @@ -4,7 +4,7 @@ from PIL import Image import torch from torch import Tensor import torch.nn.functional as F -from torchvision.transforms import Compose, ToPILImage, ToTensor +from torchvision.transforms import Compose, RandomAffine, ToTensor from text_recognizer.datasets.util import EmnistMapper @@ -66,9 +66,14 @@ class AddTokens: return target -class Whitening: - """Whitening of Tensor, i.e. set mean to zero and std to one.""" +class ApplyContrast: + """Sets everything below a threshold to zero, i.e. increase contrast.""" + + def __init__(self, low: float = 0.0, high: float = 0.25) -> None: + self.low = low + self.high = high def __call__(self, x: Tensor) -> Tensor: - """Apply the whitening.""" - return (x - x.mean()) / x.std() + """Apply mask binary mask to input tensor.""" + mask = x > np.random.RandomState().uniform(low=self.low, high=self.high) + return x * mask diff --git a/src/text_recognizer/line_predictor.py b/src/text_recognizer/line_predictor.py index 981e2c9..8e348fe 100644 --- a/src/text_recognizer/line_predictor.py +++ b/src/text_recognizer/line_predictor.py @@ -6,7 +6,7 @@ import numpy as np from torch import nn from text_recognizer import datasets, networks -from text_recognizer.models import VisionTransformerModel +from text_recognizer.models import TransformerModel from text_recognizer.util import read_image @@ -16,7 +16,7 @@ class LinePredictor: def __init__(self, dataset: str, network_fn: str) -> None: network_fn = getattr(networks, network_fn) dataset = getattr(datasets, dataset) - self.model = VisionTransformerModel(network_fn=network_fn, dataset=dataset) + self.model = TransformerModel(network_fn=network_fn, dataset=dataset) self.model.eval() def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]: diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py index 53340f1..bf89404 100644 --- a/src/text_recognizer/models/__init__.py +++ b/src/text_recognizer/models/__init__.py @@ -2,16 +2,11 @@ from .base import Model from .character_model import CharacterModel from .crnn_model import CRNNModel -from .metrics import accuracy, accuracy_ignore_pad, cer, wer from .transformer_model import TransformerModel __all__ = [ - "accuracy", - "accuracy_ignore_pad", - "cer", "CharacterModel", "CRNNModel", "Model", "TransformerModel", - "wer", ] diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index 2cc1137..078d771 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -4,6 +4,7 @@ from .crnn import ConvolutionalRecurrentNetwork from .ctc import greedy_decoder from .densenet import DenseNet from .lenet import LeNet +from .metrics import accuracy, accuracy_ignore_pad, cer, wer from .mlp import MLP from .residual_network import ResidualNetwork, ResidualNetworkEncoder from .sparse_mlp import SparseMLP @@ -12,6 +13,9 @@ from .util import sliding_window from .wide_resnet import WideResidualNetwork __all__ = [ + "accuracy", + "accuracy_ignore_pad", + "cer", "CNNTransformer", "ConvolutionalRecurrentNetwork", "DenseNet", @@ -23,5 +27,6 @@ __all__ = [ "sliding_window", "Transformer", "SparseMLP", + "wer", "WideResidualNetwork", ] diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/networks/metrics.py index af9adb5..af9adb5 100644 --- a/src/text_recognizer/models/metrics.py +++ b/src/text_recognizer/networks/metrics.py |