summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/callbacks/__init__.py1
-rw-r--r--text_recognizer/callbacks/wandb_callbacks.py8
-rw-r--r--text_recognizer/criterions/__init__.py1
-rw-r--r--text_recognizer/criterions/label_smoothing_loss.py (renamed from text_recognizer/networks/loss/label_smoothing_loss.py)0
-rw-r--r--text_recognizer/data/base_data_module.py14
-rw-r--r--text_recognizer/data/base_dataset.py24
-rw-r--r--text_recognizer/models/__init__.py2
-rw-r--r--text_recognizer/models/base.py11
-rw-r--r--text_recognizer/models/transformer.py30
-rw-r--r--text_recognizer/models/vqvae.py6
-rw-r--r--text_recognizer/networks/cnn_tranformer.py14
-rw-r--r--text_recognizer/networks/loss/__init__.py2
-rw-r--r--text_recognizer/networks/util.py7
13 files changed, 70 insertions, 50 deletions
diff --git a/text_recognizer/callbacks/__init__.py b/text_recognizer/callbacks/__init__.py
index e69de29..82d8ce3 100644
--- a/text_recognizer/callbacks/__init__.py
+++ b/text_recognizer/callbacks/__init__.py
@@ -0,0 +1 @@
+"""Module for PyTorch Lightning callbacks."""
diff --git a/text_recognizer/callbacks/wandb_callbacks.py b/text_recognizer/callbacks/wandb_callbacks.py
index 900c3b1..4186b4a 100644
--- a/text_recognizer/callbacks/wandb_callbacks.py
+++ b/text_recognizer/callbacks/wandb_callbacks.py
@@ -29,7 +29,7 @@ class WatchModel(Callback):
log: str = attr.ib(default="gradients")
log_freq: int = attr.ib(default=100)
- def __attrs_pre_init__(self):
+ def __attrs_pre_init__(self) -> None:
super().__init__()
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
@@ -44,7 +44,7 @@ class UploadCodeAsArtifact(Callback):
project_dir: Path = attr.ib(converter=Path)
- def __attrs_pre_init__(self):
+ def __attrs_pre_init__(self) -> None:
super().__init__()
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
@@ -65,7 +65,7 @@ class UploadCheckpointAsArtifact(Callback):
ckpt_dir: Path = attr.ib(converter=Path)
upload_best_only: bool = attr.ib()
- def __attrs_pre_init__(self):
+ def __attrs_pre_init__(self) -> None:
super().__init__()
def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
@@ -90,7 +90,7 @@ class LogTextPredictions(Callback):
num_samples: int = attr.ib(default=8)
ready: bool = attr.ib(default=True)
- def __attrs_pre_init__(self):
+ def __attrs_pre_init__(self) -> None:
super().__init__()
def on_sanity_check_start(
diff --git a/text_recognizer/criterions/__init__.py b/text_recognizer/criterions/__init__.py
new file mode 100644
index 0000000..5b0a7ab
--- /dev/null
+++ b/text_recognizer/criterions/__init__.py
@@ -0,0 +1 @@
+"""Module with custom loss functions."""
diff --git a/text_recognizer/networks/loss/label_smoothing_loss.py b/text_recognizer/criterions/label_smoothing_loss.py
index 40a7609..40a7609 100644
--- a/text_recognizer/networks/loss/label_smoothing_loss.py
+++ b/text_recognizer/criterions/label_smoothing_loss.py
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index 8b5c188..de5628f 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -2,7 +2,8 @@
from pathlib import Path
from typing import Dict
-import pytorch_lightning as pl
+import attr
+import pytorch_lightning as LightningDataModule
from torch.utils.data import DataLoader
@@ -14,14 +15,17 @@ def load_and_print_info(data_module_class: type) -> None:
print(dataset)
-class BaseDataModule(pl.LightningDataModule):
+@attr.s
+class BaseDataModule(LightningDataModule):
"""Base PyTorch Lightning DataModule."""
- def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None:
+ batch_size: int = attr.ib(default=16)
+ num_workers: int = attr.ib(default=0)
+
+ def __attrs_pre_init__(self) -> None:
super().__init__()
- self.batch_size = batch_size
- self.num_workers = num_workers
+ def __attrs_post_init__(self) -> None:
# Placeholders for subclasses.
self.dims = None
self.output_dims = None
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index 8d644d4..4318dfb 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -1,11 +1,13 @@
"""Base PyTorch Dataset class."""
from typing import Any, Callable, Dict, Sequence, Tuple, Union
+import attr
import torch
from torch import Tensor
from torch.utils.data import Dataset
+@attr.s
class BaseDataset(Dataset):
"""
Base Dataset class that processes data and targets through optional transfroms.
@@ -18,19 +20,17 @@ class BaseDataset(Dataset):
target transforms.
"""
- def __init__(
- self,
- data: Union[Sequence, Tensor],
- targets: Union[Sequence, Tensor],
- transform: Callable = None,
- target_transform: Callable = None,
- ) -> None:
- if len(data) != len(targets):
+ data: Union[Sequence, Tensor] = attr.ib()
+ targets: Union[Sequence, Tensor] = attr.ib()
+ transform: Callable = attr.ib()
+ target_transform: Callable = attr.ib()
+
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ def __attrs_post_init__(self) -> None:
+ if len(self.data) != len(self.targets):
raise ValueError("Data and targets must be of equal length.")
- self.data = data
- self.targets = targets
- self.transform = transform
- self.target_transform = target_transform
def __len__(self) -> int:
"""Return the length of the dataset."""
diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py
index 5ac2510..1982daf 100644
--- a/text_recognizer/models/__init__.py
+++ b/text_recognizer/models/__init__.py
@@ -1,3 +1 @@
"""PyTorch Lightning models modules."""
-from .transformer import LitTransformerModel
-from .vqvae import LitVQVAEModel
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 4e803eb..8dc7a36 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -1,5 +1,5 @@
"""Base PyTorch Lightning model."""
-from typing import Any, Dict, List, Union, Tuple, Type
+from typing import Any, Dict, List, Tuple, Type
import attr
import hydra
@@ -13,7 +13,7 @@ import torchmetrics
@attr.s
-class LitBaseModel(pl.LightningModule):
+class BaseLitModel(pl.LightningModule):
"""Abstract PyTorch Lightning class."""
network: Type[nn.Module] = attr.ib()
@@ -30,18 +30,17 @@ class LitBaseModel(pl.LightningModule):
val_acc = attr.ib(init=False)
test_acc = attr.ib(init=False)
- def __attrs_pre_init__(self):
+ def __attrs_pre_init__(self) -> None:
super().__init__()
- def __attrs_post_init__(self):
- self.loss_fn = self.configure_criterion()
+ def __attrs_post_init__(self) -> None:
+ self.loss_fn = self._configure_criterion()
# Accuracy metric
self.train_acc = torchmetrics.Accuracy()
self.val_acc = torchmetrics.Accuracy()
self.test_acc = torchmetrics.Accuracy()
- @staticmethod
def configure_criterion(self) -> Type[nn.Module]:
"""Returns a loss functions."""
log.info(f"Instantiating criterion <{self.criterion_config._target_}>")
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 6be0ac5..ea54d83 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,27 +1,35 @@
"""PyTorch Lightning model for base Transformers."""
from typing import Dict, List, Optional, Union, Tuple, Type
+import attr
from omegaconf import DictConfig
from torch import nn, Tensor
from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.models.metrics import CharacterErrorRate
from text_recognizer.models.base import LitBaseModel
-class LitTransformerModel(LitBaseModel):
+@attr.s
+class TransformerLitModel(LitBaseModel):
"""A PyTorch Lightning model for transformer networks."""
- def __init__(
- self,
- network: Type[nn.Module],
- optimizer: Union[DictConfig, Dict],
- lr_scheduler: Union[DictConfig, Dict],
- criterion: Union[DictConfig, Dict],
- monitor: str = "val_loss",
- mapping: Optional[List[str]] = None,
- ) -> None:
- super().__init__(network, optimizer, lr_scheduler, criterion, monitor)
+ network: Type[nn.Module] = attr.ib()
+ criterion_config: DictConfig = attr.ib(converter=DictConfig)
+ optimizer_config: DictConfig = attr.ib(converter=DictConfig)
+ lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig)
+ monitor: str = attr.ib()
+ mapping: Type[AbstractMapping] = attr.ib()
+
+ def __attrs_post_init__(self) -> None:
+ super().__init__(
+ network=self.network,
+ optimizer_config=self.optimizer_config,
+ lr_scheduler_config=self.lr_scheduler_config,
+ criterion_config=self.criterion_config,
+ monitor=self.monitor,
+ )
self.mapping, ignore_tokens = self.configure_mapping(mapping)
self.val_cer = CharacterErrorRate(ignore_tokens)
self.test_cer = CharacterErrorRate(ignore_tokens)
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
index 18e8691..7dc950f 100644
--- a/text_recognizer/models/vqvae.py
+++ b/text_recognizer/models/vqvae.py
@@ -18,7 +18,7 @@ class LitVQVAEModel(LitBaseModel):
optimizer: Union[DictConfig, Dict],
lr_scheduler: Union[DictConfig, Dict],
criterion: Union[DictConfig, Dict],
- monitor: str = "val_loss",
+ monitor: str = "val/loss",
*args: Any,
**kwargs: Dict,
) -> None:
@@ -50,7 +50,7 @@ class LitVQVAEModel(LitBaseModel):
reconstructions, vq_loss = self.network(data)
loss = self.loss_fn(reconstructions, data)
loss += vq_loss
- self.log("train_loss", loss)
+ self.log("train/loss", loss)
return loss
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
@@ -59,7 +59,7 @@ class LitVQVAEModel(LitBaseModel):
reconstructions, vq_loss = self.network(data)
loss = self.loss_fn(reconstructions, data)
loss += vq_loss
- self.log("val_loss", loss, prog_bar=True)
+ self.log("val/loss", loss, prog_bar=True)
title = "val_pred_examples"
self._log_prediction(data, reconstructions, title)
diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/cnn_tranformer.py
new file mode 100644
index 0000000..da69311
--- /dev/null
+++ b/text_recognizer/networks/cnn_tranformer.py
@@ -0,0 +1,14 @@
+"""Vision transformer for character recognition."""
+from typing import Type
+
+import attr
+from torch import nn, Tensor
+
+
+@attr.s
+class CnnTransformer(nn.Module):
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ backbone: Type[nn.Module] = attr.ib()
+ head = Type[nn.Module] = attr.ib()
diff --git a/text_recognizer/networks/loss/__init__.py b/text_recognizer/networks/loss/__init__.py
deleted file mode 100644
index cb83608..0000000
--- a/text_recognizer/networks/loss/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-"""Loss module."""
-from .loss import LabelSmoothingCrossEntropy
diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py
index 05b10a8..109bf4d 100644
--- a/text_recognizer/networks/util.py
+++ b/text_recognizer/networks/util.py
@@ -1,10 +1,6 @@
"""Miscellaneous neural network functionality."""
-import importlib
-from pathlib import Path
-from typing import Dict, NamedTuple, Union, Type
+from typing import Type
-from loguru import logger
-import torch
from torch import nn
@@ -19,6 +15,7 @@ def activation_function(activation: str) -> Type[nn.Module]:
["none", nn.Identity()],
["relu", nn.ReLU(inplace=True)],
["selu", nn.SELU(inplace=True)],
+ ["mish", nn.Mish(inplace=True)],
]
)
return activation_fns[activation.lower()]