From 1ca8b0b9e0613c1e02f6a5d8b49e20c4d6916412 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 22 Apr 2021 08:15:58 +0200 Subject: Fixed training script, able to train vqvae --- text_recognizer/data/transforms.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) (limited to 'text_recognizer/data/transforms.py') diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py index f53df64..8d1bedd 100644 --- a/text_recognizer/data/transforms.py +++ b/text_recognizer/data/transforms.py @@ -4,7 +4,7 @@ from typing import Optional, Union, Sequence from torch import Tensor -from text_recognizer.datasets.mappings import WordPieceMapping +from text_recognizer.data.mappings import WordPieceMapping class WordPiece: @@ -12,14 +12,15 @@ class WordPiece: def __init__( self, - num_features: int, - tokens: str, - lexicon: str, + num_features: int = 1000, + tokens: str = "iamdb_1kwp_tokens_1000.txt" , + lexicon: str = "iamdb_1kwp_lex_1000.txt", data_dir: Optional[Union[str, Path]] = None, use_words: bool = False, prepend_wordsep: bool = False, special_tokens: Sequence[str] = ("", "", "

"), - extra_symbols: Optional[Sequence[str]] = None, + extra_symbols: Optional[Sequence[str]] = ("\n",), + max_len: int = 192, ) -> None: self.mapping = WordPieceMapping( num_features, @@ -31,6 +32,7 @@ class WordPiece: special_tokens, extra_symbols, ) + self.max_len = max_len def __call__(self, x: Tensor) -> Tensor: - return self.mapping.emnist_to_wordpiece_indices(x) + return self.mapping.emnist_to_wordpiece_indices(x)[:self.max_len] -- cgit v1.2.3-70-g09d2