summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/image_transformer.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py
index daededa..a6aaca4 100644
--- a/text_recognizer/networks/image_transformer.py
+++ b/text_recognizer/networks/image_transformer.py
@@ -44,7 +44,9 @@ class ImageTransformer(nn.Module):
dropout_rate: float = 0.1,
transformer_activation: str = "glu",
) -> None:
- self.vocab_size = NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size
+ self.vocab_size = (
+ NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size
+ )
self.hidden_dim = hidden_dim
self.max_output_length = output_shape[0]