summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-06 17:42:53 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-06 17:42:53 +0200
commiteb5b206f7e1b08435378d2a02395307be55ee6f1 (patch)
tree0cd30234afab698eb632b20a7da97e3bc7e98882 /text_recognizer
parent4d1f2cef39688871d2caafce42a09316381a27ae (diff)
Refactoring data with attrs and refactor conf for hydra
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/callbacks/wandb_callbacks.py95
-rw-r--r--text_recognizer/data/base_data_module.py29
-rw-r--r--text_recognizer/data/emnist.py22
-rw-r--r--text_recognizer/data/emnist_lines.py35
-rw-r--r--text_recognizer/data/iam.py6
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py33
-rw-r--r--text_recognizer/data/iam_lines.py22
-rw-r--r--text_recognizer/data/iam_paragraphs.py32
-rw-r--r--text_recognizer/models/base.py5
-rw-r--r--text_recognizer/models/metrics.py15
-rw-r--r--text_recognizer/models/transformer.py26
-rw-r--r--text_recognizer/models/vqvae.py34
-rw-r--r--text_recognizer/networks/util.py2
13 files changed, 188 insertions, 168 deletions
diff --git a/text_recognizer/callbacks/wandb_callbacks.py b/text_recognizer/callbacks/wandb_callbacks.py
index 4186b4a..d9d81f6 100644
--- a/text_recognizer/callbacks/wandb_callbacks.py
+++ b/text_recognizer/callbacks/wandb_callbacks.py
@@ -93,6 +93,40 @@ class LogTextPredictions(Callback):
def __attrs_pre_init__(self) -> None:
super().__init__()
+ def _log_predictions(
+ stage: str, trainer: Trainer, pl_module: LightningModule
+ ) -> None:
+ """Logs the predicted text contained in the images."""
+ if not self.ready:
+ return None
+
+ logger = get_wandb_logger(trainer)
+ experiment = logger.experiment
+
+ # Get a validation batch from the validation dataloader.
+ samples = next(iter(trainer.datamodule.val_dataloader()))
+ imgs, labels = samples
+
+ imgs = imgs.to(device=pl_module.device)
+ logits = pl_module(imgs)
+
+ mapping = pl_module.mapping
+ experiment.log(
+ {
+ f"OCR/{experiment.name}/{stage}": [
+ wandb.Image(
+ img,
+ caption=f"Pred: {mapping.get_text(pred)}, Label: {mapping.get_text(label)}",
+ )
+ for img, pred, label in zip(
+ imgs[: self.num_samples],
+ logits[: self.num_samples],
+ labels[: self.num_samples],
+ )
+ ]
+ }
+ )
+
def on_sanity_check_start(
self, trainer: Trainer, pl_module: LightningModule
) -> None:
@@ -107,6 +141,27 @@ class LogTextPredictions(Callback):
self, trainer: Trainer, pl_module: LightningModule
) -> None:
"""Logs predictions on validation epoch end."""
+ self._log_predictions(stage="val", trainer=trainer, pl_module=pl_module)
+
+ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ """Logs predictions on train epoch end."""
+ self._log_predictions(stage="test", trainer=trainer, pl_module=pl_module)
+
+
+@attr.s
+class LogReconstuctedImages(Callback):
+ """Log reconstructions of images."""
+
+ num_samples: int = attr.ib(default=8)
+ ready: bool = attr.ib(default=True)
+
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ def _log_reconstruction(
+ self, stage: str, trainer: Trainer, pl_module: LightningModule
+ ) -> None:
+ """Logs the reconstructions."""
if not self.ready:
return None
@@ -115,24 +170,42 @@ class LogTextPredictions(Callback):
# Get a validation batch from the validation dataloader.
samples = next(iter(trainer.datamodule.val_dataloader()))
- imgs, labels = samples
+ imgs, _ = samples
imgs = imgs.to(device=pl_module.device)
- logits = pl_module(imgs)
+ reconstructions = pl_module(imgs)
- mapping = pl_module.mapping
experiment.log(
{
- f"Images/{experiment.name}": [
- wandb.Image(
- img,
- caption=f"Pred: {mapping.get_text(pred)}, Label: {mapping.get_text(label)}",
- )
- for img, pred, label in zip(
+ f"Reconstructions/{experiment.name}/{stage}": [
+ [
+ wandb.Image(img),
+ wandb.Image(rec),
+ ]
+ for img, rec in zip(
imgs[: self.num_samples],
- logits[: self.num_samples],
- labels[: self.num_samples],
+ reconstructions[: self.num_samples],
)
]
}
)
+
+ def on_sanity_check_start(
+ self, trainer: Trainer, pl_module: LightningModule
+ ) -> None:
+ """Sets ready attribute."""
+ self.ready = False
+
+ def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ """Start executing this callback only after all validation sanity checks end."""
+ self.ready = True
+
+ def on_validation_epoch_end(
+ self, trainer: Trainer, pl_module: LightningModule
+ ) -> None:
+ """Logs predictions on validation epoch end."""
+ self._log_reconstruction(stage="val", trainer=trainer, pl_module=pl_module)
+
+ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ """Logs predictions on train epoch end."""
+ self._log_reconstruction(stage="test", trainer=trainer, pl_module=pl_module)
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index de5628f..18b1996 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -1,11 +1,13 @@
"""Base lightning DataModule class."""
from pathlib import Path
-from typing import Dict
+from typing import Any, Dict, Tuple
import attr
-import pytorch_lightning as LightningDataModule
+from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
+from text_recognizer.data.base_dataset import BaseDataset
+
def load_and_print_info(data_module_class: type) -> None:
"""Load dataset and print dataset information."""
@@ -19,17 +21,20 @@ def load_and_print_info(data_module_class: type) -> None:
class BaseDataModule(LightningDataModule):
"""Base PyTorch Lightning DataModule."""
- batch_size: int = attr.ib(default=16)
- num_workers: int = attr.ib(default=0)
-
def __attrs_pre_init__(self) -> None:
super().__init__()
- def __attrs_post_init__(self) -> None:
- # Placeholders for subclasses.
- self.dims = None
- self.output_dims = None
- self.mapping = None
+ batch_size: int = attr.ib(default=16)
+ num_workers: int = attr.ib(default=0)
+
+ # Placeholders
+ data_train: BaseDataset = attr.ib(init=False, default=None)
+ data_val: BaseDataset = attr.ib(init=False, default=None)
+ data_test: BaseDataset = attr.ib(init=False, default=None)
+ dims: Tuple[int, ...] = attr.ib(init=False, default=None)
+ output_dims: Tuple[int, ...] = attr.ib(init=False, default=None)
+ mapping: Any = attr.ib(init=False, default=None)
+ inverse_mapping: Dict[str, int] = attr.ib(init=False)
@classmethod
def data_dirname(cls) -> Path:
@@ -58,9 +63,7 @@ class BaseDataModule(LightningDataModule):
stage (Any): Variable to set splits.
"""
- self.data_train = None
- self.data_val = None
- self.data_test = None
+ pass
def train_dataloader(self) -> DataLoader:
"""Retun DataLoader for train data."""
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index 824b947..d51a42a 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -3,9 +3,10 @@ import json
import os
from pathlib import Path
import shutil
-from typing import Dict, List, Optional, Sequence, Tuple
+from typing import Callable, Dict, List, Optional, Sequence, Tuple
import zipfile
+import attr
import h5py
from loguru import logger
import numpy as np
@@ -32,6 +33,7 @@ PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5"
ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json"
+@attr.s(auto_attribs=True)
class EMNIST(BaseDataModule):
"""Lightning DataModule class for loading EMNIST dataset.
@@ -44,18 +46,12 @@ class EMNIST(BaseDataModule):
EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
"""
- def __init__(
- self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8
- ) -> None:
- super().__init__(batch_size, num_workers)
- self.train_fraction = train_fraction
- self.mapping, self.inverse_mapping, self.input_shape = emnist_mapping()
- self.data_train = None
- self.data_val = None
- self.data_test = None
- self.transform = T.Compose([T.ToTensor()])
- self.dims = (1, *self.input_shape)
- self.output_dims = (1,)
+ train_fraction: float = attr.ib()
+ transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()]))
+
+ def __attrs_post_init__(self) -> None:
+ self.mapping, self.inverse_mapping, input_shape = emnist_mapping()
+ self.dims = (1, *input_shape)
def prepare_data(self) -> None:
"""Downloads dataset if not present."""
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 9650198..4747508 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -3,6 +3,7 @@ from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, Tuple
+import attr
import h5py
from loguru import logger
import numpy as np
@@ -31,31 +32,20 @@ IMAGE_X_PADDING = 28
MAX_OUTPUT_LENGTH = 89 # Same as IAMLines
+@attr.s(auto_attribs=True)
class EMNISTLines(BaseDataModule):
"""EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,"""
- def __init__(
- self,
- augment: bool = True,
- batch_size: int = 128,
- num_workers: int = 0,
- max_length: int = 32,
- min_overlap: float = 0.0,
- max_overlap: float = 0.33,
- num_train: int = 10_000,
- num_val: int = 2_000,
- num_test: int = 2_000,
- ) -> None:
- super().__init__(batch_size, num_workers)
-
- self.augment = augment
- self.max_length = max_length
- self.min_overlap = min_overlap
- self.max_overlap = max_overlap
- self.num_train = num_train
- self.num_val = num_val
- self.num_test = num_test
+ augment: bool = attr.ib(default=True)
+ max_length: int = attr.ib(default=128)
+ min_overlap: float = attr.ib(default=0.0)
+ max_overlap: float = attr.ib(default=0.33)
+ num_train: int = attr.ib(default=10_000)
+ num_val: int = attr.ib(default=2_000)
+ num_test: int = attr.ib(default=2_000)
+ emnist: EMNIST = attr.ib(init=False, default=None)
+ def __attrs_post_init__(self) -> None:
self.emnist = EMNIST()
self.mapping = self.emnist.mapping
@@ -75,9 +65,6 @@ class EMNISTLines(BaseDataModule):
raise ValueError("max_length greater than MAX_OUTPUT_LENGTH")
self.output_dims = (MAX_OUTPUT_LENGTH, 1)
- self.data_train: BaseDataset = None
- self.data_val: BaseDataset = None
- self.data_test: BaseDataset = None
@property
def data_filename(self) -> Path:
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
index 261c8d3..3982c4f 100644
--- a/text_recognizer/data/iam.py
+++ b/text_recognizer/data/iam.py
@@ -5,6 +5,7 @@ from typing import Any, Dict, List
import xml.etree.ElementTree as ElementTree
import zipfile
+import attr
from boltons.cacheutils import cachedproperty
from loguru import logger
import toml
@@ -22,6 +23,7 @@ DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be.
LINE_REGION_PADDING = 16 # Add this many pixels around the exact coordinates.
+@attr.s(auto_attribs=True)
class IAM(BaseDataModule):
"""
"The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text,
@@ -35,9 +37,7 @@ class IAM(BaseDataModule):
The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only.
"""
- def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None:
- super().__init__(batch_size, num_workers)
- self.metadata = toml.load(METADATA_FILENAME)
+ metadata: Dict = attr.ib(init=False, default=toml.load(METADATA_FILENAME))
def prepare_data(self) -> None:
if self.xml_filenames:
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 0a30a42..886e37e 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -1,4 +1,7 @@
"""IAM original and sythetic dataset class."""
+from typing import Dict, List
+
+import attr
from torch.utils.data import ConcatDataset
from text_recognizer.data.base_dataset import BaseDataset
@@ -7,22 +10,26 @@ from text_recognizer.data.iam_paragraphs import IAMParagraphs
from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
+@attr.s(auto_attribs=True)
class IAMExtendedParagraphs(BaseDataModule):
- def __init__(
- self,
- batch_size: int = 16,
- num_workers: int = 0,
- train_fraction: float = 0.8,
- augment: bool = True,
- word_pieces: bool = False,
- ) -> None:
- super().__init__(batch_size, num_workers)
+ train_fraction: float = attr.ib()
+ word_pieces: bool = attr.ib(default=False)
+
+ def __attrs_post_init__(self) -> None:
self.iam_paragraphs = IAMParagraphs(
- batch_size, num_workers, train_fraction, augment, word_pieces,
+ self.batch_size,
+ self.num_workers,
+ self.train_fraction,
+ self.augment,
+ self.word_pieces,
)
self.iam_synthetic_paragraphs = IAMSyntheticParagraphs(
- batch_size, num_workers, train_fraction, augment, word_pieces,
+ self.batch_size,
+ self.num_workers,
+ self.train_fraction,
+ self.augment,
+ self.word_pieces,
)
self.dims = self.iam_paragraphs.dims
@@ -30,10 +37,6 @@ class IAMExtendedParagraphs(BaseDataModule):
self.mapping = self.iam_paragraphs.mapping
self.inverse_mapping = self.iam_paragraphs.inverse_mapping
- self.data_train: BaseDataset = None
- self.data_val: BaseDataset = None
- self.data_test: BaseDataset = None
-
def prepare_data(self) -> None:
"""Prepares the paragraphs data."""
self.iam_paragraphs.prepare_data()
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index 9c78a22..e45e5c8 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -7,8 +7,9 @@ dataset.
import json
from pathlib import Path
import random
-from typing import List, Sequence, Tuple
+from typing import Dict, List, Sequence, Tuple
+import attr
from loguru import logger
from PIL import Image, ImageFile, ImageOps
import numpy as np
@@ -35,26 +36,17 @@ IMAGE_HEIGHT = 56
IMAGE_WIDTH = 1024
+@attr.s(auto_attribs=True)
class IAMLines(BaseDataModule):
"""IAM handwritten lines dataset."""
- def __init__(
- self,
- augment: bool = True,
- fraction: float = 0.8,
- batch_size: int = 128,
- num_workers: int = 0,
- ) -> None:
- # TODO: add transforms
- super().__init__(batch_size, num_workers)
- self.augment = augment
- self.fraction = fraction
+ augment: bool = attr.ib(default=True)
+ fraction: float = attr.ib(default=0.8)
+
+ def __attrs_post_init__(self) -> None:
self.mapping, self.inverse_mapping, _ = emnist_mapping()
self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
self.output_dims = (89, 1)
- self.data_train: BaseDataset = None
- self.data_val: BaseDataset = None
- self.data_test: BaseDataset = None
def prepare_data(self) -> None:
"""Creates the IAM lines dataset if not existing."""
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index fe60e99..445b788 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -3,6 +3,7 @@ import json
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
+import attr
from loguru import logger
import numpy as np
from PIL import Image, ImageOps
@@ -33,33 +34,25 @@ IMAGE_WIDTH = 1280 // IMAGE_SCALE_FACTOR
MAX_LABEL_LENGTH = 682
+@attr.s(auto_attribs=True)
class IAMParagraphs(BaseDataModule):
"""IAM handwriting database paragraphs."""
- def __init__(
- self,
- batch_size: int = 16,
- num_workers: int = 0,
- train_fraction: float = 0.8,
- augment: bool = True,
- word_pieces: bool = False,
- ) -> None:
- super().__init__(batch_size, num_workers)
- self.augment = augment
- self.word_pieces = word_pieces
+ augment: bool = attr.ib(default=True)
+ train_fraction: float = attr.ib(default=0.8)
+ word_pieces: bool = attr.ib(default=False)
+
+ def __attrs_post_init__(self) -> None:
self.mapping, self.inverse_mapping, _ = emnist_mapping(
extra_symbols=[NEW_LINE_TOKEN]
)
- if word_pieces:
+ if self.word_pieces:
self.mapping = WordPieceMapping()
self.train_fraction = train_fraction
self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
self.output_dims = (MAX_LABEL_LENGTH, 1)
- self.data_train: BaseDataset = None
- self.data_val: BaseDataset = None
- self.data_test: BaseDataset = None
def prepare_data(self) -> None:
"""Create data for training/testing."""
@@ -166,7 +159,10 @@ def get_dataset_properties() -> Dict:
"min": min(_get_property_values("num_lines")),
"max": max(_get_property_values("num_lines")),
},
- "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),},
+ "crop_shape": {
+ "min": crop_shapes.min(axis=0),
+ "max": crop_shapes.max(axis=0),
+ },
"aspect_ratio": {
"min": aspect_ratio.min(axis=0),
"max": aspect_ratio.max(axis=0),
@@ -287,7 +283,9 @@ def get_transform(image_shape: Tuple[int, int], augment: bool) -> T.Compose:
),
T.ColorJitter(brightness=(0.8, 1.6)),
T.RandomAffine(
- degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR,
+ degrees=1,
+ shear=(-10, 10),
+ interpolation=InterpolationMode.BILINEAR,
),
]
else:
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 8dc7a36..f95df0f 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -5,7 +5,7 @@ import attr
import hydra
import loguru.logger as log
from omegaconf import DictConfig
-import pytorch_lightning as pl
+import pytorch_lightning as LightningModule
import torch
from torch import nn
from torch import Tensor
@@ -13,7 +13,7 @@ import torchmetrics
@attr.s
-class BaseLitModel(pl.LightningModule):
+class BaseLitModel(LightningModule):
"""Abstract PyTorch Lightning class."""
network: Type[nn.Module] = attr.ib()
@@ -80,7 +80,6 @@ class BaseLitModel(pl.LightningModule):
"""Configures optimizer and lr scheduler."""
optimizer = self._configure_optimizer()
scheduler = self._configure_lr_scheduler(optimizer)
-
return [optimizer], [scheduler]
def forward(self, data: Tensor) -> Tensor:
diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py
index 58d0537..4117ae2 100644
--- a/text_recognizer/models/metrics.py
+++ b/text_recognizer/models/metrics.py
@@ -1,18 +1,23 @@
"""Character Error Rate (CER)."""
-from typing import Sequence
+from typing import Set, Sequence
+import attr
import editdistance
import torch
from torch import Tensor
-import torchmetrics
+from torchmetrics import Metric
-class CharacterErrorRate(torchmetrics.Metric):
+@attr.s
+class CharacterErrorRate(Metric):
"""Character error rate metric, computed using Levenshtein distance."""
- def __init__(self, ignore_tokens: Sequence[int], *args) -> None:
+ ignore_tokens: Set = attr.ib(converter=set)
+ error: Tensor = attr.ib(init=False)
+ total: Tensor = attr.ib(init=False)
+
+ def __attrs_post_init__(self) -> None:
super().__init__()
- self.ignore_tokens = set(ignore_tokens)
self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index ea54d83..8c9fe8a 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -2,35 +2,24 @@
from typing import Dict, List, Optional, Union, Tuple, Type
import attr
+import hydra
from omegaconf import DictConfig
from torch import nn, Tensor
from text_recognizer.data.emnist import emnist_mapping
from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.models.metrics import CharacterErrorRate
-from text_recognizer.models.base import LitBaseModel
+from text_recognizer.models.base import BaseLitModel
-@attr.s
-class TransformerLitModel(LitBaseModel):
+@attr.s(auto_attribs=True)
+class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- network: Type[nn.Module] = attr.ib()
- criterion_config: DictConfig = attr.ib(converter=DictConfig)
- optimizer_config: DictConfig = attr.ib(converter=DictConfig)
- lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig)
- monitor: str = attr.ib()
- mapping: Type[AbstractMapping] = attr.ib()
+ mapping_config: DictConfig = attr.ib(converter=DictConfig)
def __attrs_post_init__(self) -> None:
- super().__init__(
- network=self.network,
- optimizer_config=self.optimizer_config,
- lr_scheduler_config=self.lr_scheduler_config,
- criterion_config=self.criterion_config,
- monitor=self.monitor,
- )
- self.mapping, ignore_tokens = self.configure_mapping(mapping)
+ self.mapping, ignore_tokens = self._configure_mapping()
self.val_cer = CharacterErrorRate(ignore_tokens)
self.test_cer = CharacterErrorRate(ignore_tokens)
@@ -39,9 +28,10 @@ class TransformerLitModel(LitBaseModel):
return self.network.predict(data)
@staticmethod
- def configure_mapping(mapping: Optional[List[str]]) -> Tuple[List[str], List[int]]:
+ def _configure_mapping() -> Tuple[Type[AbstractMapping], List[int]]:
"""Configure mapping."""
# TODO: Fix me!!!
+ # Load config with hydra
mapping, inverse_mapping, _ = emnist_mapping(["\n"])
start_index = inverse_mapping["<s>"]
end_index = inverse_mapping["<e>"]
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
index 7dc950f..0172163 100644
--- a/text_recognizer/models/vqvae.py
+++ b/text_recognizer/models/vqvae.py
@@ -1,49 +1,23 @@
"""PyTorch Lightning model for base Transformers."""
from typing import Any, Dict, Union, Tuple, Type
+import attr
from omegaconf import DictConfig
from torch import nn
from torch import Tensor
import wandb
-from text_recognizer.models.base import LitBaseModel
+from text_recognizer.models.base import BaseLitModel
-class LitVQVAEModel(LitBaseModel):
+@attr.s(auto_attribs=True)
+class VQVAELitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- def __init__(
- self,
- network: Type[nn.Module],
- optimizer: Union[DictConfig, Dict],
- lr_scheduler: Union[DictConfig, Dict],
- criterion: Union[DictConfig, Dict],
- monitor: str = "val/loss",
- *args: Any,
- **kwargs: Dict,
- ) -> None:
- super().__init__(network, optimizer, lr_scheduler, criterion, monitor)
-
def forward(self, data: Tensor) -> Tensor:
"""Forward pass with the transformer network."""
return self.network.predict(data)
- def _log_prediction(
- self, data: Tensor, reconstructions: Tensor, title: str
- ) -> None:
- """Logs prediction on image with wandb."""
- try:
- self.logger.experiment.log(
- {
- title: [
- wandb.Image(data[0]),
- wandb.Image(reconstructions[0]),
- ]
- }
- )
- except AttributeError:
- pass
-
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
"""Training step."""
data, _ = batch
diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py
index 109bf4d..85094f1 100644
--- a/text_recognizer/networks/util.py
+++ b/text_recognizer/networks/util.py
@@ -1,4 +1,4 @@
-"""Miscellaneous neural network functionality."""
+"""Miscellaneous neural network utility functionality."""
from typing import Type
from torch import nn