diff options
Diffstat (limited to 'text_recognizer/models')
-rw-r--r-- | text_recognizer/models/base.py | 2 | ||||
-rw-r--r-- | text_recognizer/models/metrics.py | 4 | ||||
-rw-r--r-- | text_recognizer/models/transformer.py | 16 | ||||
-rw-r--r-- | text_recognizer/models/vqvae.py | 2 |
4 files changed, 12 insertions, 12 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index caf63c1..8ce5c37 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -12,7 +12,7 @@ from torch import Tensor import torchmetrics -@attr.s +@attr.s(eq=False) class BaseLitModel(LightningModule): """Abstract PyTorch Lightning class.""" diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py index 0eb42dc..f83c9e4 100644 --- a/text_recognizer/models/metrics.py +++ b/text_recognizer/models/metrics.py @@ -8,11 +8,11 @@ from torch import Tensor from torchmetrics import Metric -@attr.s +@attr.s(eq=False) class CharacterErrorRate(Metric): """Character error rate metric, computed using Levenshtein distance.""" - ignore_indices: Set = attr.ib(converter=set) + ignore_indices: Set[Tensor] = attr.ib(converter=set) error: Tensor = attr.ib(init=False) total: Tensor = attr.ib(init=False) diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 0e01bb5..91e088d 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,5 +1,5 @@ """PyTorch Lightning model for base Transformers.""" -from typing import Sequence, Tuple, Type +from typing import Tuple, Type, Set import attr import torch @@ -10,20 +10,20 @@ from text_recognizer.models.metrics import CharacterErrorRate from text_recognizer.models.base import BaseLitModel -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, eq=False) class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - mapping: Type[AbstractMapping] = attr.ib() - start_token: str = attr.ib() - end_token: str = attr.ib() - pad_token: str = attr.ib() + mapping: Type[AbstractMapping] = attr.ib(default=None) + start_token: str = attr.ib(default="<s>") + end_token: str = attr.ib(default="<e>") + pad_token: str = attr.ib(default="<p>") start_index: Tensor = attr.ib(init=False) end_index: Tensor = attr.ib(init=False) pad_index: Tensor = attr.ib(init=False) - ignore_indices: Sequence[str] = attr.ib(init=False) + ignore_indices: Set[Tensor] = attr.ib(init=False) val_cer: CharacterErrorRate = attr.ib(init=False) test_cer: CharacterErrorRate = attr.ib(init=False) @@ -32,7 +32,7 @@ class TransformerLitModel(BaseLitModel): self.start_index = self.mapping.get_index(self.start_token) self.end_index = self.mapping.get_index(self.end_token) self.pad_index = self.mapping.get_index(self.pad_token) - self.ignore_indices = [self.start_index, self.end_index, self.pad_index] + self.ignore_indices = set([self.start_index, self.end_index, self.pad_index]) self.val_cer = CharacterErrorRate(self.ignore_indices) self.test_cer = CharacterErrorRate(self.ignore_indices) diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index e215e14..22da018 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -10,7 +10,7 @@ import wandb from text_recognizer.models.base import BaseLitModel -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, eq=False) class VQVAELitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" |