summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/cnn_tranformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/cnn_tranformer.py')
-rw-r--r--text_recognizer/networks/cnn_tranformer.py30
1 files changed, 17 insertions, 13 deletions
diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/cnn_tranformer.py
index e030cb8..ce7ec43 100644
--- a/text_recognizer/networks/cnn_tranformer.py
+++ b/text_recognizer/networks/cnn_tranformer.py
@@ -7,6 +7,7 @@ import torch
from torch import nn, Tensor
from text_recognizer.data.mappings import AbstractMapping
+from text_recognizer.networks.encoders.efficientnet import EfficientNet
from text_recognizer.networks.transformer.layers import Decoder
from text_recognizer.networks.transformer.positional_encodings import (
PositionalEncoding,
@@ -15,7 +16,7 @@ from text_recognizer.networks.transformer.positional_encodings import (
@attr.s
-class CnnTransformer(nn.Module):
+class Reader(nn.Module):
def __attrs_pre_init__(self) -> None:
super().__init__()
@@ -27,21 +28,20 @@ class CnnTransformer(nn.Module):
num_classes: int = attr.ib()
padding_idx: int = attr.ib()
start_token: str = attr.ib()
- start_index: int = attr.ib(init=False, default=None)
+ start_index: int = attr.ib(init=False)
end_token: str = attr.ib()
- end_index: int = attr.ib(init=False, default=None)
+ end_index: int = attr.ib(init=False)
pad_token: str = attr.ib()
- pad_index: int = attr.ib(init=False, default=None)
+ pad_index: int = attr.ib(init=False)
# Modules.
- encoder: Type[nn.Module] = attr.ib()
+ encoder: EfficientNet = attr.ib()
decoder: Decoder = attr.ib()
- embedding: nn.Embedding = attr.ib(init=False, default=None)
- latent_encoder: nn.Sequential = attr.ib(init=False, default=None)
- token_embedding: nn.Embedding = attr.ib(init=False, default=None)
- token_pos_encoder: PositionalEncoding = attr.ib(init=False, default=None)
- head: nn.Linear = attr.ib(init=False, default=None)
- mapping: AbstractMapping = attr.ib(init=False, default=None)
+ latent_encoder: nn.Sequential = attr.ib(init=False)
+ token_embedding: nn.Embedding = attr.ib(init=False)
+ token_pos_encoder: PositionalEncoding = attr.ib(init=False)
+ head: nn.Linear = attr.ib(init=False)
+ mapping: Type[AbstractMapping] = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
@@ -187,12 +187,16 @@ class CnnTransformer(nn.Module):
output[:, i : i + 1] = tokens[-1:]
# Early stopping of prediction loop if token is end or padding token.
- if (output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index).all():
+ if (
+ output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index
+ ).all():
break
# Set all tokens after end token to pad token.
for i in range(1, self.max_output_len):
- idx = (output[:, i -1] == self.end_index | output[:, i - 1] == self.pad_index)
+ idx = (
+ output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index
+ )
output[idx, i] = self.pad_index
return output