summaryrefslogtreecommitdiff
path: root/text_recognizer/models/transformer.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-03 18:18:48 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-03 18:18:48 +0200
commitbd4bd443f339e95007bfdabf3e060db720f4d4b9 (patch)
treee55cb3744904f7c2a0348b100c7e92a65e538a16 /text_recognizer/models/transformer.py
parent75801019981492eedf9280cb352eea3d8e99b65f (diff)
Training working, multiple bug fixes
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r--text_recognizer/models/transformer.py36
1 files changed, 18 insertions, 18 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 91e088d..5fb84a7 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -5,7 +5,6 @@ import attr
import torch
from torch import Tensor
-from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.models.metrics import CharacterErrorRate
from text_recognizer.models.base import BaseLitModel
@@ -14,14 +13,14 @@ from text_recognizer.models.base import BaseLitModel
class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- mapping: Type[AbstractMapping] = attr.ib(default=None)
+ max_output_len: int = attr.ib(default=451)
start_token: str = attr.ib(default="<s>")
end_token: str = attr.ib(default="<e>")
pad_token: str = attr.ib(default="<p>")
- start_index: Tensor = attr.ib(init=False)
- end_index: Tensor = attr.ib(init=False)
- pad_index: Tensor = attr.ib(init=False)
+ start_index: int = attr.ib(init=False)
+ end_index: int = attr.ib(init=False)
+ pad_index: int = attr.ib(init=False)
ignore_indices: Set[Tensor] = attr.ib(init=False)
val_cer: CharacterErrorRate = attr.ib(init=False)
@@ -29,9 +28,9 @@ class TransformerLitModel(BaseLitModel):
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
- self.start_index = self.mapping.get_index(self.start_token)
- self.end_index = self.mapping.get_index(self.end_token)
- self.pad_index = self.mapping.get_index(self.pad_token)
+ self.start_index = int(self.mapping.get_index(self.start_token))
+ self.end_index = int(self.mapping.get_index(self.end_token))
+ self.pad_index = int(self.mapping.get_index(self.pad_token))
self.ignore_indices = set([self.start_index, self.end_index, self.pad_index])
self.val_cer = CharacterErrorRate(self.ignore_indices)
self.test_cer = CharacterErrorRate(self.ignore_indices)
@@ -93,23 +92,24 @@ class TransformerLitModel(BaseLitModel):
output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
output[:, 0] = self.start_index
- for i in range(1, self.max_output_len):
- context = output[:, :i] # (bsz, i)
- logits = self.network.decode(z, context) # (i, bsz, c)
- tokens = torch.argmax(logits, dim=-1) # (i, bsz)
- output[:, i : i + 1] = tokens[-1:]
+ for Sy in range(1, self.max_output_len):
+ context = output[:, :Sy] # (B, Sy)
+ logits = self.network.decode(z, context) # (B, Sy, C)
+ tokens = torch.argmax(logits, dim=-1) # (B, Sy)
+ output[:, Sy : Sy + 1] = tokens[:, -1:]
# Early stopping of prediction loop if token is end or padding token.
if (
- output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index
+ (output[:, Sy - 1] == self.end_index)
+ | (output[:, Sy - 1] == self.pad_index)
).all():
break
# Set all tokens after end token to pad token.
- for i in range(1, self.max_output_len):
- idx = (
- output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index
+ for Sy in range(1, self.max_output_len):
+ idx = (output[:, Sy - 1] == self.end_index) | (
+ output[:, Sy - 1] == self.pad_index
)
- output[idx, i] = self.pad_index
+ output[idx, Sy] = self.pad_index
return output