diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/models/transformer.py | 5 | ||||
-rw-r--r-- | text_recognizer/models/vqvae.py | 10 |
2 files changed, 8 insertions, 7 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 8dd4db2..bc7e313 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,12 +1,9 @@ """PyTorch Lightning model for base Transformers.""" from typing import Dict, List, Optional, Union, Tuple, Type -from omegaconf import DictConfig, OmegaConf -import pytorch_lightning as pl -import torch +from omegaconf import DictConfig from torch import nn from torch import Tensor -import torch.nn.functional as F import wandb from text_recognizer.data.emnist import emnist_mapping diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index 9857420..18e8691 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -1,10 +1,9 @@ """PyTorch Lightning model for base Transformers.""" from typing import Any, Dict, Union, Tuple, Type -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig from torch import nn from torch import Tensor -import torch.nn.functional as F import wandb from text_recognizer.models.base import LitBaseModel @@ -35,7 +34,12 @@ class LitVQVAEModel(LitBaseModel): """Logs prediction on image with wandb.""" try: self.logger.experiment.log( - {title: [wandb.Image(data[0]), wandb.Image(reconstructions[0]),]} + { + title: [ + wandb.Image(data[0]), + wandb.Image(reconstructions[0]), + ] + } ) except AttributeError: pass |