diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-10 18:08:21 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-10 18:08:21 +0200 |
commit | 8b9662480abd6104869c65f52b4edc80c5bf635e (patch) | |
tree | 1ed393736ecbbb1d31857a7d0ce5e632048f09f4 /text_recognizer/models/transformer.py | |
parent | 283a6fb2c33213dc05d34f1163422f2855506337 (diff) |
Update transformer models
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r-- | text_recognizer/models/transformer.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index dbdc7f2..0acc30a 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,10 +1,11 @@ """PyTorch Lightning model for base Transformers.""" -from typing import Tuple, Set +from typing import Tuple, Type, Set import attr import torch from torch import Tensor +from text_recognizer.data.base_mapping import AbstractMapping from text_recognizer.models.metrics import CharacterErrorRate from text_recognizer.models.base import BaseLitModel @@ -13,6 +14,8 @@ from text_recognizer.models.base import BaseLitModel class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" + mapping: Type[AbstractMapping] = attr.ib() + max_output_len: int = attr.ib(default=451) start_token: str = attr.ib(default="<s>") end_token: str = attr.ib(default="<e>") |