summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r--src/text_recognizer/models/__init__.py14
-rw-r--r--src/text_recognizer/models/base.py8
-rw-r--r--src/text_recognizer/models/metrics.py21
-rw-r--r--src/text_recognizer/models/transformer_encoder_model.py111
-rw-r--r--src/text_recognizer/models/transformer_model.py (renamed from src/text_recognizer/models/vision_transformer_model.py)13
5 files changed, 33 insertions, 134 deletions
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index 28aa52e..53340f1 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -2,18 +2,16 @@
from .base import Model
from .character_model import CharacterModel
from .crnn_model import CRNNModel
-from .metrics import accuracy, cer, wer
-from .transformer_encoder_model import TransformerEncoderModel
-from .vision_transformer_model import VisionTransformerModel
+from .metrics import accuracy, accuracy_ignore_pad, cer, wer
+from .transformer_model import TransformerModel
__all__ = [
- "Model",
+ "accuracy",
+ "accuracy_ignore_pad",
"cer",
"CharacterModel",
"CRNNModel",
- "CNNTransfromerModel",
- "accuracy",
- "TransformerEncoderModel",
- "VisionTransformerModel",
+ "Model",
+ "TransformerModel",
"wer",
]
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index cc44c92..a945b41 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -49,7 +49,7 @@ class Model(ABC):
network_args (Optional[Dict]): Arguments for the network. Defaults to None.
dataset_args (Optional[Dict]): Arguments for the dataset.
metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
- criterion (Optional[Callable]): The criterion to evaulate the preformance of the network.
+ criterion (Optional[Callable]): The criterion to evaluate the performance of the network.
Defaults to None.
criterion_args (Optional[Dict]): Dict of arguments for criterion. Defaults to None.
optimizer (Optional[Callable]): The optimizer for updating the weights. Defaults to None.
@@ -221,7 +221,7 @@ class Model(ABC):
def _configure_network(self, network_fn: Type[nn.Module]) -> None:
"""Loads the network."""
- # If no network arguemnts are given, load pretrained weights if they exist.
+ # If no network arguments are given, load pretrained weights if they exist.
if self._network_args is None:
self.load_weights(network_fn)
else:
@@ -245,7 +245,7 @@ class Model(ABC):
self._optimizer = None
if self._optimizer and self._lr_scheduler is not None:
- if "OneCycleLR" in str(self._lr_scheduler):
+ if "steps_per_epoch" in self.lr_scheduler_args:
self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader())
# Assume lr scheduler should update at each epoch if not specified.
@@ -412,7 +412,7 @@ class Model(ABC):
self._optimizer.load_state_dict(checkpoint["optimizer_state"])
if self._lr_scheduler is not None:
- # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs
+ # Does not work when loading from previous checkpoint and trying to train beyond the last max epochs
# with OneCycleLR.
if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR":
self._lr_scheduler["lr_scheduler"].load_state_dict(
diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py
index 42c3c6e..af9adb5 100644
--- a/src/text_recognizer/models/metrics.py
+++ b/src/text_recognizer/models/metrics.py
@@ -6,7 +6,23 @@ from torch import Tensor
from text_recognizer.networks import greedy_decoder
-def accuracy(outputs: Tensor, labels: Tensor) -> float:
+def accuracy_ignore_pad(
+ output: Tensor,
+ target: Tensor,
+ pad_index: int = 79,
+ eos_index: int = 81,
+ seq_len: int = 97,
+) -> float:
+ """Sets all predictions after eos to pad."""
+ start_indices = torch.nonzero(target == eos_index, as_tuple=False).squeeze(1)
+ end_indices = torch.arange(seq_len, target.shape[0] + 1, seq_len)
+ for start, stop in zip(start_indices, end_indices):
+ output[start + 1 : stop] = pad_index
+
+ return accuracy(output, target)
+
+
+def accuracy(outputs: Tensor, labels: Tensor,) -> float:
"""Computes the accuracy.
Args:
@@ -17,10 +33,9 @@ def accuracy(outputs: Tensor, labels: Tensor) -> float:
float: The accuracy for the batch.
"""
- # eos_index = torch.nonzero(labels == eos, as_tuple=False)
- # eos_index = eos_index[0].item() if eos_index.nelement() else -1
_, predicted = torch.max(outputs, dim=-1)
+
acc = (predicted == labels).sum().float() / labels.shape[0]
acc = acc.item()
return acc
diff --git a/src/text_recognizer/models/transformer_encoder_model.py b/src/text_recognizer/models/transformer_encoder_model.py
deleted file mode 100644
index e35e298..0000000
--- a/src/text_recognizer/models/transformer_encoder_model.py
+++ /dev/null
@@ -1,111 +0,0 @@
-"""Defines the CNN-Transformer class."""
-from typing import Callable, Dict, List, Optional, Tuple, Type, Union
-
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-from torch.utils.data import Dataset
-from torchvision.transforms import ToTensor
-
-from text_recognizer.datasets import EmnistMapper
-from text_recognizer.models.base import Model
-
-
-class TransformerEncoderModel(Model):
- """A class for only using the encoder part in the sequence modelling."""
-
- def __init__(
- self,
- network_fn: Type[nn.Module],
- dataset: Type[Dataset],
- network_args: Optional[Dict] = None,
- dataset_args: Optional[Dict] = None,
- metrics: Optional[Dict] = None,
- criterion: Optional[Callable] = None,
- criterion_args: Optional[Dict] = None,
- optimizer: Optional[Callable] = None,
- optimizer_args: Optional[Dict] = None,
- lr_scheduler: Optional[Callable] = None,
- lr_scheduler_args: Optional[Dict] = None,
- swa_args: Optional[Dict] = None,
- device: Optional[str] = None,
- ) -> None:
- super().__init__(
- network_fn,
- dataset,
- network_args,
- dataset_args,
- metrics,
- criterion,
- criterion_args,
- optimizer,
- optimizer_args,
- lr_scheduler,
- lr_scheduler_args,
- swa_args,
- device,
- )
- # self.init_token = dataset_args["args"]["init_token"]
- self.pad_token = dataset_args["args"]["pad_token"]
- self.eos_token = dataset_args["args"]["eos_token"]
- if network_args is not None:
- self.max_len = network_args["max_len"]
- else:
- self.max_len = 128
-
- if self._mapper is None:
- self._mapper = EmnistMapper(
- # init_token=self.init_token,
- pad_token=self.pad_token,
- eos_token=self.eos_token,
- )
- self.tensor_transform = ToTensor()
-
- self.softmax = nn.Softmax(dim=2)
-
- @torch.no_grad()
- def _generate_sentence(self, image: Tensor) -> Tuple[List, float]:
- logits = self.network(image)
- # Convert logits to probabilities.
- probs = self.softmax(logits).squeeze(0)
-
- confidence, pred_tokens = probs.max(1)
- pred_tokens = pred_tokens
-
- eos_index = torch.nonzero(
- pred_tokens == self._mapper(self.eos_token), as_tuple=False,
- )
-
- eos_index = eos_index[0].item() if eos_index.nelement() else -1
-
- predicted_characters = "".join(
- [self.mapper(x) for x in pred_tokens[:eos_index].tolist()]
- )
-
- confidence = np.min(confidence.tolist())
-
- return predicted_characters, confidence
-
- @torch.no_grad()
- def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
- """Predict on a single input."""
- self.eval()
-
- if image.dtype == np.uint8:
- # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
- image = self.tensor_transform(image)
-
- # Rescale image between 0 and 1.
- if image.dtype == torch.uint8:
- # If the image is an unscaled tensor.
- image = image.type("torch.FloatTensor") / 255
-
- # Put the image tensor on the device the model weights are on.
- image = image.to(self.device)
-
- (predicted_characters, confidence_of_prediction,) = self._generate_sentence(
- image
- )
-
- return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/models/vision_transformer_model.py b/src/text_recognizer/models/transformer_model.py
index 3d36437..968a047 100644
--- a/src/text_recognizer/models/vision_transformer_model.py
+++ b/src/text_recognizer/models/transformer_model.py
@@ -13,7 +13,7 @@ from text_recognizer.models.base import Model
from text_recognizer.networks import greedy_decoder
-class VisionTransformerModel(Model):
+class TransformerModel(Model):
"""Model for predicting a sequence of characters from an image of a text line with a cnn-transformer."""
def __init__(
@@ -50,10 +50,7 @@ class VisionTransformerModel(Model):
self.init_token = dataset_args["args"]["init_token"]
self.pad_token = dataset_args["args"]["pad_token"]
self.eos_token = dataset_args["args"]["eos_token"]
- if network_args is not None:
- self.max_len = network_args["max_len"]
- else:
- self.max_len = 120
+ self.max_len = 120
if self._mapper is None:
self._mapper = EmnistMapper(
@@ -67,7 +64,7 @@ class VisionTransformerModel(Model):
@torch.no_grad()
def _generate_sentence(self, image: Tensor) -> Tuple[List, float]:
- src = self.network.preprocess_input(image)
+ src = self.network.extract_image_features(image)
memory = self.network.encoder(src)
confidence_of_predictions = []
@@ -75,7 +72,7 @@ class VisionTransformerModel(Model):
for _ in range(self.max_len - 1):
trg = torch.tensor(trg_indices, device=self.device)[None, :].long()
- trg = self.network.preprocess_target(trg)
+ trg = self.network.target_embedding(trg)
logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None)
# Convert logits to probabilities.
@@ -101,7 +98,7 @@ class VisionTransformerModel(Model):
self.eval()
if image.dtype == np.uint8:
- # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1].
image = self.tensor_transform(image)
# Rescale image between 0 and 1.