diff options
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r-- | text_recognizer/models/transformer.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 285b715..3625ab2 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,6 +1,7 @@ """PyTorch Lightning model for base Transformers.""" -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Union, Tuple +from omegaconf import OmegaConf import pytorch_lightning as pl import torch from torch import nn @@ -18,15 +19,15 @@ class LitTransformerModel(LitBaseModel): def __init__( self, - network_args: Dict, - optimizer_args: Dict, - lr_scheduler_args: Dict, - criterion_args: Dict, + network: Type[nn,Module], + optimizer: Union[OmegaConf, Dict], + lr_scheduler: Union[OmegaConf, Dict], + criterion: Union[OmegaConf, Dict], monitor: str = "val_loss", mapping: Optional[List[str]] = None, ) -> None: super().__init__( - network_args, optimizer_args, lr_scheduler_args, criterion_args, monitor + network, optimizer, lr_scheduler, criterion, monitor ) self.mapping, ignore_tokens = self.configure_mapping(mapping) @@ -40,6 +41,7 @@ class LitTransformerModel(LitBaseModel): @staticmethod def configure_mapping(mapping: Optional[List[str]]) -> Tuple[List[str], List[int]]: """Configure mapping.""" + # TODO: Fix me!!! mapping, inverse_mapping, _ = emnist_mapping() start_index = inverse_mapping["<s>"] end_index = inverse_mapping["<e>"] |