summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/models/base.py2
-rw-r--r--text_recognizer/models/transformer.py21
2 files changed, 16 insertions, 7 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 63fe5a7..77c5509 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -13,7 +13,7 @@ from torchmetrics import Accuracy
from text_recognizer.data.mappings.base import AbstractMapping
-class BaseLitModel(LightningModule):
+class LitBase(LightningModule):
"""Abstract PyTorch Lightning class."""
def __init__(
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 9537dd9..4bbc671 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,24 +1,33 @@
-"""PyTorch Lightning model for base Transformers."""
-from typing import Set, Tuple
+"""Lightning model for base Transformers."""
+from typing import Optional, Tuple, Type
+from omegaconf import DictConfig
import torch
-from torch import Tensor
+from torch import nn, Tensor
-from text_recognizer.models.base import BaseLitModel
+from text_recognizer.data.mappings import AbstractMapping
+from text_recognizer.models.base import LitBase
from text_recognizer.models.metrics import CharacterErrorRate
-class TransformerLitModel(BaseLitModel):
+class LitTransformer(LitBase):
"""A PyTorch Lightning model for transformer networks."""
def __init__(
self,
+ network: Type[nn.Module],
+ loss_fn: Type[nn.Module],
+ optimizer_configs: DictConfig,
+ lr_scheduler_configs: Optional[DictConfig],
+ mapping: Type[AbstractMapping],
max_output_len: int = 451,
start_token: str = "<s>",
end_token: str = "<e>",
pad_token: str = "<p>",
) -> None:
- super().__init__()
+ super().__init__(
+ network, loss_fn, optimizer_configs, lr_scheduler_configs, mapping
+ )
self.max_output_len = max_output_len
self.start_token = start_token
self.end_token = end_token