summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md1
-rw-r--r--text_recognizer/data/iam_paragraphs.py1
-rw-r--r--text_recognizer/data/mappings.py4
-rw-r--r--training/callbacks/wandb_callbacks.py2
-rw-r--r--training/run.py10
-rw-r--r--training/utils.py2
6 files changed, 11 insertions, 9 deletions
diff --git a/README.md b/README.md
index 43cf05f..b456055 100644
--- a/README.md
+++ b/README.md
@@ -28,6 +28,7 @@ python build-transitions --tokens iamdb_1kwp_tokens_1000.txt --lexicon iamdb_1kw
## Todo
- [x] Efficient-net b0 + transformer decoder
- [ ] Load everything with hydra, get it to work
+- [ ] Train network
- [ ] Tests
- [ ] Evaluation
- [ ] Wandb artifact fetcher
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 7ba1077..0f3a2ce 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -39,6 +39,7 @@ class IAMParagraphs(BaseDataModule):
"""IAM handwriting database paragraphs."""
num_classes: int = attr.ib()
+ word_pieces: bool = attr.ib(default=False)
augment: bool = attr.ib(default=True)
train_fraction: float = attr.ib(default=0.8)
dims: Tuple[int, int, int] = attr.ib(
diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py
index a934fd9..b69e888 100644
--- a/text_recognizer/data/mappings.py
+++ b/text_recognizer/data/mappings.py
@@ -4,7 +4,7 @@ from pathlib import Path
from typing import Dict, List, Optional, Union, Set, Sequence
import attr
-from loguru import logger
+import loguru.logger as log
import torch
from torch import Tensor
@@ -87,7 +87,7 @@ class WordPieceMapping(EmnistMapping):
if self.data_dir is None
else Path(self.data_dir)
)
- logger.debug(f"Using data dir: {self.data_dir}")
+ log.debug(f"Using data dir: {self.data_dir}")
if not self.data_dir.exists():
raise RuntimeError(f"Could not locate iamdb directory at {self.data_dir}")
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py
index 451b0d5..6379cc0 100644
--- a/training/callbacks/wandb_callbacks.py
+++ b/training/callbacks/wandb_callbacks.py
@@ -94,7 +94,7 @@ class LogTextPredictions(Callback):
super().__init__()
def _log_predictions(
- stage: str, trainer: Trainer, pl_module: LightningModule
+ self, stage: str, trainer: Trainer, pl_module: LightningModule
) -> None:
"""Logs the predicted text contained in the images."""
if not self.ready:
diff --git a/training/run.py b/training/run.py
index f745d61..d88a8f6 100644
--- a/training/run.py
+++ b/training/run.py
@@ -2,7 +2,7 @@
from typing import List, Optional, Type
import hydra
-from loguru import logger as log
+import loguru.logger as log
from omegaconf import DictConfig
from pytorch_lightning import (
Callback,
@@ -33,11 +33,11 @@ def run(config: DictConfig) -> Optional[float]:
log.info(f"Instantiating model <{config.model._target_}>")
model: LightningModule = hydra.utils.instantiate(
- config.model,
+ **config.model,
network=network,
- criterion=config.criterion,
- optimizer=config.optimizer,
- lr_scheduler=config.lr_scheduler,
+ criterion_config=config.criterion,
+ optimizer_config=config.optimizer,
+ lr_scheduler_config=config.lr_scheduler,
_recursive_=False,
)
diff --git a/training/utils.py b/training/utils.py
index ef74f61..564b9bb 100644
--- a/training/utils.py
+++ b/training/utils.py
@@ -3,7 +3,7 @@ from typing import Any, List, Type
import warnings
import hydra
-from loguru import logger as log
+import loguru.logger as log
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import (
Callback,