summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r--src/text_recognizer/models/__init__.py7
-rw-r--r--src/text_recognizer/models/base.py9
-rw-r--r--src/text_recognizer/models/character_model.py3
-rw-r--r--src/text_recognizer/models/crnn_model.py (renamed from src/text_recognizer/models/line_ctc_model.py)10
-rw-r--r--src/text_recognizer/models/metrics.py5
-rw-r--r--src/text_recognizer/models/transformer_encoder_model.py111
-rw-r--r--src/text_recognizer/models/vision_transformer_model.py12
7 files changed, 140 insertions, 17 deletions
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index 0855079..28aa52e 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -1,16 +1,19 @@
"""Model modules."""
from .base import Model
from .character_model import CharacterModel
-from .line_ctc_model import LineCTCModel
+from .crnn_model import CRNNModel
from .metrics import accuracy, cer, wer
+from .transformer_encoder_model import TransformerEncoderModel
from .vision_transformer_model import VisionTransformerModel
__all__ = [
"Model",
"cer",
"CharacterModel",
+ "CRNNModel",
"CNNTransfromerModel",
- "LineCTCModel",
"accuracy",
+ "TransformerEncoderModel",
+ "VisionTransformerModel",
"wer",
]
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index cbef787..cc44c92 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -141,11 +141,12 @@ class Model(ABC):
"transform" in self.dataset_args["args"]
and self.dataset_args["args"]["transform"] is not None
):
- transform_ = [
- getattr(transforms_module, t["type"])()
- for t in self.dataset_args["args"]["transform"]
- ]
+ transform_ = []
+ for t in self.dataset_args["args"]["transform"]:
+ args = t["args"] or {}
+ transform_.append(getattr(transforms_module, t["type"])(**args))
self.dataset_args["args"]["transform"] = Compose(transform_)
+
if (
"target_transform" in self.dataset_args["args"]
and self.dataset_args["args"]["target_transform"] is not None
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index 3cf6695..f9944f3 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -47,8 +47,9 @@ class CharacterModel(Model):
swa_args,
device,
)
+ self.pad_token = dataset_args["args"]["pad_token"]
if self._mapper is None:
- self._mapper = EmnistMapper()
+ self._mapper = EmnistMapper(pad_token=self.pad_token,)
self.tensor_transform = ToTensor()
self.softmax = nn.Softmax(dim=0)
diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/crnn_model.py
index cdc2d8b..1e01a83 100644
--- a/src/text_recognizer/models/line_ctc_model.py
+++ b/src/text_recognizer/models/crnn_model.py
@@ -1,4 +1,4 @@
-"""Defines the LineCTCModel class."""
+"""Defines the CRNNModel class."""
from typing import Callable, Dict, Optional, Tuple, Type, Union
import numpy as np
@@ -13,7 +13,7 @@ from text_recognizer.models.base import Model
from text_recognizer.networks import greedy_decoder
-class LineCTCModel(Model):
+class CRNNModel(Model):
"""Model for predicting a sequence of characters from an image of a text line."""
def __init__(
@@ -47,8 +47,10 @@ class LineCTCModel(Model):
swa_args,
device,
)
+
+ self.pad_token = dataset_args["args"]["pad_token"]
if self._mapper is None:
- self._mapper = EmnistMapper()
+ self._mapper = EmnistMapper(pad_token=self.pad_token,)
self.tensor_transform = ToTensor()
def criterion(self, output: Tensor, targets: Tensor) -> Tensor:
@@ -112,6 +114,6 @@ class LineCTCModel(Model):
log_probs, _ = log_probs.max(dim=2)
predicted_characters = "".join(raw_pred[0])
- confidence_of_prediction = torch.exp(-log_probs.sum()).item()
+ confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item()
return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py
index 6a26216..42c3c6e 100644
--- a/src/text_recognizer/models/metrics.py
+++ b/src/text_recognizer/models/metrics.py
@@ -17,7 +17,10 @@ def accuracy(outputs: Tensor, labels: Tensor) -> float:
float: The accuracy for the batch.
"""
- _, predicted = torch.max(outputs.data, dim=1)
+ # eos_index = torch.nonzero(labels == eos, as_tuple=False)
+ # eos_index = eos_index[0].item() if eos_index.nelement() else -1
+
+ _, predicted = torch.max(outputs, dim=-1)
acc = (predicted == labels).sum().float() / labels.shape[0]
acc = acc.item()
return acc
diff --git a/src/text_recognizer/models/transformer_encoder_model.py b/src/text_recognizer/models/transformer_encoder_model.py
new file mode 100644
index 0000000..e35e298
--- /dev/null
+++ b/src/text_recognizer/models/transformer_encoder_model.py
@@ -0,0 +1,111 @@
+"""Defines the CNN-Transformer class."""
+from typing import Callable, Dict, List, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+
+
+class TransformerEncoderModel(Model):
+ """A class for only using the encoder part in the sequence modelling."""
+
+ def __init__(
+ self,
+ network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ # self.init_token = dataset_args["args"]["init_token"]
+ self.pad_token = dataset_args["args"]["pad_token"]
+ self.eos_token = dataset_args["args"]["eos_token"]
+ if network_args is not None:
+ self.max_len = network_args["max_len"]
+ else:
+ self.max_len = 128
+
+ if self._mapper is None:
+ self._mapper = EmnistMapper(
+ # init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ )
+ self.tensor_transform = ToTensor()
+
+ self.softmax = nn.Softmax(dim=2)
+
+ @torch.no_grad()
+ def _generate_sentence(self, image: Tensor) -> Tuple[List, float]:
+ logits = self.network(image)
+ # Convert logits to probabilities.
+ probs = self.softmax(logits).squeeze(0)
+
+ confidence, pred_tokens = probs.max(1)
+ pred_tokens = pred_tokens
+
+ eos_index = torch.nonzero(
+ pred_tokens == self._mapper(self.eos_token), as_tuple=False,
+ )
+
+ eos_index = eos_index[0].item() if eos_index.nelement() else -1
+
+ predicted_characters = "".join(
+ [self.mapper(x) for x in pred_tokens[:eos_index].tolist()]
+ )
+
+ confidence = np.min(confidence.tolist())
+
+ return predicted_characters, confidence
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
+ """Predict on a single input."""
+ self.eval()
+
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+
+ # Rescale image between 0 and 1.
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+
+ (predicted_characters, confidence_of_prediction,) = self._generate_sentence(
+ image
+ )
+
+ return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/models/vision_transformer_model.py b/src/text_recognizer/models/vision_transformer_model.py
index 20bd4ca..3d36437 100644
--- a/src/text_recognizer/models/vision_transformer_model.py
+++ b/src/text_recognizer/models/vision_transformer_model.py
@@ -53,7 +53,7 @@ class VisionTransformerModel(Model):
if network_args is not None:
self.max_len = network_args["max_len"]
else:
- self.max_len = 128
+ self.max_len = 120
if self._mapper is None:
self._mapper = EmnistMapper(
@@ -73,10 +73,10 @@ class VisionTransformerModel(Model):
confidence_of_predictions = []
trg_indices = [self.mapper(self.init_token)]
- for _ in range(self.max_len):
+ for _ in range(self.max_len - 1):
trg = torch.tensor(trg_indices, device=self.device)[None, :].long()
- trg, trg_mask = self.network.preprocess_target(trg)
- logits = self.network.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+ trg = self.network.preprocess_target(trg)
+ logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None)
# Convert logits to probabilities.
probs = self.softmax(logits)
@@ -112,6 +112,8 @@ class VisionTransformerModel(Model):
# Put the image tensor on the device the model weights are on.
image = image.to(self.device)
- predicted_characters, confidence_of_prediction = self._generate_sentence(image)
+ (predicted_characters, confidence_of_prediction,) = self._generate_sentence(
+ image
+ )
return predicted_characters, confidence_of_prediction