summaryrefslogtreecommitdiff
path: root/text_recognizer/models
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models')
-rw-r--r--text_recognizer/models/base.py1
-rw-r--r--text_recognizer/models/transformer.py5
2 files changed, 3 insertions, 3 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 0928e6c..c6d5d73 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -60,6 +60,7 @@ class LitBaseModel(pl.LightningModule):
scheduler["scheduler"] = getattr(
torch.optim.lr_scheduler, self._lr_scheduler.type
)(optimizer, **args)
+
return scheduler
def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]:
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index b23685b..7dc1352 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,5 +1,5 @@
"""PyTorch Lightning model for base Transformers."""
-from typing import Dict, List, Optional, Union, Tuple
+from typing import Dict, List, Optional, Union, Tuple, Type
from omegaconf import DictConfig, OmegaConf
import pytorch_lightning as pl
@@ -19,7 +19,7 @@ class LitTransformerModel(LitBaseModel):
def __init__(
self,
- network: Type[nn, Module],
+ network: Type[nn.Module],
optimizer: Union[DictConfig, Dict],
lr_scheduler: Union[DictConfig, Dict],
criterion: Union[DictConfig, Dict],
@@ -27,7 +27,6 @@ class LitTransformerModel(LitBaseModel):
mapping: Optional[List[str]] = None,
) -> None:
super().__init__(network, optimizer, lr_scheduler, criterion, monitor)
-
self.mapping, ignore_tokens = self.configure_mapping(mapping)
self.val_cer = CharacterErrorRate(ignore_tokens)
self.test_cer = CharacterErrorRate(ignore_tokens)