summaryrefslogtreecommitdiff
path: root/text_recognizer/models/vqvae.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-06 17:42:53 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-06 17:42:53 +0200
commiteb5b206f7e1b08435378d2a02395307be55ee6f1 (patch)
tree0cd30234afab698eb632b20a7da97e3bc7e98882 /text_recognizer/models/vqvae.py
parent4d1f2cef39688871d2caafce42a09316381a27ae (diff)
Refactoring data with attrs and refactor conf for hydra
Diffstat (limited to 'text_recognizer/models/vqvae.py')
-rw-r--r--text_recognizer/models/vqvae.py34
1 files changed, 4 insertions, 30 deletions
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
index 7dc950f..0172163 100644
--- a/text_recognizer/models/vqvae.py
+++ b/text_recognizer/models/vqvae.py
@@ -1,49 +1,23 @@
"""PyTorch Lightning model for base Transformers."""
from typing import Any, Dict, Union, Tuple, Type
+import attr
from omegaconf import DictConfig
from torch import nn
from torch import Tensor
import wandb
-from text_recognizer.models.base import LitBaseModel
+from text_recognizer.models.base import BaseLitModel
-class LitVQVAEModel(LitBaseModel):
+@attr.s(auto_attribs=True)
+class VQVAELitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- def __init__(
- self,
- network: Type[nn.Module],
- optimizer: Union[DictConfig, Dict],
- lr_scheduler: Union[DictConfig, Dict],
- criterion: Union[DictConfig, Dict],
- monitor: str = "val/loss",
- *args: Any,
- **kwargs: Dict,
- ) -> None:
- super().__init__(network, optimizer, lr_scheduler, criterion, monitor)
-
def forward(self, data: Tensor) -> Tensor:
"""Forward pass with the transformer network."""
return self.network.predict(data)
- def _log_prediction(
- self, data: Tensor, reconstructions: Tensor, title: str
- ) -> None:
- """Logs prediction on image with wandb."""
- try:
- self.logger.experiment.log(
- {
- title: [
- wandb.Image(data[0]),
- wandb.Image(reconstructions[0]),
- ]
- }
- )
- except AttributeError:
- pass
-
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
"""Training step."""
data, _ = batch