diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-22 08:15:58 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-22 08:15:58 +0200 |
commit | 1ca8b0b9e0613c1e02f6a5d8b49e20c4d6916412 (patch) | |
tree | 5e610ac459c9b254f8826e92372346f01f8e2412 /text_recognizer/data/transforms.py | |
parent | ffa4be4bf4e3758e01d52a9c1f354a05a90b93de (diff) |
Fixed training script, able to train vqvae
Diffstat (limited to 'text_recognizer/data/transforms.py')
-rw-r--r-- | text_recognizer/data/transforms.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py index f53df64..8d1bedd 100644 --- a/text_recognizer/data/transforms.py +++ b/text_recognizer/data/transforms.py @@ -4,7 +4,7 @@ from typing import Optional, Union, Sequence from torch import Tensor -from text_recognizer.datasets.mappings import WordPieceMapping +from text_recognizer.data.mappings import WordPieceMapping class WordPiece: @@ -12,14 +12,15 @@ class WordPiece: def __init__( self, - num_features: int, - tokens: str, - lexicon: str, + num_features: int = 1000, + tokens: str = "iamdb_1kwp_tokens_1000.txt" , + lexicon: str = "iamdb_1kwp_lex_1000.txt", data_dir: Optional[Union[str, Path]] = None, use_words: bool = False, prepend_wordsep: bool = False, special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"), - extra_symbols: Optional[Sequence[str]] = None, + extra_symbols: Optional[Sequence[str]] = ("\n",), + max_len: int = 192, ) -> None: self.mapping = WordPieceMapping( num_features, @@ -31,6 +32,7 @@ class WordPiece: special_tokens, extra_symbols, ) + self.max_len = max_len def __call__(self, x: Tensor) -> Tensor: - return self.mapping.emnist_to_wordpiece_indices(x) + return self.mapping.emnist_to_wordpiece_indices(x)[:self.max_len] |