summaryrefslogtreecommitdiff
path: root/text_recognizer/models/transformer.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:08:21 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:08:21 +0200
commit8b9662480abd6104869c65f52b4edc80c5bf635e (patch)
tree1ed393736ecbbb1d31857a7d0ce5e632048f09f4 /text_recognizer/models/transformer.py
parent283a6fb2c33213dc05d34f1163422f2855506337 (diff)
Update transformer models
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r--text_recognizer/models/transformer.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index dbdc7f2..0acc30a 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,10 +1,11 @@
"""PyTorch Lightning model for base Transformers."""
-from typing import Tuple, Set
+from typing import Tuple, Type, Set
import attr
import torch
from torch import Tensor
+from text_recognizer.data.base_mapping import AbstractMapping
from text_recognizer.models.metrics import CharacterErrorRate
from text_recognizer.models.base import BaseLitModel
@@ -13,6 +14,8 @@ from text_recognizer.models.base import BaseLitModel
class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
+ mapping: Type[AbstractMapping] = attr.ib()
+
max_output_len: int = attr.ib(default=451)
start_token: str = attr.ib(default="<s>")
end_token: str = attr.ib(default="<e>")