summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-15 21:47:54 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-15 21:47:54 +0200
commit9ce21f569ecac03f15f2ad669fde3dd4a512f8cc (patch)
treee6f5bbf4cfe758788fd6ad3679b714d4ecfad568 /training
parenta93dcc5b9c8160a441c5b5f99f2f59264778ef91 (diff)
Format
Diffstat (limited to 'training')
-rw-r--r--training/artifacts.py7
-rw-r--r--training/callbacks/wandb.py5
-rw-r--r--training/main.py1
-rw-r--r--training/run.py4
-rw-r--r--training/utils.py3
5 files changed, 12 insertions, 8 deletions
diff --git a/training/artifacts.py b/training/artifacts.py
index 2cb8bda..f7af085 100644
--- a/training/artifacts.py
+++ b/training/artifacts.py
@@ -1,14 +1,15 @@
"""Fetches model artifacts from wandb."""
-from datetime import datetime
-from pathlib import Path
import shutil
import sys
+from datetime import datetime
+from pathlib import Path
from typing import Optional
import click
from loguru import logger as log
-from training import metadata
+
import wandb
+from training import metadata
from wandb.apis.public import Run
diff --git a/training/callbacks/wandb.py b/training/callbacks/wandb.py
index 6adbebe..0841000 100644
--- a/training/callbacks/wandb.py
+++ b/training/callbacks/wandb.py
@@ -2,11 +2,12 @@
from pathlib import Path
from typing import Tuple
-import wandb
-from torch import Tensor
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities import rank_zero_only
+from torch import Tensor
+
+import wandb
def get_wandb_logger(trainer: Trainer) -> WandbLogger:
diff --git a/training/main.py b/training/main.py
index c36e397..17b961c 100644
--- a/training/main.py
+++ b/training/main.py
@@ -1,6 +1,7 @@
"""Loads config with hydra and runs experiment."""
import hydra
from omegaconf import DictConfig
+
from training.metadata import TRAINING_DIR
diff --git a/training/run.py b/training/run.py
index b8f2178..2089a86 100644
--- a/training/run.py
+++ b/training/run.py
@@ -2,19 +2,19 @@
from typing import Callable, List, Optional, Type
import hydra
+import utils
from loguru import logger as log
from omegaconf import DictConfig
from pytorch_lightning import (
Callback,
LightningDataModule,
LightningModule,
- seed_everything,
Trainer,
+ seed_everything,
)
from pytorch_lightning.loggers import Logger
from torch import nn
from torchinfo import summary
-import utils
def run(config: DictConfig) -> Optional[float]:
diff --git a/training/utils.py b/training/utils.py
index d1801a7..f0a0d4d 100644
--- a/training/utils.py
+++ b/training/utils.py
@@ -1,6 +1,6 @@
"""Util functions for training with hydra and pytorch lightning."""
-from typing import List, Type
import warnings
+from typing import List, Type
import hydra
from loguru import logger as log
@@ -14,6 +14,7 @@ from pytorch_lightning.loggers import Logger
from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.utilities import rank_zero_only
from tqdm import tqdm
+
import wandb