summaryrefslogtreecommitdiff
path: root/text_recognizer/data/transforms.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-03 18:18:48 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-03 18:18:48 +0200
commitbd4bd443f339e95007bfdabf3e060db720f4d4b9 (patch)
treee55cb3744904f7c2a0348b100c7e92a65e538a16 /text_recognizer/data/transforms.py
parent75801019981492eedf9280cb352eea3d8e99b65f (diff)
Training working, multiple bug fixes
Diffstat (limited to 'text_recognizer/data/transforms.py')
-rw-r--r--text_recognizer/data/transforms.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
index 3b1b929..047496f 100644
--- a/text_recognizer/data/transforms.py
+++ b/text_recognizer/data/transforms.py
@@ -1,11 +1,11 @@
"""Transforms for PyTorch datasets."""
from pathlib import Path
-from typing import Optional, Union, Sequence
+from typing import Optional, Union, Set
import torch
from torch import Tensor
-from text_recognizer.data.mappings import WordPieceMapping
+from text_recognizer.data.word_piece_mapping import WordPieceMapping
class WordPiece:
@@ -19,8 +19,8 @@ class WordPiece:
data_dir: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
- special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"),
- extra_symbols: Optional[Sequence[str]] = ("\n",),
+ special_tokens: Set[str] = {"<s>", "<e>", "<p>"},
+ extra_symbols: Optional[Set[str]] = {"\n",},
max_len: int = 451,
) -> None:
self.mapping = WordPieceMapping(