From 00dd3df9f2e29622248668662cb40ff0c8889145 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Mon, 25 Oct 2021 22:30:22 +0200
Subject: Format transducer

---
 text_recognizer/criterions/transducer.py | 118 +++++++++++--------------------
 1 file changed, 42 insertions(+), 76 deletions(-)

diff --git a/text_recognizer/criterions/transducer.py b/text_recognizer/criterions/transducer.py
index 3c8d5d0..089bff7 100644
--- a/text_recognizer/criterions/transducer.py
+++ b/text_recognizer/criterions/transducer.py
@@ -6,9 +6,8 @@ Stolen from:
 """
 from pathlib import Path
 import itertools
-from typing import Dict, List, Optional, Union, Tuple
+from typing import Dict, List, Optional, Sequence, Set, Tuple
 
-from loguru import logger
 import gtn
 import torch
 from torch import nn
@@ -65,17 +64,23 @@ def make_transitions_graph(
     return transitions
 
 
-def make_lexicon_graph(word_pieces: List, graphemes_to_idx: Dict) -> gtn.Graph:
+def make_lexicon_graph(
+    word_pieces: List, graphemes_to_idx: Dict, special_tokens: Optional[Set]
+) -> gtn.Graph:
     """Constructs a graph which transduces letters to word pieces."""
     graph = gtn.Graph(False)
     graph.add_node(True, True)
     for i, wp in enumerate(word_pieces):
         prev = 0
-        for l in wp[:-1]:
+        if special_tokens is not None and wp in special_tokens:
             n = graph.add_node()
-            graph.add_arc(prev, n, graphemes_to_idx[l], gtn.epsilon)
-            prev = n
-        graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i)
+            graph.add_arc(prev, n, graphemes_to_idx[wp], i)
+        else:
+            for character in wp[:-1]:
+                n = graph.add_node()
+                graph.add_arc(prev, n, graphemes_to_idx[character], gtn.epsilon)
+                prev = n
+            graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i)
     graph.arc_sort()
     return graph
 
@@ -254,8 +259,7 @@ TransducerLoss = TransducerLossFunction.apply
 class Transducer(nn.Module):
     def __init__(
         self,
-        tokens: List,
-        graphemes_to_idx: Dict,
+        preprocessor: Preprocessor,
         ngram: int = 0,
         transitions: str = None,
         blank: str = "none",
@@ -265,12 +269,7 @@ class Transducer(nn.Module):
         """A generic transducer loss function.
 
         Args:
-            tokens (List) : A list of iterable objects (e.g. strings, tuples, etc)
-                representing the output tokens of the model (e.g. letters,
-                word-pieces, words). For example ["a", "b", "ab", "ba", "aba"]
-                could be a list of sub-word tokens.
-            graphemes_to_idx (dict) : A dictionary mapping grapheme units (e.g.
-                "a", "b", ..) to their corresponding integer index.
+            preprocessor (Preprocessor) : The IAM preprocessor for word pieces.
             ngram (int) : Order of the token-level transition model. If `ngram=0`
                 then no transition model is used.
             blank (string) : Specifies the usage of blank token
@@ -287,30 +286,47 @@ class Transducer(nn.Module):
             raise ValueError(
                 "Invalid value specified for blank. Must be in ['optional', 'forced', 'none']"
             )
-        self.tokens = make_token_graph(tokens, blank=blank, allow_repeats=allow_repeats)
-        self.lexicon = make_lexicon_graph(tokens, graphemes_to_idx)
+        self.tokens = make_token_graph(
+            preprocessor.tokens, blank=blank, allow_repeats=allow_repeats
+        )
+        self.lexicon = make_lexicon_graph(
+            preprocessor.tokens,
+            preprocessor.graphemes_to_index,
+            preprocessor.special_tokens,
+        )
         self.ngram = ngram
+
+        self.transitions: Optional[gtn.Graph] = None
+        self.transitions_params: Optional[nn.Parameter] = None
+        self._load_transitions(transitions, preprocessor, blank)
+
         if ngram > 0 and transitions is not None:
             raise ValueError("Only one of ngram and transitions may be specified")
 
-        if ngram > 0:
-            transitions = make_transitions_graph(
-                ngram, len(tokens) + int(blank != "none"), True
-            )
+        self.reduction = reduction
 
+    def _load_transitions(
+        self, transitions: Optional[str], preprocessor: Preprocessor, blank: str
+    ):
+        """Loads transition graph."""
+        processed_path = (
+            Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
+        )
+        if transitions is not None:
+            transitions = gtn.load(str(processed_path / transitions))
+        if self.ngram > 0:
+            self.transitions = make_transitions_graph(
+                self.ngram, len(preprocessor.tokens) + int(blank != "none"), True
+            )
         if transitions is not None:
             self.transitions = transitions
             self.transitions.arc_sort()
             self.transitions_params = nn.Parameter(
                 torch.zeros(self.transitions.num_arcs())
             )
-        else:
-            self.transitions = None
-            self.transitions_params = None
-        self.reduction = reduction
 
     def forward(self, inputs: Tensor, targets: Tensor) -> TransducerLoss:
-        TransducerLoss(
+        return TransducerLoss(
             inputs,
             targets,
             self.tokens,
@@ -358,53 +374,3 @@ class Transducer(nn.Module):
         gtn.parallel_for(process, range(B))
         predictions = [torch.IntTensor(path) for path in paths]
         return predictions
-
-
-def load_transducer_loss(
-    num_features: int,
-    ngram: int,
-    tokens: str,
-    lexicon: str,
-    transitions: str,
-    blank: str,
-    allow_repeats: bool,
-    prepend_wordsep: bool = False,
-    use_words: bool = False,
-    data_dir: Optional[Union[str, Path]] = None,
-    reduction: str = "mean",
-) -> Tuple[Transducer, int]:
-    if data_dir is None:
-        data_dir = (
-            Path(__file__).resolve().parents[4] / "data" / "raw" / "iam" / "iamdb"
-        )
-        logger.debug(f"Using data dir: {data_dir}")
-        if not data_dir.exists():
-            raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
-    else:
-        data_dir = Path(data_dir)
-    processed_path = (
-        Path(__file__).resolve().parents[4] / "data" / "processed" / "iam_lines"
-    )
-    tokens_path = processed_path / tokens
-    lexicon_path = processed_path / lexicon
-
-    if transitions is not None:
-        transitions = gtn.load(str(processed_path / transitions))
-
-    preprocessor = Preprocessor(
-        data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep,
-    )
-
-    num_tokens = preprocessor.num_tokens
-
-    criterion = Transducer(
-        preprocessor.tokens,
-        preprocessor.graphemes_to_index,
-        ngram=ngram,
-        transitions=transitions,
-        blank=blank,
-        allow_repeats=allow_repeats,
-        reduction=reduction,
-    )
-
-    return criterion, num_tokens + int(blank != "none")
-- 
cgit v1.2.3-70-g09d2