From 2ba8f61184f56955db27cbab878ab05e01dfea07 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 6 Jun 2022 23:02:50 +0200 Subject: Rename lit models --- text_recognizer/models/base.py | 2 +- text_recognizer/models/transformer.py | 21 +++++++++++++++------ 2 files changed, 16 insertions(+), 7 deletions(-) (limited to 'text_recognizer/models') diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 63fe5a7..77c5509 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -13,7 +13,7 @@ from torchmetrics import Accuracy from text_recognizer.data.mappings.base import AbstractMapping -class BaseLitModel(LightningModule): +class LitBase(LightningModule): """Abstract PyTorch Lightning class.""" def __init__( diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 9537dd9..4bbc671 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,24 +1,33 @@ -"""PyTorch Lightning model for base Transformers.""" -from typing import Set, Tuple +"""Lightning model for base Transformers.""" +from typing import Optional, Tuple, Type +from omegaconf import DictConfig import torch -from torch import Tensor +from torch import nn, Tensor -from text_recognizer.models.base import BaseLitModel +from text_recognizer.data.mappings import AbstractMapping +from text_recognizer.models.base import LitBase from text_recognizer.models.metrics import CharacterErrorRate -class TransformerLitModel(BaseLitModel): +class LitTransformer(LitBase): """A PyTorch Lightning model for transformer networks.""" def __init__( self, + network: Type[nn.Module], + loss_fn: Type[nn.Module], + optimizer_configs: DictConfig, + lr_scheduler_configs: Optional[DictConfig], + mapping: Type[AbstractMapping], max_output_len: int = 451, start_token: str = "", end_token: str = "", pad_token: str = "

", ) -> None: - super().__init__() + super().__init__( + network, loss_fn, optimizer_configs, lr_scheduler_configs, mapping + ) self.max_output_len = max_output_len self.start_token = start_token self.end_token = end_token -- cgit v1.2.3-70-g09d2