summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-29 23:59:52 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-29 23:59:52 +0200
commit34098ccbbbf6379c0bd29a987440b8479c743746 (patch)
treea8c68e3036503049fc7034c677ec855465f7a8e0 /text_recognizer/networks
parentc032ffb05a7ed86f8fe5d596f94e8997c558cae8 (diff)
Configs, refactor with attrs, fix attr bug in iam
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/base.py18
-rw-r--r--text_recognizer/networks/conv_transformer.py (renamed from text_recognizer/networks/cnn_tranformer.py)27
2 files changed, 31 insertions, 14 deletions
diff --git a/text_recognizer/networks/base.py b/text_recognizer/networks/base.py
new file mode 100644
index 0000000..07b6a32
--- /dev/null
+++ b/text_recognizer/networks/base.py
@@ -0,0 +1,18 @@
+"""Base network with required methods."""
+from abc import abstractmethod
+
+import attr
+from torch import nn, Tensor
+
+
+@attr.s
+class BaseNetwork(nn.Module):
+ """Base network."""
+
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ @abstractmethod
+ def predict(self, x: Tensor) -> Tensor:
+ """Return token indices for predictions."""
+ ...
diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/conv_transformer.py
index ce7ec43..4acdc36 100644
--- a/text_recognizer/networks/cnn_tranformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -7,6 +7,7 @@ import torch
from torch import nn, Tensor
from text_recognizer.data.mappings import AbstractMapping
+from text_recognizer.networks.base import BaseNetwork
from text_recognizer.networks.encoders.efficientnet import EfficientNet
from text_recognizer.networks.transformer.layers import Decoder
from text_recognizer.networks.transformer.positional_encodings import (
@@ -15,39 +16,37 @@ from text_recognizer.networks.transformer.positional_encodings import (
)
-@attr.s
-class Reader(nn.Module):
- def __attrs_pre_init__(self) -> None:
- super().__init__()
-
+@attr.s(auto_attribs=True)
+class ConvTransformer(BaseNetwork):
# Parameters and placeholders,
input_dims: Tuple[int, int, int] = attr.ib()
hidden_dim: int = attr.ib()
dropout_rate: float = attr.ib()
max_output_len: int = attr.ib()
num_classes: int = attr.ib()
- padding_idx: int = attr.ib()
start_token: str = attr.ib()
- start_index: int = attr.ib(init=False)
+ start_index: Tensor = attr.ib(init=False)
end_token: str = attr.ib()
- end_index: int = attr.ib(init=False)
+ end_index: Tensor = attr.ib(init=False)
pad_token: str = attr.ib()
- pad_index: int = attr.ib(init=False)
+ pad_index: Tensor = attr.ib(init=False)
# Modules.
encoder: EfficientNet = attr.ib()
decoder: Decoder = attr.ib()
+ mapping: Type[AbstractMapping] = attr.ib()
+
latent_encoder: nn.Sequential = attr.ib(init=False)
token_embedding: nn.Embedding = attr.ib(init=False)
token_pos_encoder: PositionalEncoding = attr.ib(init=False)
head: nn.Linear = attr.ib(init=False)
- mapping: Type[AbstractMapping] = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
- self.start_index = int(self.mapping.get_index(self.start_token))
- self.end_index = int(self.mapping.get_index(self.end_token))
- self.pad_index = int(self.mapping.get_index(self.pad_token))
+ self.start_index = self.mapping.get_index(self.start_token)
+ self.end_index = self.mapping.get_index(self.end_token)
+ self.pad_index = self.mapping.get_index(self.pad_token)
+
# Latent projector for down sampling number of filters and 2d
# positional encoding.
self.latent_encoder = nn.Sequential(
@@ -130,7 +129,7 @@ class Reader(nn.Module):
Returns:
Tensor: Sequence of word piece embeddings.
"""
- context_mask = context != self.padding_idx
+ context_mask = context != self.pad_index
context = self.token_embedding(context) * math.sqrt(self.hidden_dim)
context = self.token_pos_encoder(context)
out = self.decoder(x=context, context=z, mask=context_mask)