diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-25 22:30:22 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-25 22:30:22 +0200 |
commit | 00dd3df9f2e29622248668662cb40ff0c8889145 (patch) | |
tree | f3e352fa372a8c6bc87455743046b437a54a6dc8 | |
parent | 8c380f60a4f84f69ab4d2030cce663b4136fa0a7 (diff) |
Format transducer
-rw-r--r-- | text_recognizer/criterions/transducer.py | 118 |
1 files changed, 42 insertions, 76 deletions
diff --git a/text_recognizer/criterions/transducer.py b/text_recognizer/criterions/transducer.py index 3c8d5d0..089bff7 100644 --- a/text_recognizer/criterions/transducer.py +++ b/text_recognizer/criterions/transducer.py @@ -6,9 +6,8 @@ Stolen from: """ from pathlib import Path import itertools -from typing import Dict, List, Optional, Union, Tuple +from typing import Dict, List, Optional, Sequence, Set, Tuple -from loguru import logger import gtn import torch from torch import nn @@ -65,17 +64,23 @@ def make_transitions_graph( return transitions -def make_lexicon_graph(word_pieces: List, graphemes_to_idx: Dict) -> gtn.Graph: +def make_lexicon_graph( + word_pieces: List, graphemes_to_idx: Dict, special_tokens: Optional[Set] +) -> gtn.Graph: """Constructs a graph which transduces letters to word pieces.""" graph = gtn.Graph(False) graph.add_node(True, True) for i, wp in enumerate(word_pieces): prev = 0 - for l in wp[:-1]: + if special_tokens is not None and wp in special_tokens: n = graph.add_node() - graph.add_arc(prev, n, graphemes_to_idx[l], gtn.epsilon) - prev = n - graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i) + graph.add_arc(prev, n, graphemes_to_idx[wp], i) + else: + for character in wp[:-1]: + n = graph.add_node() + graph.add_arc(prev, n, graphemes_to_idx[character], gtn.epsilon) + prev = n + graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i) graph.arc_sort() return graph @@ -254,8 +259,7 @@ TransducerLoss = TransducerLossFunction.apply class Transducer(nn.Module): def __init__( self, - tokens: List, - graphemes_to_idx: Dict, + preprocessor: Preprocessor, ngram: int = 0, transitions: str = None, blank: str = "none", @@ -265,12 +269,7 @@ class Transducer(nn.Module): """A generic transducer loss function. Args: - tokens (List) : A list of iterable objects (e.g. strings, tuples, etc) - representing the output tokens of the model (e.g. letters, - word-pieces, words). For example ["a", "b", "ab", "ba", "aba"] - could be a list of sub-word tokens. - graphemes_to_idx (dict) : A dictionary mapping grapheme units (e.g. - "a", "b", ..) to their corresponding integer index. + preprocessor (Preprocessor) : The IAM preprocessor for word pieces. ngram (int) : Order of the token-level transition model. If `ngram=0` then no transition model is used. blank (string) : Specifies the usage of blank token @@ -287,30 +286,47 @@ class Transducer(nn.Module): raise ValueError( "Invalid value specified for blank. Must be in ['optional', 'forced', 'none']" ) - self.tokens = make_token_graph(tokens, blank=blank, allow_repeats=allow_repeats) - self.lexicon = make_lexicon_graph(tokens, graphemes_to_idx) + self.tokens = make_token_graph( + preprocessor.tokens, blank=blank, allow_repeats=allow_repeats + ) + self.lexicon = make_lexicon_graph( + preprocessor.tokens, + preprocessor.graphemes_to_index, + preprocessor.special_tokens, + ) self.ngram = ngram + + self.transitions: Optional[gtn.Graph] = None + self.transitions_params: Optional[nn.Parameter] = None + self._load_transitions(transitions, preprocessor, blank) + if ngram > 0 and transitions is not None: raise ValueError("Only one of ngram and transitions may be specified") - if ngram > 0: - transitions = make_transitions_graph( - ngram, len(tokens) + int(blank != "none"), True - ) + self.reduction = reduction + def _load_transitions( + self, transitions: Optional[str], preprocessor: Preprocessor, blank: str + ): + """Loads transition graph.""" + processed_path = ( + Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines" + ) + if transitions is not None: + transitions = gtn.load(str(processed_path / transitions)) + if self.ngram > 0: + self.transitions = make_transitions_graph( + self.ngram, len(preprocessor.tokens) + int(blank != "none"), True + ) if transitions is not None: self.transitions = transitions self.transitions.arc_sort() self.transitions_params = nn.Parameter( torch.zeros(self.transitions.num_arcs()) ) - else: - self.transitions = None - self.transitions_params = None - self.reduction = reduction def forward(self, inputs: Tensor, targets: Tensor) -> TransducerLoss: - TransducerLoss( + return TransducerLoss( inputs, targets, self.tokens, @@ -358,53 +374,3 @@ class Transducer(nn.Module): gtn.parallel_for(process, range(B)) predictions = [torch.IntTensor(path) for path in paths] return predictions - - -def load_transducer_loss( - num_features: int, - ngram: int, - tokens: str, - lexicon: str, - transitions: str, - blank: str, - allow_repeats: bool, - prepend_wordsep: bool = False, - use_words: bool = False, - data_dir: Optional[Union[str, Path]] = None, - reduction: str = "mean", -) -> Tuple[Transducer, int]: - if data_dir is None: - data_dir = ( - Path(__file__).resolve().parents[4] / "data" / "raw" / "iam" / "iamdb" - ) - logger.debug(f"Using data dir: {data_dir}") - if not data_dir.exists(): - raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") - else: - data_dir = Path(data_dir) - processed_path = ( - Path(__file__).resolve().parents[4] / "data" / "processed" / "iam_lines" - ) - tokens_path = processed_path / tokens - lexicon_path = processed_path / lexicon - - if transitions is not None: - transitions = gtn.load(str(processed_path / transitions)) - - preprocessor = Preprocessor( - data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep, - ) - - num_tokens = preprocessor.num_tokens - - criterion = Transducer( - preprocessor.tokens, - preprocessor.graphemes_to_index, - ngram=ngram, - transitions=transitions, - blank=blank, - allow_repeats=allow_repeats, - reduction=reduction, - ) - - return criterion, num_tokens + int(blank != "none") |