From 75801019981492eedf9280cb352eea3d8e99b65f Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 2 Aug 2021 21:13:48 +0200 Subject: Fix log import, fix mapping in datamodules, fix nn modules can be hashed --- text_recognizer/models/base.py | 2 +- text_recognizer/models/metrics.py | 4 ++-- text_recognizer/models/transformer.py | 16 ++++++++-------- text_recognizer/models/vqvae.py | 2 +- 4 files changed, 12 insertions(+), 12 deletions(-) (limited to 'text_recognizer/models') diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index caf63c1..8ce5c37 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -12,7 +12,7 @@ from torch import Tensor import torchmetrics -@attr.s +@attr.s(eq=False) class BaseLitModel(LightningModule): """Abstract PyTorch Lightning class.""" diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py index 0eb42dc..f83c9e4 100644 --- a/text_recognizer/models/metrics.py +++ b/text_recognizer/models/metrics.py @@ -8,11 +8,11 @@ from torch import Tensor from torchmetrics import Metric -@attr.s +@attr.s(eq=False) class CharacterErrorRate(Metric): """Character error rate metric, computed using Levenshtein distance.""" - ignore_indices: Set = attr.ib(converter=set) + ignore_indices: Set[Tensor] = attr.ib(converter=set) error: Tensor = attr.ib(init=False) total: Tensor = attr.ib(init=False) diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 0e01bb5..91e088d 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,5 +1,5 @@ """PyTorch Lightning model for base Transformers.""" -from typing import Sequence, Tuple, Type +from typing import Tuple, Type, Set import attr import torch @@ -10,20 +10,20 @@ from text_recognizer.models.metrics import CharacterErrorRate from text_recognizer.models.base import BaseLitModel -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, eq=False) class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - mapping: Type[AbstractMapping] = attr.ib() - start_token: str = attr.ib() - end_token: str = attr.ib() - pad_token: str = attr.ib() + mapping: Type[AbstractMapping] = attr.ib(default=None) + start_token: str = attr.ib(default="") + end_token: str = attr.ib(default="") + pad_token: str = attr.ib(default="

") start_index: Tensor = attr.ib(init=False) end_index: Tensor = attr.ib(init=False) pad_index: Tensor = attr.ib(init=False) - ignore_indices: Sequence[str] = attr.ib(init=False) + ignore_indices: Set[Tensor] = attr.ib(init=False) val_cer: CharacterErrorRate = attr.ib(init=False) test_cer: CharacterErrorRate = attr.ib(init=False) @@ -32,7 +32,7 @@ class TransformerLitModel(BaseLitModel): self.start_index = self.mapping.get_index(self.start_token) self.end_index = self.mapping.get_index(self.end_token) self.pad_index = self.mapping.get_index(self.pad_token) - self.ignore_indices = [self.start_index, self.end_index, self.pad_index] + self.ignore_indices = set([self.start_index, self.end_index, self.pad_index]) self.val_cer = CharacterErrorRate(self.ignore_indices) self.test_cer = CharacterErrorRate(self.ignore_indices) diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index e215e14..22da018 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -10,7 +10,7 @@ import wandb from text_recognizer.models.base import BaseLitModel -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, eq=False) class VQVAELitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" -- cgit v1.2.3-70-g09d2