summaryrefslogtreecommitdiff
path: root/text_recognizer/criterions/transducer.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-25 22:30:22 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-25 22:30:22 +0200
commit00dd3df9f2e29622248668662cb40ff0c8889145 (patch)
treef3e352fa372a8c6bc87455743046b437a54a6dc8 /text_recognizer/criterions/transducer.py
parent8c380f60a4f84f69ab4d2030cce663b4136fa0a7 (diff)
Format transducer
Diffstat (limited to 'text_recognizer/criterions/transducer.py')
-rw-r--r--text_recognizer/criterions/transducer.py118
1 files changed, 42 insertions, 76 deletions
diff --git a/text_recognizer/criterions/transducer.py b/text_recognizer/criterions/transducer.py
index 3c8d5d0..089bff7 100644
--- a/text_recognizer/criterions/transducer.py
+++ b/text_recognizer/criterions/transducer.py
@@ -6,9 +6,8 @@ Stolen from:
"""
from pathlib import Path
import itertools
-from typing import Dict, List, Optional, Union, Tuple
+from typing import Dict, List, Optional, Sequence, Set, Tuple
-from loguru import logger
import gtn
import torch
from torch import nn
@@ -65,17 +64,23 @@ def make_transitions_graph(
return transitions
-def make_lexicon_graph(word_pieces: List, graphemes_to_idx: Dict) -> gtn.Graph:
+def make_lexicon_graph(
+ word_pieces: List, graphemes_to_idx: Dict, special_tokens: Optional[Set]
+) -> gtn.Graph:
"""Constructs a graph which transduces letters to word pieces."""
graph = gtn.Graph(False)
graph.add_node(True, True)
for i, wp in enumerate(word_pieces):
prev = 0
- for l in wp[:-1]:
+ if special_tokens is not None and wp in special_tokens:
n = graph.add_node()
- graph.add_arc(prev, n, graphemes_to_idx[l], gtn.epsilon)
- prev = n
- graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i)
+ graph.add_arc(prev, n, graphemes_to_idx[wp], i)
+ else:
+ for character in wp[:-1]:
+ n = graph.add_node()
+ graph.add_arc(prev, n, graphemes_to_idx[character], gtn.epsilon)
+ prev = n
+ graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i)
graph.arc_sort()
return graph
@@ -254,8 +259,7 @@ TransducerLoss = TransducerLossFunction.apply
class Transducer(nn.Module):
def __init__(
self,
- tokens: List,
- graphemes_to_idx: Dict,
+ preprocessor: Preprocessor,
ngram: int = 0,
transitions: str = None,
blank: str = "none",
@@ -265,12 +269,7 @@ class Transducer(nn.Module):
"""A generic transducer loss function.
Args:
- tokens (List) : A list of iterable objects (e.g. strings, tuples, etc)
- representing the output tokens of the model (e.g. letters,
- word-pieces, words). For example ["a", "b", "ab", "ba", "aba"]
- could be a list of sub-word tokens.
- graphemes_to_idx (dict) : A dictionary mapping grapheme units (e.g.
- "a", "b", ..) to their corresponding integer index.
+ preprocessor (Preprocessor) : The IAM preprocessor for word pieces.
ngram (int) : Order of the token-level transition model. If `ngram=0`
then no transition model is used.
blank (string) : Specifies the usage of blank token
@@ -287,30 +286,47 @@ class Transducer(nn.Module):
raise ValueError(
"Invalid value specified for blank. Must be in ['optional', 'forced', 'none']"
)
- self.tokens = make_token_graph(tokens, blank=blank, allow_repeats=allow_repeats)
- self.lexicon = make_lexicon_graph(tokens, graphemes_to_idx)
+ self.tokens = make_token_graph(
+ preprocessor.tokens, blank=blank, allow_repeats=allow_repeats
+ )
+ self.lexicon = make_lexicon_graph(
+ preprocessor.tokens,
+ preprocessor.graphemes_to_index,
+ preprocessor.special_tokens,
+ )
self.ngram = ngram
+
+ self.transitions: Optional[gtn.Graph] = None
+ self.transitions_params: Optional[nn.Parameter] = None
+ self._load_transitions(transitions, preprocessor, blank)
+
if ngram > 0 and transitions is not None:
raise ValueError("Only one of ngram and transitions may be specified")
- if ngram > 0:
- transitions = make_transitions_graph(
- ngram, len(tokens) + int(blank != "none"), True
- )
+ self.reduction = reduction
+ def _load_transitions(
+ self, transitions: Optional[str], preprocessor: Preprocessor, blank: str
+ ):
+ """Loads transition graph."""
+ processed_path = (
+ Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
+ )
+ if transitions is not None:
+ transitions = gtn.load(str(processed_path / transitions))
+ if self.ngram > 0:
+ self.transitions = make_transitions_graph(
+ self.ngram, len(preprocessor.tokens) + int(blank != "none"), True
+ )
if transitions is not None:
self.transitions = transitions
self.transitions.arc_sort()
self.transitions_params = nn.Parameter(
torch.zeros(self.transitions.num_arcs())
)
- else:
- self.transitions = None
- self.transitions_params = None
- self.reduction = reduction
def forward(self, inputs: Tensor, targets: Tensor) -> TransducerLoss:
- TransducerLoss(
+ return TransducerLoss(
inputs,
targets,
self.tokens,
@@ -358,53 +374,3 @@ class Transducer(nn.Module):
gtn.parallel_for(process, range(B))
predictions = [torch.IntTensor(path) for path in paths]
return predictions
-
-
-def load_transducer_loss(
- num_features: int,
- ngram: int,
- tokens: str,
- lexicon: str,
- transitions: str,
- blank: str,
- allow_repeats: bool,
- prepend_wordsep: bool = False,
- use_words: bool = False,
- data_dir: Optional[Union[str, Path]] = None,
- reduction: str = "mean",
-) -> Tuple[Transducer, int]:
- if data_dir is None:
- data_dir = (
- Path(__file__).resolve().parents[4] / "data" / "raw" / "iam" / "iamdb"
- )
- logger.debug(f"Using data dir: {data_dir}")
- if not data_dir.exists():
- raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
- else:
- data_dir = Path(data_dir)
- processed_path = (
- Path(__file__).resolve().parents[4] / "data" / "processed" / "iam_lines"
- )
- tokens_path = processed_path / tokens
- lexicon_path = processed_path / lexicon
-
- if transitions is not None:
- transitions = gtn.load(str(processed_path / transitions))
-
- preprocessor = Preprocessor(
- data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep,
- )
-
- num_tokens = preprocessor.num_tokens
-
- criterion = Transducer(
- preprocessor.tokens,
- preprocessor.graphemes_to_index,
- ngram=ngram,
- transitions=transitions,
- blank=blank,
- allow_repeats=allow_repeats,
- reduction=reduction,
- )
-
- return criterion, num_tokens + int(blank != "none")