diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
commit | 7e8e54e84c63171e748bbf09516fd517e6821ace (patch) | |
tree | 996093f75a5d488dddf7ea1f159ed343a561ef89 /tasks | |
parent | b0719d84138b6bbe5f04a4982dfca673aea1a368 (diff) |
Inital commit for refactoring to lightning
Diffstat (limited to 'tasks')
-rw-r--r-- | tasks/build_transitions.py | 263 | ||||
-rwxr-xr-x | tasks/create_emnist_lines_datasets.sh | 4 | ||||
-rwxr-xr-x | tasks/create_iam_paragraphs.sh | 2 | ||||
-rwxr-xr-x | tasks/download_emnist.sh | 3 | ||||
-rwxr-xr-x | tasks/download_iam.sh | 2 | ||||
-rw-r--r-- | tasks/make_wordpieces.py | 114 | ||||
-rwxr-xr-x | tasks/prepare_experiments.sh | 3 | ||||
-rwxr-xr-x | tasks/test_functionality.sh | 2 | ||||
-rwxr-xr-x | tasks/train.sh | 68 |
9 files changed, 461 insertions, 0 deletions
diff --git a/tasks/build_transitions.py b/tasks/build_transitions.py new file mode 100644 index 0000000..91f8c1a --- /dev/null +++ b/tasks/build_transitions.py @@ -0,0 +1,263 @@ +"""Builds transition graph. + +Most code stolen from here: + + https://github.com/facebookresearch/gtn_applications/blob/master/scripts/build_transitions.py + +""" + +import collections +import itertools +from pathlib import Path +from typing import Any, Dict, List, Optional + +import click +import gtn +from loguru import logger + + +START_IDX = -1 +END_IDX = -2 +WORDSEP = "▁" + + +def build_graph(ngrams: List, disable_backoff: bool = False) -> gtn.Graph: + """Returns a gtn Graph based on the ngrams.""" + graph = gtn.Graph(False) + ngram = len(ngrams) + state_to_node = {} + + def get_node(state: Optional[List]) -> Any: + node = state_to_node.get(state, None) + + if node is not None: + return node + + start = state == tuple([START_IDX]) if ngram > 1 else True + end = state == tuple([END_IDX]) if ngram > 1 else True + node = graph.add_node(start, end) + state_to_node[state] = node + + if not disable_backoff and not end: + # Add back off when adding node. + for n in range(1, len(state) + 1): + backoff_node = state_to_node.get(state[n:], None) + + # Epsilon transition to the back-off state. + if backoff_node is not None: + graph.add_arc(node, backoff_node, gtn.epsilon) + break + return node + + for grams in ngrams: + for gram in grams: + istate, ostate = gram[:-1], gram[len(gram) - ngram + 1 :] + inode = get_node(istate) + + if END_IDX not in gram[1:] and gram[1:] not in state_to_node: + raise ValueError( + "Ill formed counts: if (x, y_1, ..., y_{n-1}) is above" + "the n-gram threshold, then (y_1, ..., y_{n-1}) must be" + "above the (n-1)-gram threshold" + ) + + if END_IDX in ostate: + # Merge all state having </s> into one as final graph generated + # will be similar. + ostate = tuple([END_IDX]) + + onode = get_node(ostate) + # p(gram[-1] | gram[:-1]) + graph.add_arc( + inode, onode, gtn.epsilon if gram[-1] == END_IDX else gram[-1] + ) + return graph + + +def count_ngrams(lines: List, ngram: List, tokens_to_index: Dict) -> List: + """Counts the number of ngrams.""" + counts = [collections.Counter() for _ in range(ngram)] + for line in lines: + # Prepend implicit start token. + token_line = [START_IDX] + for t in line: + token_line.append(tokens_to_index[t]) + token_line.append(END_IDX) + for n, counter in enumerate(counts): + start_offset = n == 0 + end_offset = ngram == 1 + for e in range(n + start_offset, len(token_line) - end_offset): + counter[tuple(token_line[e - n : e + 1])] += 1 + + return counts + + +def prune_ngrams(ngrams: List, prune: List) -> List: + """Prunes ngrams.""" + pruned_ngrams = [] + for n, grams in enumerate(ngrams): + grams = grams.most_common() + pruned_grams = [gram for gram, c in grams if c > prune[n]] + pruned_ngrams.append(pruned_grams) + return pruned_ngrams + + +def add_blank_grams(pruned_ngrams: List, num_tokens: int, blank: str) -> List: + """Adds blank token to grams.""" + all_grams = [gram for grams in pruned_ngrams for gram in grams] + maxorder = len(pruned_ngrams) + blank_grams = {} + if blank == "forced": + pruned_ngrams = [pruned_ngrams[0] if i == 0 else [] for i in range(maxorder)] + pruned_ngrams[0].append(tuple([num_tokens])) + blank_grams[tuple([num_tokens])] = True + + for gram in all_grams: + # Iterate over all possibilities by using a vector of 0s, 1s to + # denote whether a blank is being used at each position. + if blank == "optional": + # Given a gram ab.. if order n, we have n + 1 positions + # available whether to use blank or not. + onehot_vectors = itertools.product([0, 1], repeat=len(gram) + 1) + elif blank == "forced": + # Must include a blank token in between. + onehot_vectors = [[1] * (len(gram) + 1)] + else: + raise ValueError( + "Invalid value specificed for blank. Must be in |optional|forced|none|" + ) + + for j in onehot_vectors: + new_array = [] + for idx, oz in enumerate(j[:-1]): + if oz == 1 and gram[idx] != START_IDX: + new_array.append(num_tokens) + new_array.append(gram[idx]) + if j[-1] == 1 and gram[-1] != END_IDX: + new_array.append(num_tokens) + for n in range(maxorder): + for e in range(n, len(new_array)): + cur_gram = tuple(new_array[e - n : e + 1]) + if num_tokens in cur_gram and cur_gram not in blank_grams: + pruned_ngrams[n].append(cur_gram) + blank_grams[cur_gram] = True + + return pruned_ngrams + + +def add_self_loops(pruned_ngrams: List) -> List: + """Adds self loops to the ngrams.""" + maxorder = len(pruned_ngrams) + + # Use dict for fast search. + all_grams = set([gram for grams in pruned_ngrams for gram in grams]) + for o in range(1, maxorder): + for gram in pruned_ngrams[o - 1]: + # Repeat one of the tokens. + for pos in range(len(gram)): + if gram[pos] == START_IDX or gram[pos] == END_IDX: + continue + new_gram = gram[:pos] + (gram[pos],) + gram[pos:] + + if new_gram not in all_grams: + pruned_ngrams[o].append(new_gram) + all_grams.add(new_gram) + return pruned_ngrams + + +def parse_lines(lines: List, lexicon: Path) -> List: + """Parses lines with a lexicon.""" + with open(lexicon, "r") as f: + lex = (line.strip().split() for line in f) + lex = {line[0]: line[1:] for line in lex} + print(len(lex)) + return [[t for w in line.split(WORDSEP) for t in lex[w]] for line in lines] + + +@click.command() +@click.option("--data_dir", type=str, default=None, help="Path to dataset root.") +@click.option( + "--tokens", type=str, help="Path to token list (in order used with training)." +) +@click.option("--lexicon", type=str, default=None, help="Path to lexicon") +@click.option( + "--prune", + nargs=2, + type=int, + help="Threshold values for prune unigrams, bigrams, etc.", +) +@click.option( + "--blank", + default=click.Choice(["none", "optional", "forced"]), + help="Specifies the usage of blank token" + "'none' - do not use blank token " + "'optional' - allow an optional blank inbetween tokens" + "'forced' - force a blank inbetween tokens (also referred to as garbage token)", +) +@click.option("--self_loops", is_flag=True, help="Add self loops for tokens") +@click.option("--disable_backoff", is_flag=True, help="Disable backoff transitions") +@click.option("--save_path", default=None, help="Path to save transition graph.") +def cli( + data_dir: str, + tokens: str, + lexicon: str, + prune: List[int], + blank: str, + self_loops: bool, + disable_backoff: bool, + save_path: str, +) -> None: + """CLI for creating the transitions.""" + logger.info(f"Building {len(prune)}-gram transition models.") + + if data_dir is None: + data_dir = ( + Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines" + ) + logger.debug(f"Using data dir: {data_dir}") + if not data_dir.exists(): + raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") + else: + data_dir = Path(data_dir) + + # Build table of counts and the back-off if below threshold. + with open(data_dir / "train.txt", "r") as f: + lines = [line.strip() for line in f] + + with open(data_dir / tokens, "r") as f: + tokens = [line.strip() for line in f] + + if lexicon is not None: + lexicon = data_dir / lexicon + lines = parse_lines(lines, lexicon) + + tokens_to_idx = {t: e for e, t in enumerate(tokens)} + + ngram = len(prune) + + logger.info("Counting data...") + ngrams = count_ngrams(lines, ngram, tokens_to_idx) + + pruned_ngrams = prune_ngrams(ngrams, prune) + + for n in range(ngram): + logger.info(f"Kept {len(pruned_ngrams[n])} of {len(ngrams[n])} {n + 1}-grams") + + if blank == "none": + pruned_ngrams = add_blank_grams(pruned_ngrams, len(tokens_to_idx), blank) + + if self_loops: + pruned_ngrams = add_self_loops(pruned_ngrams) + + logger.info("Building graph from pruned ngrams...") + graph = build_graph(pruned_ngrams, disable_backoff) + logger.info(f"Graph has {graph.num_arcs()} arcs and {graph.num_nodes()} nodes.") + + save_path = str(data_dir / save_path) + + logger.info(f"Saving graph to {save_path}") + gtn.save(save_path, graph) + + +if __name__ == "__main__": + cli() diff --git a/tasks/create_emnist_lines_datasets.sh b/tasks/create_emnist_lines_datasets.sh new file mode 100755 index 0000000..6416277 --- /dev/null +++ b/tasks/create_emnist_lines_datasets.sh @@ -0,0 +1,4 @@ +#!/usr/bin/fish +command="python text_recognizer/datasets/emnist_lines_dataset.py --max_length 34 --min_overlap 0.0 --max_overlap 0.33 --num_train 100000 --num_test 10000" +echo $command +eval $command diff --git a/tasks/create_iam_paragraphs.sh b/tasks/create_iam_paragraphs.sh new file mode 100755 index 0000000..fa2bfb0 --- /dev/null +++ b/tasks/create_iam_paragraphs.sh @@ -0,0 +1,2 @@ +#!/usr/bin/fish +poetry run create-iam-paragraphs diff --git a/tasks/download_emnist.sh b/tasks/download_emnist.sh new file mode 100755 index 0000000..18c8e29 --- /dev/null +++ b/tasks/download_emnist.sh @@ -0,0 +1,3 @@ +#!/usr/bin/fish +poetry run download-emnist +poetry run create-emnist-support-files diff --git a/tasks/download_iam.sh b/tasks/download_iam.sh new file mode 100755 index 0000000..e3cf76b --- /dev/null +++ b/tasks/download_iam.sh @@ -0,0 +1,2 @@ +#!/usr/bin/fish +poetry run download-iam diff --git a/tasks/make_wordpieces.py b/tasks/make_wordpieces.py new file mode 100644 index 0000000..2ac0e2c --- /dev/null +++ b/tasks/make_wordpieces.py @@ -0,0 +1,114 @@ +"""Creates word pieces from a text file. + +Most code stolen from: + + https://github.com/facebookresearch/gtn_applications/blob/master/scripts/make_wordpieces.py + +""" +import io +from pathlib import Path +from typing import List, Optional, Union + +import click +from loguru import logger +import sentencepiece as spm + +from text_recognizer.datasets.iam_preprocessor import load_metadata + + +def iamdb_pieces( + data_dir: Path, text_file: str, num_pieces: int, output_prefix: str +) -> None: + """Creates word pieces from the iamdb train text.""" + # Load training text. + with open(data_dir / text_file, "r") as f: + text = [line.strip() for line in f] + + sp = train_spm_model( + iter(text), + num_pieces + 1, # To account for <unk> + user_symbols=["/"], # added so token is in the output set + ) + + vocab = sorted(set(w for t in text for w in t.split("▁") if w)) + if "move" not in vocab: + raise RuntimeError("`MOVE` not in vocab") + + save_pieces(sp, num_pieces, data_dir, output_prefix, vocab) + + +def train_spm_model( + sentences: iter, vocab_size: int, user_symbols: Union[str, List[str]] = "" +) -> spm.SentencePieceProcessor: + """Trains the sentence piece model.""" + model = io.BytesIO() + spm.SentencePieceTrainer.train( + sentence_iterator=sentences, + model_writer=model, + vocab_size=vocab_size, + bos_id=-1, + eos_id=-1, + character_coverage=1.0, + user_defined_symbols=user_symbols, + ) + sp = spm.SentencePieceProcessor(model_proto=model.getvalue()) + return sp + + +def save_pieces( + sp: spm.SentencePieceProcessor, + num_pieces: int, + data_dir: Path, + output_prefix: str, + vocab: set, +) -> None: + """Saves word pieces to disk.""" + logger.info(f"Generating word piece list of size {num_pieces}.") + pieces = [sp.id_to_piece(i) for i in range(1, num_pieces + 1)] + logger.info(f"Encoding vocabulary of size {len(vocab)}.") + encoded_vocab = [sp.encode_as_pieces(v) for v in vocab] + + # Save pieces to file. + with open(data_dir / f"{output_prefix}_tokens_{num_pieces}.txt", "w") as f: + f.write("\n".join(pieces)) + + # Save lexicon to a file. + with open(data_dir / f"{output_prefix}_lex_{num_pieces}.txt", "w") as f: + for v, p in zip(vocab, encoded_vocab): + f.write(f"{v} {' '.join(p)}\n") + + +@click.command() +@click.option("--data_dir", type=str, default=None, help="Path to processed iam dir.") +@click.option( + "--text_file", type=str, default=None, help="Name of sentence piece training text." +) +@click.option( + "--output_prefix", + type=str, + default="word_pieces", + help="Prefix name to store tokens and lexicon.", +) +@click.option("--num_pieces", type=int, default=1000, help="Number of word pieces.") +def cli( + data_dir: Optional[str], + text_file: Optional[str], + output_prefix: Optional[str], + num_pieces: Optional[int], +) -> None: + """CLI for training the sentence piece model.""" + if data_dir is None: + data_dir = ( + Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines" + ) + logger.debug(f"Using data dir: {data_dir}") + if not data_dir.exists(): + raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") + else: + data_dir = Path(data_dir) + + iamdb_pieces(data_dir, text_file, num_pieces, output_prefix) + + +if __name__ == "__main__": + cli() diff --git a/tasks/prepare_experiments.sh b/tasks/prepare_experiments.sh new file mode 100755 index 0000000..95a538f --- /dev/null +++ b/tasks/prepare_experiments.sh @@ -0,0 +1,3 @@ +#!/usr/bin/fish +experiments_filename=${1:-training/experiments/sample_experiment.yml} +poetry run prepare-experiments --experiments_filename $experiments_filename diff --git a/tasks/test_functionality.sh b/tasks/test_functionality.sh new file mode 100755 index 0000000..5ccf0cd --- /dev/null +++ b/tasks/test_functionality.sh @@ -0,0 +1,2 @@ +#!/usr/bin/fish +pytest -s -q text_recognizer diff --git a/tasks/train.sh b/tasks/train.sh new file mode 100755 index 0000000..60cbd23 --- /dev/null +++ b/tasks/train.sh @@ -0,0 +1,68 @@ +#!/bin/bash + + +# Add checkpoint and resume experiment +usage() { + cat << EOF + usage: ./tasks/train_crnn_line_ctc_model.sh + -f | --experiment_config Name of the experiment config. + -c | --checkpoint (Optional) The experiment name to continue from. + -p | --pretrained_weights (Optional) Path to pretrained weights. + -n | --notrain (Optional) Evaluates a trained model. + -t | --test (Optional) If set, evaluates the model on test set. + -v | --verbose (Optional) Sets the verbosity. + -h | --help Shows this message. +EOF +exit 1 +} + +experiment_config="" +checkpoint="" +pretrained_weights="" +notrain="" +test="" +verbose="" +train_command="" + +while getopts 'f:c:p:nthv' flag; do + case "${flag}" in + f) experiment_config="${OPTARG}" ;; + c) checkpoint="${OPTARG}" ;; + p) pretrained_weights="${OPTARG}" ;; + n) notrain="--notrain" ;; + t) test="--test" ;; + v) verbose="${verbose}v" ;; + h) usage ;; + *) error "Unexpected option ${flag}" ;; + esac +done + + +if [ -z ${experiment_config} ]; +then + echo "experiment_config not specified!" + usage + exit 1 +fi + +experiments_filename="training/experiments/${experiment_config}" +train_command=$(bash tasks/prepare_experiments.sh $experiments_filename) + +if [ ${checkpoint} ]; +then + train_command="${train_command} --checkpoint $checkpoint" +fi + +if [ ${pretrained_weights} ]; +then + train_command="${train_command} --pretrained_weights $pretrained_weights" +fi + +if [ ${verbose} ]; +then + train_command="${train_command} -$verbose" +fi + +train_command="${train_command} $test $notrain" +echo $train_command +eval $train_command |