summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md3
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py6
-rw-r--r--text_recognizer/data/iam_paragraphs.py9
-rw-r--r--text_recognizer/models/base.py3
-rw-r--r--text_recognizer/models/metrics.py4
-rw-r--r--text_recognizer/models/transformer.py1
6 files changed, 12 insertions, 14 deletions
diff --git a/README.md b/README.md
index fa1b6a6..43cf05f 100644
--- a/README.md
+++ b/README.md
@@ -26,7 +26,7 @@ python build-transitions --tokens iamdb_1kwp_tokens_1000.txt --lexicon iamdb_1kw
(TODO: Not working atm, needed for GTN loss function)
## Todo
-- [ ] Efficient-net b0 + transformer decoder
+- [x] Efficient-net b0 + transformer decoder
- [ ] Load everything with hydra, get it to work
- [ ] Tests
- [ ] Evaluation
@@ -34,6 +34,7 @@ python build-transitions --tokens iamdb_1kwp_tokens_1000.txt --lexicon iamdb_1kw
- [ ] attr refactor
- [ ] Refactor once more
- [ ] fix linting
+- [ ] fix loading of transform iam paragraph
## Run Sweeps
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 23e424d..0e97801 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -16,6 +16,7 @@ class IAMExtendedParagraphs(BaseDataModule):
augment: bool = attr.ib(default=True)
train_fraction: float = attr.ib(default=0.8)
word_pieces: bool = attr.ib(default=False)
+ num_classes: int = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
self.iam_paragraphs = IAMParagraphs(
@@ -35,8 +36,7 @@ class IAMExtendedParagraphs(BaseDataModule):
self.dims = self.iam_paragraphs.dims
self.output_dims = self.iam_paragraphs.output_dims
- self.mapping = self.iam_paragraphs.mapping
- self.inverse_mapping = self.iam_paragraphs.inverse_mapping
+ self.num_classes = self.iam_paragraphs.num_classes
def prepare_data(self) -> None:
"""Prepares the paragraphs data."""
@@ -58,7 +58,7 @@ class IAMExtendedParagraphs(BaseDataModule):
"""Returns info about the dataset."""
basic = (
"IAM Original and Synthetic Paragraphs Dataset\n" # pylint: disable=no-member
- f"Num classes: {len(self.mapping)}\n"
+ f"Num classes: {len(self.num_classes)}\n"
f"Dims: {self.dims}\n"
f"Output dims: {self.output_dims}\n"
)
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 82058e0..7ba1077 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -38,20 +38,17 @@ MAX_LABEL_LENGTH = 682
class IAMParagraphs(BaseDataModule):
"""IAM handwriting database paragraphs."""
+ num_classes: int = attr.ib()
augment: bool = attr.ib(default=True)
train_fraction: float = attr.ib(default=0.8)
- word_pieces: bool = attr.ib(default=False)
dims: Tuple[int, int, int] = attr.ib(
init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)
)
output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1))
+ inverse_mapping: Dict[str, int] = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
- self.mapping, self.inverse_mapping, _ = emnist_mapping(
- extra_symbols=[NEW_LINE_TOKEN]
- )
- if self.word_pieces:
- self.mapping = WordPieceMapping()
+ _, self.inverse_mapping, _ = emnist_mapping(extra_symbols=[NEW_LINE_TOKEN])
def prepare_data(self) -> None:
"""Create data for training/testing."""
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index dfb4ca4..caf63c1 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -23,12 +23,9 @@ class BaseLitModel(LightningModule):
criterion_config: DictConfig = attr.ib(converter=DictConfig)
optimizer_config: DictConfig = attr.ib(converter=DictConfig)
lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig)
-
interval: str = attr.ib()
monitor: str = attr.ib(default="val/loss")
-
loss_fn: Type[nn.Module] = attr.ib(init=False)
-
train_acc: torchmetrics.Accuracy = attr.ib(
init=False, default=torchmetrics.Accuracy()
)
diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py
index 9793157..0eb42dc 100644
--- a/text_recognizer/models/metrics.py
+++ b/text_recognizer/models/metrics.py
@@ -26,7 +26,9 @@ class CharacterErrorRate(Metric):
bsz = preds.shape[0]
for index in range(bsz):
pred = [p for p in preds[index].tolist() if p not in self.ignore_indices]
- target = [t for t in targets[index].tolist() if t not in self.ignore_indices]
+ target = [
+ t for t in targets[index].tolist() if t not in self.ignore_indices
+ ]
distance = editdistance.distance(pred, target)
error = distance / max(len(pred), len(target))
self.error += error
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 7a9d566..0e01bb5 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -13,6 +13,7 @@ from text_recognizer.models.base import BaseLitModel
@attr.s(auto_attribs=True)
class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
+
mapping: Type[AbstractMapping] = attr.ib()
start_token: str = attr.ib()
end_token: str = attr.ib()