summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-06 23:02:50 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-06 23:02:50 +0200
commit2ba8f61184f56955db27cbab878ab05e01dfea07 (patch)
treeb31ed74cb3e7bf74fdb059014f708efdc2ce3c35 /text_recognizer
parentf27fc4d798e4a332b616736b6145c95b87cdabbd (diff)
Rename lit models
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/models/base.py2
-rw-r--r--text_recognizer/models/transformer.py21
2 files changed, 16 insertions, 7 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 63fe5a7..77c5509 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -13,7 +13,7 @@ from torchmetrics import Accuracy
from text_recognizer.data.mappings.base import AbstractMapping
-class BaseLitModel(LightningModule):
+class LitBase(LightningModule):
"""Abstract PyTorch Lightning class."""
def __init__(
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 9537dd9..4bbc671 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,24 +1,33 @@
-"""PyTorch Lightning model for base Transformers."""
-from typing import Set, Tuple
+"""Lightning model for base Transformers."""
+from typing import Optional, Tuple, Type
+from omegaconf import DictConfig
import torch
-from torch import Tensor
+from torch import nn, Tensor
-from text_recognizer.models.base import BaseLitModel
+from text_recognizer.data.mappings import AbstractMapping
+from text_recognizer.models.base import LitBase
from text_recognizer.models.metrics import CharacterErrorRate
-class TransformerLitModel(BaseLitModel):
+class LitTransformer(LitBase):
"""A PyTorch Lightning model for transformer networks."""
def __init__(
self,
+ network: Type[nn.Module],
+ loss_fn: Type[nn.Module],
+ optimizer_configs: DictConfig,
+ lr_scheduler_configs: Optional[DictConfig],
+ mapping: Type[AbstractMapping],
max_output_len: int = 451,
start_token: str = "<s>",
end_token: str = "<e>",
pad_token: str = "<p>",
) -> None:
- super().__init__()
+ super().__init__(
+ network, loss_fn, optimizer_configs, lr_scheduler_configs, mapping
+ )
self.max_output_len = max_output_len
self.start_token = start_token
self.end_token = end_token