diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/artifacts.py | 7 | ||||
-rw-r--r-- | training/callbacks/wandb.py | 5 | ||||
-rw-r--r-- | training/main.py | 1 | ||||
-rw-r--r-- | training/run.py | 4 | ||||
-rw-r--r-- | training/utils.py | 3 |
5 files changed, 12 insertions, 8 deletions
diff --git a/training/artifacts.py b/training/artifacts.py index 2cb8bda..f7af085 100644 --- a/training/artifacts.py +++ b/training/artifacts.py @@ -1,14 +1,15 @@ """Fetches model artifacts from wandb.""" -from datetime import datetime -from pathlib import Path import shutil import sys +from datetime import datetime +from pathlib import Path from typing import Optional import click from loguru import logger as log -from training import metadata + import wandb +from training import metadata from wandb.apis.public import Run diff --git a/training/callbacks/wandb.py b/training/callbacks/wandb.py index 6adbebe..0841000 100644 --- a/training/callbacks/wandb.py +++ b/training/callbacks/wandb.py @@ -2,11 +2,12 @@ from pathlib import Path from typing import Tuple -import wandb -from torch import Tensor from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.utilities import rank_zero_only +from torch import Tensor + +import wandb def get_wandb_logger(trainer: Trainer) -> WandbLogger: diff --git a/training/main.py b/training/main.py index c36e397..17b961c 100644 --- a/training/main.py +++ b/training/main.py @@ -1,6 +1,7 @@ """Loads config with hydra and runs experiment.""" import hydra from omegaconf import DictConfig + from training.metadata import TRAINING_DIR diff --git a/training/run.py b/training/run.py index b8f2178..2089a86 100644 --- a/training/run.py +++ b/training/run.py @@ -2,19 +2,19 @@ from typing import Callable, List, Optional, Type import hydra +import utils from loguru import logger as log from omegaconf import DictConfig from pytorch_lightning import ( Callback, LightningDataModule, LightningModule, - seed_everything, Trainer, + seed_everything, ) from pytorch_lightning.loggers import Logger from torch import nn from torchinfo import summary -import utils def run(config: DictConfig) -> Optional[float]: diff --git a/training/utils.py b/training/utils.py index d1801a7..f0a0d4d 100644 --- a/training/utils.py +++ b/training/utils.py @@ -1,6 +1,6 @@ """Util functions for training with hydra and pytorch lightning.""" -from typing import List, Type import warnings +from typing import List, Type import hydra from loguru import logger as log @@ -14,6 +14,7 @@ from pytorch_lightning.loggers import Logger from pytorch_lightning.loggers.wandb import WandbLogger from pytorch_lightning.utilities import rank_zero_only from tqdm import tqdm + import wandb |