diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2021-02-24 22:00:29 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2021-02-24 22:00:29 +0100 |
commit | 905eeeb4c3c0ba54b5414eb8f435e2e9870b7307 (patch) | |
tree | 91dab598a94911e6147b996237e786dd47f11f2f /src/tasks | |
parent | 4a54d7e690897dd6e6c719fb908fd371a44c2952 (diff) |
updates
Diffstat (limited to 'src/tasks')
-rw-r--r-- | src/tasks/build_transitions.py | 6 | ||||
-rw-r--r-- | src/tasks/make_wordpieces.py | 2 |
2 files changed, 4 insertions, 4 deletions
diff --git a/src/tasks/build_transitions.py b/src/tasks/build_transitions.py index b12c9bc..91f8c1a 100644 --- a/src/tasks/build_transitions.py +++ b/src/tasks/build_transitions.py @@ -9,7 +9,7 @@ Most code stolen from here: import collections import itertools from pathlib import Path -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import click import gtn @@ -18,7 +18,7 @@ from loguru import logger START_IDX = -1 END_IDX = -2 -WORDSEP = "_" +WORDSEP = "▁" def build_graph(ngrams: List, disable_backoff: bool = False) -> gtn.Graph: @@ -27,7 +27,7 @@ def build_graph(ngrams: List, disable_backoff: bool = False) -> gtn.Graph: ngram = len(ngrams) state_to_node = {} - def get_node(state: Optional[List]) -> gtn.node: + def get_node(state: Optional[List]) -> Any: node = state_to_node.get(state, None) if node is not None: diff --git a/src/tasks/make_wordpieces.py b/src/tasks/make_wordpieces.py index f605920..2ac0e2c 100644 --- a/src/tasks/make_wordpieces.py +++ b/src/tasks/make_wordpieces.py @@ -30,7 +30,7 @@ def iamdb_pieces( user_symbols=["/"], # added so token is in the output set ) - vocab = sorted(set(w for t in text for w in t.split("_") if w)) + vocab = sorted(set(w for t in text for w in t.split("▁") if w)) if "move" not in vocab: raise RuntimeError("`MOVE` not in vocab") |