summaryrefslogtreecommitdiff
path: root/text_recognizer/models
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models')
-rw-r--r--text_recognizer/models/base.py2
-rw-r--r--text_recognizer/models/metrics.py4
-rw-r--r--text_recognizer/models/transformer.py16
-rw-r--r--text_recognizer/models/vqvae.py2
4 files changed, 12 insertions, 12 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index caf63c1..8ce5c37 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -12,7 +12,7 @@ from torch import Tensor
import torchmetrics
-@attr.s
+@attr.s(eq=False)
class BaseLitModel(LightningModule):
"""Abstract PyTorch Lightning class."""
diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py
index 0eb42dc..f83c9e4 100644
--- a/text_recognizer/models/metrics.py
+++ b/text_recognizer/models/metrics.py
@@ -8,11 +8,11 @@ from torch import Tensor
from torchmetrics import Metric
-@attr.s
+@attr.s(eq=False)
class CharacterErrorRate(Metric):
"""Character error rate metric, computed using Levenshtein distance."""
- ignore_indices: Set = attr.ib(converter=set)
+ ignore_indices: Set[Tensor] = attr.ib(converter=set)
error: Tensor = attr.ib(init=False)
total: Tensor = attr.ib(init=False)
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 0e01bb5..91e088d 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,5 +1,5 @@
"""PyTorch Lightning model for base Transformers."""
-from typing import Sequence, Tuple, Type
+from typing import Tuple, Type, Set
import attr
import torch
@@ -10,20 +10,20 @@ from text_recognizer.models.metrics import CharacterErrorRate
from text_recognizer.models.base import BaseLitModel
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, eq=False)
class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- mapping: Type[AbstractMapping] = attr.ib()
- start_token: str = attr.ib()
- end_token: str = attr.ib()
- pad_token: str = attr.ib()
+ mapping: Type[AbstractMapping] = attr.ib(default=None)
+ start_token: str = attr.ib(default="<s>")
+ end_token: str = attr.ib(default="<e>")
+ pad_token: str = attr.ib(default="<p>")
start_index: Tensor = attr.ib(init=False)
end_index: Tensor = attr.ib(init=False)
pad_index: Tensor = attr.ib(init=False)
- ignore_indices: Sequence[str] = attr.ib(init=False)
+ ignore_indices: Set[Tensor] = attr.ib(init=False)
val_cer: CharacterErrorRate = attr.ib(init=False)
test_cer: CharacterErrorRate = attr.ib(init=False)
@@ -32,7 +32,7 @@ class TransformerLitModel(BaseLitModel):
self.start_index = self.mapping.get_index(self.start_token)
self.end_index = self.mapping.get_index(self.end_token)
self.pad_index = self.mapping.get_index(self.pad_token)
- self.ignore_indices = [self.start_index, self.end_index, self.pad_index]
+ self.ignore_indices = set([self.start_index, self.end_index, self.pad_index])
self.val_cer = CharacterErrorRate(self.ignore_indices)
self.test_cer = CharacterErrorRate(self.ignore_indices)
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
index e215e14..22da018 100644
--- a/text_recognizer/models/vqvae.py
+++ b/text_recognizer/models/vqvae.py
@@ -10,7 +10,7 @@ import wandb
from text_recognizer.models.base import BaseLitModel
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, eq=False)
class VQVAELitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""