summaryrefslogtreecommitdiff
path: root/src/tasks
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2021-02-24 22:00:29 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2021-02-24 22:00:29 +0100
commit905eeeb4c3c0ba54b5414eb8f435e2e9870b7307 (patch)
tree91dab598a94911e6147b996237e786dd47f11f2f /src/tasks
parent4a54d7e690897dd6e6c719fb908fd371a44c2952 (diff)
updates
Diffstat (limited to 'src/tasks')
-rw-r--r--src/tasks/build_transitions.py6
-rw-r--r--src/tasks/make_wordpieces.py2
2 files changed, 4 insertions, 4 deletions
diff --git a/src/tasks/build_transitions.py b/src/tasks/build_transitions.py
index b12c9bc..91f8c1a 100644
--- a/src/tasks/build_transitions.py
+++ b/src/tasks/build_transitions.py
@@ -9,7 +9,7 @@ Most code stolen from here:
import collections
import itertools
from pathlib import Path
-from typing import Dict, List, Optional
+from typing import Any, Dict, List, Optional
import click
import gtn
@@ -18,7 +18,7 @@ from loguru import logger
START_IDX = -1
END_IDX = -2
-WORDSEP = "_"
+WORDSEP = "▁"
def build_graph(ngrams: List, disable_backoff: bool = False) -> gtn.Graph:
@@ -27,7 +27,7 @@ def build_graph(ngrams: List, disable_backoff: bool = False) -> gtn.Graph:
ngram = len(ngrams)
state_to_node = {}
- def get_node(state: Optional[List]) -> gtn.node:
+ def get_node(state: Optional[List]) -> Any:
node = state_to_node.get(state, None)
if node is not None:
diff --git a/src/tasks/make_wordpieces.py b/src/tasks/make_wordpieces.py
index f605920..2ac0e2c 100644
--- a/src/tasks/make_wordpieces.py
+++ b/src/tasks/make_wordpieces.py
@@ -30,7 +30,7 @@ def iamdb_pieces(
user_symbols=["/"], # added so token is in the output set
)
- vocab = sorted(set(w for t in text for w in t.split("_") if w))
+ vocab = sorted(set(w for t in text for w in t.split("▁") if w))
if "move" not in vocab:
raise RuntimeError("`MOVE` not in vocab")