summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/models/base.py55
-rw-r--r--text_recognizer/models/transformer.py14
-rw-r--r--text_recognizer/networks/image_transformer.py42
3 files changed, 68 insertions, 43 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 2d6e435..1004f48 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -1,33 +1,32 @@
"""Base PyTorch Lightning model."""
-from typing import Any, Dict, List, Tuple, Type
+from typing import Any, Dict, List, Union, Tuple, Type
import madgrad
+from omegaconf import OmegaConf
import pytorch_lightning as pl
import torch
from torch import nn
from torch import Tensor
import torchmetrics
-from text_recognizer import networks
-
class LitBaseModel(pl.LightningModule):
"""Abstract PyTorch Lightning class."""
def __init__(
self,
- network_args: Dict,
- optimizer_args: Dict,
- lr_scheduler_args: Dict,
- criterion_args: Dict,
+ network: Type[nn,Module],
+ optimizer: Union[OmegaConf, Dict],
+ lr_scheduler: Union[OmegaConf, Dict],
+ criterion: Union[OmegaConf, Dict],
monitor: str = "val_loss",
) -> None:
super().__init__()
self.monitor = monitor
- self.network = getattr(networks, network_args["type"])(**network_args["args"])
- self.optimizer_args = optimizer_args
- self.lr_scheduler_args = lr_scheduler_args
- self.loss_fn = self.configure_criterion(criterion_args)
+ self.network = network
+ self._optimizer = OmegaConf.create(optimizer)
+ self._lr_scheduler = OmegaConf.create(lr_scheduler)
+ self.loss_fn = self.configure_criterion(criterion)
# Accuracy metric
self.train_acc = torchmetrics.Accuracy()
@@ -35,27 +34,39 @@ class LitBaseModel(pl.LightningModule):
self.test_acc = torchmetrics.Accuracy()
@staticmethod
- def configure_criterion(criterion_args: Dict) -> Type[nn.Module]:
+ def configure_criterion(criterion: Union[OmegaConf, Dict]) -> Type[nn.Module]:
"""Returns a loss functions."""
- args = {} or criterion_args["args"]
- return getattr(nn, criterion_args["type"])(**args)
+ criterion = OmegaConf.create(criterion)
+ args = {} or criterion.args
+ return getattr(nn, criterion.type)(**args)
- def configure_optimizer(self) -> Tuple[List[type], List[Dict[str, Any]]]:
- """Configures optimizer and lr scheduler."""
- args = {} or self.optimizer_args["args"]
- if self.optimizer_args["type"] == "MADGRAD":
- optimizer = getattr(madgrad, self.optimizer_args["type"])(**args)
+ def _configure_optimizer(self) -> type:
+ """Configures the optimizer."""
+ args = {} or self._optimizer.args
+ if self._optimizer.type == "MADGRAD":
+ optimizer_class = madgrad.MADGRAD
else:
- optimizer = getattr(torch.optim, self.optimizer_args["type"])(**args)
+ optimizer_class = getattr(torch.optim, self._optimizer.type)
+ return optimizer_class(parameters=self.parameters(), **args)
+ def _configure_lr_scheduler(self) -> Dict[str, Any]:
+ """Configures the lr scheduler."""
scheduler = {"monitor": self.monitor}
- args = {} or self.lr_scheduler_args["args"]
+ args = {} or self._lr_scheduler.args
+
if "interval" in args:
scheduler["interval"] = args.pop("interval")
scheduler["scheduler"] = getattr(
- torch.optim.lr_scheduler, self.lr_scheduler_args["type"]
+ torch.optim.lr_scheduler, self._lr_scheduler.type
)(**args)
+ return scheduler
+
+ def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]:
+ """Configures optimizer and lr scheduler."""
+ optimizer = self._configure_optimizer()
+ scheduler = self._configure_lr_scheduler()
+
return [optimizer], [scheduler]
def forward(self, data: Tensor) -> Tensor:
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 285b715..3625ab2 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,6 +1,7 @@
"""PyTorch Lightning model for base Transformers."""
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Union, Tuple
+from omegaconf import OmegaConf
import pytorch_lightning as pl
import torch
from torch import nn
@@ -18,15 +19,15 @@ class LitTransformerModel(LitBaseModel):
def __init__(
self,
- network_args: Dict,
- optimizer_args: Dict,
- lr_scheduler_args: Dict,
- criterion_args: Dict,
+ network: Type[nn,Module],
+ optimizer: Union[OmegaConf, Dict],
+ lr_scheduler: Union[OmegaConf, Dict],
+ criterion: Union[OmegaConf, Dict],
monitor: str = "val_loss",
mapping: Optional[List[str]] = None,
) -> None:
super().__init__(
- network_args, optimizer_args, lr_scheduler_args, criterion_args, monitor
+ network, optimizer, lr_scheduler, criterion, monitor
)
self.mapping, ignore_tokens = self.configure_mapping(mapping)
@@ -40,6 +41,7 @@ class LitTransformerModel(LitBaseModel):
@staticmethod
def configure_mapping(mapping: Optional[List[str]]) -> Tuple[List[str], List[int]]:
"""Configure mapping."""
+ # TODO: Fix me!!!
mapping, inverse_mapping, _ = emnist_mapping()
start_index = inverse_mapping["<s>"]
end_index = inverse_mapping["<e>"]
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py
index b9254c9..aa024e0 100644
--- a/text_recognizer/networks/image_transformer.py
+++ b/text_recognizer/networks/image_transformer.py
@@ -8,10 +8,12 @@ together with the target tokens.
TODO: Local attention for transformer.j
"""
+import importlib
import math
-from typing import Any, Dict, List, Optional, Sequence, Type
+from typing import Dict, List, Union, Sequence, Tuple, Type
from einops import rearrange
+from omegaconf import OmegaConf
import torch
from torch import nn
from torch import Tensor
@@ -32,8 +34,8 @@ class ImageTransformer(nn.Module):
self,
input_shape: Sequence[int],
output_shape: Sequence[int],
- backbone: Type[nn.Module],
- mapping: Optional[List[str]] = None,
+ encoder: Union[OmegaConf, Dict],
+ mapping: str,
num_decoder_layers: int = 4,
hidden_dim: int = 256,
num_heads: int = 4,
@@ -51,8 +53,8 @@ class ImageTransformer(nn.Module):
self.pad_index = inverse_mapping["<p>"]
# Image backbone
- self.backbone = backbone
- self.latent_encoding = PositionalEncoding2D(
+ self.encoder = self._configure_encoder(encoder)
+ self.feature_map_encoding = PositionalEncoding2D(
hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2]
)
@@ -86,20 +88,30 @@ class ImageTransformer(nn.Module):
self.head.weight.data.uniform_(-0.1, 0.1)
nn.init.kaiming_normal_(
- self.latent_encoding.weight.data, a=0, mode="fan_out", nonlinearity="relu"
+ self.feature_map_encoding.weight.data,
+ a=0,
+ mode="fan_out",
+ nonlinearity="relu",
)
- if self.latent_encoding.bias is not None:
+ if self.feature_map_encoding.bias is not None:
_, fan_out = nn.init._calculate_fan_in_and_fan_out(
- self.latent_encoding.weight.data
+ self.feature_map_encoding.weight.data
)
bound = 1 / math.sqrt(fan_out)
- nn.init.normal_(self.latent_encoding.bias, -bound, bound)
+ nn.init.normal_(self.feature_map_encoding.bias, -bound, bound)
+
+ @staticmethod
+ def _configure_encoder(encoder: Union[OmegaConf, NamedTuple]) -> Type[nn.Module]:
+ encoder = OmegaConf.create(encoder)
+ network_module = importlib.import_module("text_recognizer.networks")
+ encoder_class = getattr(network_module, encoder.type)
+ return encoder_class(**encoder.args)
def _configure_mapping(
- self, mapping: Optional[List[str]]
+ self, mapping: str
) -> Tuple[List[str], Dict[str, int]]:
"""Configures mapping."""
- if mapping is None:
+ if mapping == "emnist":
mapping, inverse_mapping, _ = emnist_mapping()
return mapping, inverse_mapping
@@ -118,14 +130,14 @@ class ImageTransformer(nn.Module):
"""
# Extract image features.
- latent = self.backbone(image)
+ image_features = self.encoder(image)
# Add 2d encoding to the feature maps.
- latent = self.latent_encoding(latent)
+ image_features = self.feature_map_encoding(image_features)
# Collapse features maps height and width.
- latent = rearrange(latent, "b c h w -> b (h w) c")
- return latent
+ image_features = rearrange(image_features, "b c h w -> b (h w) c")
+ return image_features
def decode(self, memory: Tensor, trg: Tensor) -> Tensor:
"""Decodes image features with transformer decoder."""