From 7e8e54e84c63171e748bbf09516fd517e6821ace Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 20 Mar 2021 18:09:06 +0100 Subject: Inital commit for refactoring to lightning --- src/tasks/build_transitions.py | 263 ------------------------------ src/tasks/create_emnist_lines_datasets.sh | 4 - src/tasks/create_iam_paragraphs.sh | 2 - src/tasks/download_emnist.sh | 3 - src/tasks/download_iam.sh | 2 - src/tasks/make_wordpieces.py | 114 ------------- src/tasks/prepare_experiments.sh | 3 - src/tasks/test_functionality.sh | 2 - src/tasks/train.sh | 68 -------- 9 files changed, 461 deletions(-) delete mode 100644 src/tasks/build_transitions.py delete mode 100755 src/tasks/create_emnist_lines_datasets.sh delete mode 100755 src/tasks/create_iam_paragraphs.sh delete mode 100755 src/tasks/download_emnist.sh delete mode 100755 src/tasks/download_iam.sh delete mode 100644 src/tasks/make_wordpieces.py delete mode 100755 src/tasks/prepare_experiments.sh delete mode 100755 src/tasks/test_functionality.sh delete mode 100755 src/tasks/train.sh (limited to 'src/tasks') diff --git a/src/tasks/build_transitions.py b/src/tasks/build_transitions.py deleted file mode 100644 index 91f8c1a..0000000 --- a/src/tasks/build_transitions.py +++ /dev/null @@ -1,263 +0,0 @@ -"""Builds transition graph. - -Most code stolen from here: - - https://github.com/facebookresearch/gtn_applications/blob/master/scripts/build_transitions.py - -""" - -import collections -import itertools -from pathlib import Path -from typing import Any, Dict, List, Optional - -import click -import gtn -from loguru import logger - - -START_IDX = -1 -END_IDX = -2 -WORDSEP = "▁" - - -def build_graph(ngrams: List, disable_backoff: bool = False) -> gtn.Graph: - """Returns a gtn Graph based on the ngrams.""" - graph = gtn.Graph(False) - ngram = len(ngrams) - state_to_node = {} - - def get_node(state: Optional[List]) -> Any: - node = state_to_node.get(state, None) - - if node is not None: - return node - - start = state == tuple([START_IDX]) if ngram > 1 else True - end = state == tuple([END_IDX]) if ngram > 1 else True - node = graph.add_node(start, end) - state_to_node[state] = node - - if not disable_backoff and not end: - # Add back off when adding node. - for n in range(1, len(state) + 1): - backoff_node = state_to_node.get(state[n:], None) - - # Epsilon transition to the back-off state. - if backoff_node is not None: - graph.add_arc(node, backoff_node, gtn.epsilon) - break - return node - - for grams in ngrams: - for gram in grams: - istate, ostate = gram[:-1], gram[len(gram) - ngram + 1 :] - inode = get_node(istate) - - if END_IDX not in gram[1:] and gram[1:] not in state_to_node: - raise ValueError( - "Ill formed counts: if (x, y_1, ..., y_{n-1}) is above" - "the n-gram threshold, then (y_1, ..., y_{n-1}) must be" - "above the (n-1)-gram threshold" - ) - - if END_IDX in ostate: - # Merge all state having into one as final graph generated - # will be similar. - ostate = tuple([END_IDX]) - - onode = get_node(ostate) - # p(gram[-1] | gram[:-1]) - graph.add_arc( - inode, onode, gtn.epsilon if gram[-1] == END_IDX else gram[-1] - ) - return graph - - -def count_ngrams(lines: List, ngram: List, tokens_to_index: Dict) -> List: - """Counts the number of ngrams.""" - counts = [collections.Counter() for _ in range(ngram)] - for line in lines: - # Prepend implicit start token. - token_line = [START_IDX] - for t in line: - token_line.append(tokens_to_index[t]) - token_line.append(END_IDX) - for n, counter in enumerate(counts): - start_offset = n == 0 - end_offset = ngram == 1 - for e in range(n + start_offset, len(token_line) - end_offset): - counter[tuple(token_line[e - n : e + 1])] += 1 - - return counts - - -def prune_ngrams(ngrams: List, prune: List) -> List: - """Prunes ngrams.""" - pruned_ngrams = [] - for n, grams in enumerate(ngrams): - grams = grams.most_common() - pruned_grams = [gram for gram, c in grams if c > prune[n]] - pruned_ngrams.append(pruned_grams) - return pruned_ngrams - - -def add_blank_grams(pruned_ngrams: List, num_tokens: int, blank: str) -> List: - """Adds blank token to grams.""" - all_grams = [gram for grams in pruned_ngrams for gram in grams] - maxorder = len(pruned_ngrams) - blank_grams = {} - if blank == "forced": - pruned_ngrams = [pruned_ngrams[0] if i == 0 else [] for i in range(maxorder)] - pruned_ngrams[0].append(tuple([num_tokens])) - blank_grams[tuple([num_tokens])] = True - - for gram in all_grams: - # Iterate over all possibilities by using a vector of 0s, 1s to - # denote whether a blank is being used at each position. - if blank == "optional": - # Given a gram ab.. if order n, we have n + 1 positions - # available whether to use blank or not. - onehot_vectors = itertools.product([0, 1], repeat=len(gram) + 1) - elif blank == "forced": - # Must include a blank token in between. - onehot_vectors = [[1] * (len(gram) + 1)] - else: - raise ValueError( - "Invalid value specificed for blank. Must be in |optional|forced|none|" - ) - - for j in onehot_vectors: - new_array = [] - for idx, oz in enumerate(j[:-1]): - if oz == 1 and gram[idx] != START_IDX: - new_array.append(num_tokens) - new_array.append(gram[idx]) - if j[-1] == 1 and gram[-1] != END_IDX: - new_array.append(num_tokens) - for n in range(maxorder): - for e in range(n, len(new_array)): - cur_gram = tuple(new_array[e - n : e + 1]) - if num_tokens in cur_gram and cur_gram not in blank_grams: - pruned_ngrams[n].append(cur_gram) - blank_grams[cur_gram] = True - - return pruned_ngrams - - -def add_self_loops(pruned_ngrams: List) -> List: - """Adds self loops to the ngrams.""" - maxorder = len(pruned_ngrams) - - # Use dict for fast search. - all_grams = set([gram for grams in pruned_ngrams for gram in grams]) - for o in range(1, maxorder): - for gram in pruned_ngrams[o - 1]: - # Repeat one of the tokens. - for pos in range(len(gram)): - if gram[pos] == START_IDX or gram[pos] == END_IDX: - continue - new_gram = gram[:pos] + (gram[pos],) + gram[pos:] - - if new_gram not in all_grams: - pruned_ngrams[o].append(new_gram) - all_grams.add(new_gram) - return pruned_ngrams - - -def parse_lines(lines: List, lexicon: Path) -> List: - """Parses lines with a lexicon.""" - with open(lexicon, "r") as f: - lex = (line.strip().split() for line in f) - lex = {line[0]: line[1:] for line in lex} - print(len(lex)) - return [[t for w in line.split(WORDSEP) for t in lex[w]] for line in lines] - - -@click.command() -@click.option("--data_dir", type=str, default=None, help="Path to dataset root.") -@click.option( - "--tokens", type=str, help="Path to token list (in order used with training)." -) -@click.option("--lexicon", type=str, default=None, help="Path to lexicon") -@click.option( - "--prune", - nargs=2, - type=int, - help="Threshold values for prune unigrams, bigrams, etc.", -) -@click.option( - "--blank", - default=click.Choice(["none", "optional", "forced"]), - help="Specifies the usage of blank token" - "'none' - do not use blank token " - "'optional' - allow an optional blank inbetween tokens" - "'forced' - force a blank inbetween tokens (also referred to as garbage token)", -) -@click.option("--self_loops", is_flag=True, help="Add self loops for tokens") -@click.option("--disable_backoff", is_flag=True, help="Disable backoff transitions") -@click.option("--save_path", default=None, help="Path to save transition graph.") -def cli( - data_dir: str, - tokens: str, - lexicon: str, - prune: List[int], - blank: str, - self_loops: bool, - disable_backoff: bool, - save_path: str, -) -> None: - """CLI for creating the transitions.""" - logger.info(f"Building {len(prune)}-gram transition models.") - - if data_dir is None: - data_dir = ( - Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines" - ) - logger.debug(f"Using data dir: {data_dir}") - if not data_dir.exists(): - raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") - else: - data_dir = Path(data_dir) - - # Build table of counts and the back-off if below threshold. - with open(data_dir / "train.txt", "r") as f: - lines = [line.strip() for line in f] - - with open(data_dir / tokens, "r") as f: - tokens = [line.strip() for line in f] - - if lexicon is not None: - lexicon = data_dir / lexicon - lines = parse_lines(lines, lexicon) - - tokens_to_idx = {t: e for e, t in enumerate(tokens)} - - ngram = len(prune) - - logger.info("Counting data...") - ngrams = count_ngrams(lines, ngram, tokens_to_idx) - - pruned_ngrams = prune_ngrams(ngrams, prune) - - for n in range(ngram): - logger.info(f"Kept {len(pruned_ngrams[n])} of {len(ngrams[n])} {n + 1}-grams") - - if blank == "none": - pruned_ngrams = add_blank_grams(pruned_ngrams, len(tokens_to_idx), blank) - - if self_loops: - pruned_ngrams = add_self_loops(pruned_ngrams) - - logger.info("Building graph from pruned ngrams...") - graph = build_graph(pruned_ngrams, disable_backoff) - logger.info(f"Graph has {graph.num_arcs()} arcs and {graph.num_nodes()} nodes.") - - save_path = str(data_dir / save_path) - - logger.info(f"Saving graph to {save_path}") - gtn.save(save_path, graph) - - -if __name__ == "__main__": - cli() diff --git a/src/tasks/create_emnist_lines_datasets.sh b/src/tasks/create_emnist_lines_datasets.sh deleted file mode 100755 index 6416277..0000000 --- a/src/tasks/create_emnist_lines_datasets.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/usr/bin/fish -command="python text_recognizer/datasets/emnist_lines_dataset.py --max_length 34 --min_overlap 0.0 --max_overlap 0.33 --num_train 100000 --num_test 10000" -echo $command -eval $command diff --git a/src/tasks/create_iam_paragraphs.sh b/src/tasks/create_iam_paragraphs.sh deleted file mode 100755 index fa2bfb0..0000000 --- a/src/tasks/create_iam_paragraphs.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/usr/bin/fish -poetry run create-iam-paragraphs diff --git a/src/tasks/download_emnist.sh b/src/tasks/download_emnist.sh deleted file mode 100755 index 18c8e29..0000000 --- a/src/tasks/download_emnist.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/fish -poetry run download-emnist -poetry run create-emnist-support-files diff --git a/src/tasks/download_iam.sh b/src/tasks/download_iam.sh deleted file mode 100755 index e3cf76b..0000000 --- a/src/tasks/download_iam.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/usr/bin/fish -poetry run download-iam diff --git a/src/tasks/make_wordpieces.py b/src/tasks/make_wordpieces.py deleted file mode 100644 index 2ac0e2c..0000000 --- a/src/tasks/make_wordpieces.py +++ /dev/null @@ -1,114 +0,0 @@ -"""Creates word pieces from a text file. - -Most code stolen from: - - https://github.com/facebookresearch/gtn_applications/blob/master/scripts/make_wordpieces.py - -""" -import io -from pathlib import Path -from typing import List, Optional, Union - -import click -from loguru import logger -import sentencepiece as spm - -from text_recognizer.datasets.iam_preprocessor import load_metadata - - -def iamdb_pieces( - data_dir: Path, text_file: str, num_pieces: int, output_prefix: str -) -> None: - """Creates word pieces from the iamdb train text.""" - # Load training text. - with open(data_dir / text_file, "r") as f: - text = [line.strip() for line in f] - - sp = train_spm_model( - iter(text), - num_pieces + 1, # To account for - user_symbols=["/"], # added so token is in the output set - ) - - vocab = sorted(set(w for t in text for w in t.split("▁") if w)) - if "move" not in vocab: - raise RuntimeError("`MOVE` not in vocab") - - save_pieces(sp, num_pieces, data_dir, output_prefix, vocab) - - -def train_spm_model( - sentences: iter, vocab_size: int, user_symbols: Union[str, List[str]] = "" -) -> spm.SentencePieceProcessor: - """Trains the sentence piece model.""" - model = io.BytesIO() - spm.SentencePieceTrainer.train( - sentence_iterator=sentences, - model_writer=model, - vocab_size=vocab_size, - bos_id=-1, - eos_id=-1, - character_coverage=1.0, - user_defined_symbols=user_symbols, - ) - sp = spm.SentencePieceProcessor(model_proto=model.getvalue()) - return sp - - -def save_pieces( - sp: spm.SentencePieceProcessor, - num_pieces: int, - data_dir: Path, - output_prefix: str, - vocab: set, -) -> None: - """Saves word pieces to disk.""" - logger.info(f"Generating word piece list of size {num_pieces}.") - pieces = [sp.id_to_piece(i) for i in range(1, num_pieces + 1)] - logger.info(f"Encoding vocabulary of size {len(vocab)}.") - encoded_vocab = [sp.encode_as_pieces(v) for v in vocab] - - # Save pieces to file. - with open(data_dir / f"{output_prefix}_tokens_{num_pieces}.txt", "w") as f: - f.write("\n".join(pieces)) - - # Save lexicon to a file. - with open(data_dir / f"{output_prefix}_lex_{num_pieces}.txt", "w") as f: - for v, p in zip(vocab, encoded_vocab): - f.write(f"{v} {' '.join(p)}\n") - - -@click.command() -@click.option("--data_dir", type=str, default=None, help="Path to processed iam dir.") -@click.option( - "--text_file", type=str, default=None, help="Name of sentence piece training text." -) -@click.option( - "--output_prefix", - type=str, - default="word_pieces", - help="Prefix name to store tokens and lexicon.", -) -@click.option("--num_pieces", type=int, default=1000, help="Number of word pieces.") -def cli( - data_dir: Optional[str], - text_file: Optional[str], - output_prefix: Optional[str], - num_pieces: Optional[int], -) -> None: - """CLI for training the sentence piece model.""" - if data_dir is None: - data_dir = ( - Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines" - ) - logger.debug(f"Using data dir: {data_dir}") - if not data_dir.exists(): - raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") - else: - data_dir = Path(data_dir) - - iamdb_pieces(data_dir, text_file, num_pieces, output_prefix) - - -if __name__ == "__main__": - cli() diff --git a/src/tasks/prepare_experiments.sh b/src/tasks/prepare_experiments.sh deleted file mode 100755 index 95a538f..0000000 --- a/src/tasks/prepare_experiments.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/fish -experiments_filename=${1:-training/experiments/sample_experiment.yml} -poetry run prepare-experiments --experiments_filename $experiments_filename diff --git a/src/tasks/test_functionality.sh b/src/tasks/test_functionality.sh deleted file mode 100755 index 5ccf0cd..0000000 --- a/src/tasks/test_functionality.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/usr/bin/fish -pytest -s -q text_recognizer diff --git a/src/tasks/train.sh b/src/tasks/train.sh deleted file mode 100755 index 60cbd23..0000000 --- a/src/tasks/train.sh +++ /dev/null @@ -1,68 +0,0 @@ -#!/bin/bash - - -# Add checkpoint and resume experiment -usage() { - cat << EOF - usage: ./tasks/train_crnn_line_ctc_model.sh - -f | --experiment_config Name of the experiment config. - -c | --checkpoint (Optional) The experiment name to continue from. - -p | --pretrained_weights (Optional) Path to pretrained weights. - -n | --notrain (Optional) Evaluates a trained model. - -t | --test (Optional) If set, evaluates the model on test set. - -v | --verbose (Optional) Sets the verbosity. - -h | --help Shows this message. -EOF -exit 1 -} - -experiment_config="" -checkpoint="" -pretrained_weights="" -notrain="" -test="" -verbose="" -train_command="" - -while getopts 'f:c:p:nthv' flag; do - case "${flag}" in - f) experiment_config="${OPTARG}" ;; - c) checkpoint="${OPTARG}" ;; - p) pretrained_weights="${OPTARG}" ;; - n) notrain="--notrain" ;; - t) test="--test" ;; - v) verbose="${verbose}v" ;; - h) usage ;; - *) error "Unexpected option ${flag}" ;; - esac -done - - -if [ -z ${experiment_config} ]; -then - echo "experiment_config not specified!" - usage - exit 1 -fi - -experiments_filename="training/experiments/${experiment_config}" -train_command=$(bash tasks/prepare_experiments.sh $experiments_filename) - -if [ ${checkpoint} ]; -then - train_command="${train_command} --checkpoint $checkpoint" -fi - -if [ ${pretrained_weights} ]; -then - train_command="${train_command} --pretrained_weights $pretrained_weights" -fi - -if [ ${verbose} ]; -then - train_command="${train_command} -$verbose" -fi - -train_command="${train_command} $test $notrain" -echo $train_command -eval $train_command -- cgit v1.2.3-70-g09d2