diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-29 23:59:52 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-29 23:59:52 +0200 |
commit | 34098ccbbbf6379c0bd29a987440b8479c743746 (patch) | |
tree | a8c68e3036503049fc7034c677ec855465f7a8e0 /text_recognizer/networks | |
parent | c032ffb05a7ed86f8fe5d596f94e8997c558cae8 (diff) |
Configs, refactor with attrs, fix attr bug in iam
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/base.py | 18 | ||||
-rw-r--r-- | text_recognizer/networks/conv_transformer.py (renamed from text_recognizer/networks/cnn_tranformer.py) | 27 |
2 files changed, 31 insertions, 14 deletions
diff --git a/text_recognizer/networks/base.py b/text_recognizer/networks/base.py new file mode 100644 index 0000000..07b6a32 --- /dev/null +++ b/text_recognizer/networks/base.py @@ -0,0 +1,18 @@ +"""Base network with required methods.""" +from abc import abstractmethod + +import attr +from torch import nn, Tensor + + +@attr.s +class BaseNetwork(nn.Module): + """Base network.""" + + def __attrs_pre_init__(self) -> None: + super().__init__() + + @abstractmethod + def predict(self, x: Tensor) -> Tensor: + """Return token indices for predictions.""" + ... diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/conv_transformer.py index ce7ec43..4acdc36 100644 --- a/text_recognizer/networks/cnn_tranformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -7,6 +7,7 @@ import torch from torch import nn, Tensor from text_recognizer.data.mappings import AbstractMapping +from text_recognizer.networks.base import BaseNetwork from text_recognizer.networks.encoders.efficientnet import EfficientNet from text_recognizer.networks.transformer.layers import Decoder from text_recognizer.networks.transformer.positional_encodings import ( @@ -15,39 +16,37 @@ from text_recognizer.networks.transformer.positional_encodings import ( ) -@attr.s -class Reader(nn.Module): - def __attrs_pre_init__(self) -> None: - super().__init__() - +@attr.s(auto_attribs=True) +class ConvTransformer(BaseNetwork): # Parameters and placeholders, input_dims: Tuple[int, int, int] = attr.ib() hidden_dim: int = attr.ib() dropout_rate: float = attr.ib() max_output_len: int = attr.ib() num_classes: int = attr.ib() - padding_idx: int = attr.ib() start_token: str = attr.ib() - start_index: int = attr.ib(init=False) + start_index: Tensor = attr.ib(init=False) end_token: str = attr.ib() - end_index: int = attr.ib(init=False) + end_index: Tensor = attr.ib(init=False) pad_token: str = attr.ib() - pad_index: int = attr.ib(init=False) + pad_index: Tensor = attr.ib(init=False) # Modules. encoder: EfficientNet = attr.ib() decoder: Decoder = attr.ib() + mapping: Type[AbstractMapping] = attr.ib() + latent_encoder: nn.Sequential = attr.ib(init=False) token_embedding: nn.Embedding = attr.ib(init=False) token_pos_encoder: PositionalEncoding = attr.ib(init=False) head: nn.Linear = attr.ib(init=False) - mapping: Type[AbstractMapping] = attr.ib(init=False) def __attrs_post_init__(self) -> None: """Post init configuration.""" - self.start_index = int(self.mapping.get_index(self.start_token)) - self.end_index = int(self.mapping.get_index(self.end_token)) - self.pad_index = int(self.mapping.get_index(self.pad_token)) + self.start_index = self.mapping.get_index(self.start_token) + self.end_index = self.mapping.get_index(self.end_token) + self.pad_index = self.mapping.get_index(self.pad_token) + # Latent projector for down sampling number of filters and 2d # positional encoding. self.latent_encoder = nn.Sequential( @@ -130,7 +129,7 @@ class Reader(nn.Module): Returns: Tensor: Sequence of word piece embeddings. """ - context_mask = context != self.padding_idx + context_mask = context != self.pad_index context = self.token_embedding(context) * math.sqrt(self.hidden_dim) context = self.token_pos_encoder(context) out = self.decoder(x=context, context=z, mask=context_mask) |