diff options
-rw-r--r-- | README.md | 1 | ||||
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 1 | ||||
-rw-r--r-- | text_recognizer/data/mappings.py | 4 | ||||
-rw-r--r-- | training/callbacks/wandb_callbacks.py | 2 | ||||
-rw-r--r-- | training/run.py | 10 | ||||
-rw-r--r-- | training/utils.py | 2 |
6 files changed, 11 insertions, 9 deletions
@@ -28,6 +28,7 @@ python build-transitions --tokens iamdb_1kwp_tokens_1000.txt --lexicon iamdb_1kw ## Todo - [x] Efficient-net b0 + transformer decoder - [ ] Load everything with hydra, get it to work +- [ ] Train network - [ ] Tests - [ ] Evaluation - [ ] Wandb artifact fetcher diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index 7ba1077..0f3a2ce 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -39,6 +39,7 @@ class IAMParagraphs(BaseDataModule): """IAM handwriting database paragraphs.""" num_classes: int = attr.ib() + word_pieces: bool = attr.ib(default=False) augment: bool = attr.ib(default=True) train_fraction: float = attr.ib(default=0.8) dims: Tuple[int, int, int] = attr.ib( diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py index a934fd9..b69e888 100644 --- a/text_recognizer/data/mappings.py +++ b/text_recognizer/data/mappings.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Dict, List, Optional, Union, Set, Sequence import attr -from loguru import logger +import loguru.logger as log import torch from torch import Tensor @@ -87,7 +87,7 @@ class WordPieceMapping(EmnistMapping): if self.data_dir is None else Path(self.data_dir) ) - logger.debug(f"Using data dir: {self.data_dir}") + log.debug(f"Using data dir: {self.data_dir}") if not self.data_dir.exists(): raise RuntimeError(f"Could not locate iamdb directory at {self.data_dir}") diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index 451b0d5..6379cc0 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -94,7 +94,7 @@ class LogTextPredictions(Callback): super().__init__() def _log_predictions( - stage: str, trainer: Trainer, pl_module: LightningModule + self, stage: str, trainer: Trainer, pl_module: LightningModule ) -> None: """Logs the predicted text contained in the images.""" if not self.ready: diff --git a/training/run.py b/training/run.py index f745d61..d88a8f6 100644 --- a/training/run.py +++ b/training/run.py @@ -2,7 +2,7 @@ from typing import List, Optional, Type import hydra -from loguru import logger as log +import loguru.logger as log from omegaconf import DictConfig from pytorch_lightning import ( Callback, @@ -33,11 +33,11 @@ def run(config: DictConfig) -> Optional[float]: log.info(f"Instantiating model <{config.model._target_}>") model: LightningModule = hydra.utils.instantiate( - config.model, + **config.model, network=network, - criterion=config.criterion, - optimizer=config.optimizer, - lr_scheduler=config.lr_scheduler, + criterion_config=config.criterion, + optimizer_config=config.optimizer, + lr_scheduler_config=config.lr_scheduler, _recursive_=False, ) diff --git a/training/utils.py b/training/utils.py index ef74f61..564b9bb 100644 --- a/training/utils.py +++ b/training/utils.py @@ -3,7 +3,7 @@ from typing import Any, List, Type import warnings import hydra -from loguru import logger as log +import loguru.logger as log from omegaconf import DictConfig, OmegaConf from pytorch_lightning import ( Callback, |