diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-30 23:15:03 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-30 23:15:03 +0200 |
commit | 7268035fb9e57342612a8cc50a1fe04e8841ca2f (patch) | |
tree | 8d4cf3743975bd25f2c04d6a56ff3d4608a7e8d9 /text_recognizer | |
parent | 92fc1c7ed2f9f64552be8f71d9b8ab0d5a0a88d4 (diff) |
attr bug fix, properly loading network
Diffstat (limited to 'text_recognizer')
22 files changed, 165 insertions, 207 deletions
diff --git a/text_recognizer/data/__init__.py b/text_recognizer/data/__init__.py index 3599a8b..2727b20 100644 --- a/text_recognizer/data/__init__.py +++ b/text_recognizer/data/__init__.py @@ -1,7 +1 @@ """Dataset modules.""" -from .base_dataset import BaseDataset, convert_strings_to_labels, split_dataset -from .base_data_module import BaseDataModule, load_and_print_info -from .download_utils import download_dataset -from .iam_paragraphs import IAMParagraphs -from .iam_synthetic_paragraphs import IAMSyntheticParagraphs -from .iam_extended_paragraphs import IAMExtendedParagraphs diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index 18b1996..408ae36 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -17,7 +17,7 @@ def load_and_print_info(data_module_class: type) -> None: print(dataset) -@attr.s +@attr.s(repr=False) class BaseDataModule(LightningDataModule): """Base PyTorch Lightning DataModule.""" diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 4747508..7548ad5 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -32,7 +32,7 @@ IMAGE_X_PADDING = 28 MAX_OUTPUT_LENGTH = 89 # Same as IAMLines -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, repr=False) class EMNISTLines(BaseDataModule): """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,""" diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 58c7369..23e424d 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -10,7 +10,7 @@ from text_recognizer.data.iam_paragraphs import IAMParagraphs from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, repr=False) class IAMExtendedParagraphs(BaseDataModule): augment: bool = attr.ib(default=True) diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index 13dd379..b7f3fdd 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -37,7 +37,7 @@ IMAGE_WIDTH = 1024 MAX_LABEL_LENGTH = 89 -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, repr=False) class IAMLines(BaseDataModule): """IAM handwritten lines dataset.""" diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index de32875..82058e0 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -34,7 +34,7 @@ IMAGE_WIDTH = 1280 // IMAGE_SCALE_FACTOR MAX_LABEL_LENGTH = 682 -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, repr=False) class IAMParagraphs(BaseDataModule): """IAM handwriting database paragraphs.""" diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py index f7457e4..93a13bb 100644 --- a/text_recognizer/data/iam_preprocessor.py +++ b/text_recognizer/data/iam_preprocessor.py @@ -9,7 +9,7 @@ import collections import itertools from pathlib import Path import re -from typing import List, Optional, Union +from typing import List, Optional, Union, Sequence import click from loguru import logger @@ -57,15 +57,13 @@ class Preprocessor: lexicon_path: Optional[Union[str, Path]] = None, use_words: bool = False, prepend_wordsep: bool = False, - special_tokens: Optional[List[str]] = None, + special_tokens: Optional[Sequence[str]] = None, ) -> None: self.wordsep = "▁" self._use_word = use_words self._prepend_wordsep = prepend_wordsep self.special_tokens = special_tokens if special_tokens is not None else None - self.data_dir = Path(data_dir) - self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words) # Load the set of graphemes: @@ -123,7 +121,7 @@ class Preprocessor: self.text.append(example["text"].lower()) def _to_index(self, line: str) -> torch.LongTensor: - if line in self.special_tokens: + if self.special_tokens is not None and line in self.special_tokens: return torch.LongTensor([self.tokens_to_index[line]]) token_to_index = self.graphemes_to_index if self.lexicon is not None: diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index a3697e7..f00a494 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -34,7 +34,7 @@ PROCESSED_DATA_DIRNAME = ( ) -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, repr=False) class IAMSyntheticParagraphs(IAMParagraphs): """IAM Handwriting database of synthetic paragraphs.""" diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py index 0d778b2..a934fd9 100644 --- a/text_recognizer/data/mappings.py +++ b/text_recognizer/data/mappings.py @@ -1,8 +1,9 @@ """Mapping to and from word pieces.""" from abc import ABC, abstractmethod from pathlib import Path -from typing import List, Optional, Union, Sequence +from typing import Dict, List, Optional, Union, Set, Sequence +import attr from loguru import logger import torch from torch import Tensor @@ -29,10 +30,17 @@ class AbstractMapping(ABC): ... +@attr.s class EmnistMapping(AbstractMapping): - def __init__(self, extra_symbols: Optional[Sequence[str]]) -> None: + extra_symbols: Optional[Set[str]] = attr.ib(default=None, converter=set) + mapping: Sequence[str] = attr.ib(init=False) + inverse_mapping: Dict[str, int] = attr.ib(init=False) + input_size: List[int] = attr.ib(init=False) + + def __attrs_post_init__(self) -> None: + """Post init configuration.""" self.mapping, self.inverse_mapping, self.input_size = emnist_mapping( - extra_symbols + self.extra_symbols ) def get_token(self, index: Union[int, Tensor]) -> str: @@ -54,42 +62,21 @@ class EmnistMapping(AbstractMapping): return Tensor([self.inverse_mapping[token] for token in text]) +@attr.s(auto_attribs=True) class WordPieceMapping(EmnistMapping): - def __init__( - self, - num_features: int = 1000, - tokens: str = "iamdb_1kwp_tokens_1000.txt", - lexicon: str = "iamdb_1kwp_lex_1000.txt", - data_dir: Optional[Union[str, Path]] = None, - use_words: bool = False, - prepend_wordsep: bool = False, - special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"), - extra_symbols: Optional[Sequence[str]] = ("\n",), - ) -> None: - super().__init__(extra_symbols) - self.wordpiece_processor = self._configure_wordpiece_processor( - num_features, - tokens, - lexicon, - data_dir, - use_words, - prepend_wordsep, - special_tokens, - extra_symbols, - ) - - @staticmethod - def _configure_wordpiece_processor( - num_features: int, - tokens: str, - lexicon: str, - data_dir: Optional[Union[str, Path]], - use_words: bool, - prepend_wordsep: bool, - special_tokens: Optional[Sequence[str]], - extra_symbols: Optional[Sequence[str]], - ) -> Preprocessor: - data_dir = ( + data_dir: Optional[Path] = attr.ib(default=None) + num_features: int = attr.ib(default=1000) + tokens: str = attr.ib(default="iamdb_1kwp_tokens_1000.txt") + lexicon: str = attr.ib(default="iamdb_1kwp_lex_1000.txt") + use_words: bool = attr.ib(default=False) + prepend_wordsep: bool = attr.ib(default=False) + special_tokens: Set[str] = attr.ib(default={"<s>", "<e>", "<p>"}, converter=set) + extra_symbols: Set[str] = attr.ib(default={"\n",}, converter=set) + wordpiece_processor: Preprocessor = attr.ib(init=False) + + def __attrs_post_init__(self) -> None: + super().__attrs_post_init__() + self.data_dir = ( ( Path(__file__).resolve().parents[2] / "data" @@ -97,32 +84,32 @@ class WordPieceMapping(EmnistMapping): / "iam" / "iamdb" ) - if data_dir is None - else Path(data_dir) + if self.data_dir is None + else Path(self.data_dir) ) - - logger.debug(f"Using data dir: {data_dir}") - if not data_dir.exists(): - raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") + logger.debug(f"Using data dir: {self.data_dir}") + if not self.data_dir.exists(): + raise RuntimeError(f"Could not locate iamdb directory at {self.data_dir}") processed_path = ( Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines" ) - tokens_path = processed_path / tokens - lexicon_path = processed_path / lexicon - - if extra_symbols is not None: - special_tokens += extra_symbols - - return Preprocessor( - data_dir, - num_features, - tokens_path, - lexicon_path, - use_words, - prepend_wordsep, - special_tokens, + tokens_path = processed_path / self.tokens + lexicon_path = processed_path / self.lexicon + + special_tokens = self.special_tokens + if self.extra_symbols is not None: + special_tokens = special_tokens | self.extra_symbols + + self.wordpiece_processor = Preprocessor( + data_dir=self.data_dir, + num_features=self.num_features, + tokens_path=tokens_path, + lexicon_path=lexicon_path, + use_words=self.use_words, + prepend_wordsep=self.prepend_wordsep, + special_tokens=special_tokens, ) def __len__(self) -> int: @@ -151,7 +138,9 @@ class WordPieceMapping(EmnistMapping): text = text.lower().replace(" ", "▁") return torch.LongTensor(self.wordpiece_processor.to_index(text)) - def __getitem__(self, x: Union[str, int, Tensor]) -> Union[str, Tensor]: + def __getitem__(self, x: Union[str, int, List[int], Tensor]) -> Union[str, Tensor]: + if isinstance(x, int): + x = [x] if isinstance(x, str): - return self.get_index(x) - return self.get_token(x) + return self.get_indices(x) + return self.get_text(x) diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py index 66531a5..3b1b929 100644 --- a/text_recognizer/data/transforms.py +++ b/text_recognizer/data/transforms.py @@ -24,14 +24,14 @@ class WordPiece: max_len: int = 451, ) -> None: self.mapping = WordPieceMapping( - num_features, - tokens, - lexicon, - data_dir, - use_words, - prepend_wordsep, - special_tokens, - extra_symbols, + data_dir=data_dir, + num_features=num_features, + tokens=tokens, + lexicon=lexicon, + use_words=use_words, + prepend_wordsep=prepend_wordsep, + special_tokens=special_tokens, + extra_symbols=extra_symbols, ) self.max_len = max_len diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 3e02261..dfb4ca4 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -11,8 +11,6 @@ from torch import nn from torch import Tensor import torchmetrics -from text_recognizer.networks.base import BaseNetwork - @attr.s class BaseLitModel(LightningModule): @@ -21,7 +19,7 @@ class BaseLitModel(LightningModule): def __attrs_pre_init__(self) -> None: super().__init__() - network: Type[BaseNetwork] = attr.ib() + network: Type[nn.Module] = attr.ib() criterion_config: DictConfig = attr.ib(converter=DictConfig) optimizer_config: DictConfig = attr.ib(converter=DictConfig) lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig) diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py index 4117ae2..9793157 100644 --- a/text_recognizer/models/metrics.py +++ b/text_recognizer/models/metrics.py @@ -1,5 +1,5 @@ """Character Error Rate (CER).""" -from typing import Set, Sequence +from typing import Set import attr import editdistance @@ -12,7 +12,7 @@ from torchmetrics import Metric class CharacterErrorRate(Metric): """Character error rate metric, computed using Levenshtein distance.""" - ignore_tokens: Set = attr.ib(converter=set) + ignore_indices: Set = attr.ib(converter=set) error: Tensor = attr.ib(init=False) total: Tensor = attr.ib(init=False) @@ -25,8 +25,8 @@ class CharacterErrorRate(Metric): """Update CER.""" bsz = preds.shape[0] for index in range(bsz): - pred = [p for p in preds[index].tolist() if p not in self.ignore_tokens] - target = [t for t in targets[index].tolist() if t not in self.ignore_tokens] + pred = [p for p in preds[index].tolist() if p not in self.ignore_indices] + target = [t for t in targets[index].tolist() if t not in self.ignore_indices] distance = editdistance.distance(pred, target) error = distance / max(len(pred), len(target)) self.error += error diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index f5cb491..7a9d566 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,11 +1,11 @@ """PyTorch Lightning model for base Transformers.""" -from typing import Dict, List, Optional, Sequence, Union, Tuple, Type +from typing import Sequence, Tuple, Type import attr -import hydra -from omegaconf import DictConfig -from torch import nn, Tensor +import torch +from torch import Tensor +from text_recognizer.data.mappings import AbstractMapping from text_recognizer.models.metrics import CharacterErrorRate from text_recognizer.models.base import BaseLitModel @@ -13,18 +13,31 @@ from text_recognizer.models.base import BaseLitModel @attr.s(auto_attribs=True) class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" + mapping: Type[AbstractMapping] = attr.ib() + start_token: str = attr.ib() + end_token: str = attr.ib() + pad_token: str = attr.ib() - ignore_tokens: Sequence[str] = attr.ib(default=("<s>", "<e>", "<p>",)) + start_index: Tensor = attr.ib(init=False) + end_index: Tensor = attr.ib(init=False) + pad_index: Tensor = attr.ib(init=False) + + ignore_indices: Sequence[str] = attr.ib(init=False) val_cer: CharacterErrorRate = attr.ib(init=False) test_cer: CharacterErrorRate = attr.ib(init=False) def __attrs_post_init__(self) -> None: - self.val_cer = CharacterErrorRate(self.ignore_tokens) - self.test_cer = CharacterErrorRate(self.ignore_tokens) + """Post init configuration.""" + self.start_index = self.mapping.get_index(self.start_token) + self.end_index = self.mapping.get_index(self.end_token) + self.pad_index = self.mapping.get_index(self.pad_token) + self.ignore_indices = [self.start_index, self.end_index, self.pad_index] + self.val_cer = CharacterErrorRate(self.ignore_indices) + self.test_cer = CharacterErrorRate(self.ignore_indices) def forward(self, data: Tensor) -> Tensor: """Forward pass with the transformer network.""" - return self.network.predict(data) + return self.predict(data) def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: """Training step.""" @@ -38,17 +51,64 @@ class TransformerLitModel(BaseLitModel): """Validation step.""" data, targets = batch + # Compute the loss. logits = self.network(data, targets[:-1]) loss = self.loss_fn(logits, targets[1:]) self.log("val/loss", loss, prog_bar=True) - pred = self.network.predict(data) + # Get the token prediction. + pred = self(data) self.val_cer(pred, targets) self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" data, targets = batch - pred = self.network.predict(data) + + # Compute the text prediction. + pred = self(data) self.test_cer(pred, targets) self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) + + def predict(self, x: Tensor) -> Tensor: + """Predicts text in image. + + Args: + x (Tensor): Image(s) to extract text from. + + Shapes: + - x: :math: `(B, H, W)` + - output: :math: `(B, S)` + + Returns: + Tensor: A tensor of token indices of the predictions from the model. + """ + bsz = x.shape[0] + + # Encode image(s) to latent vectors. + z = self.network.encode(x) + + # Create a placeholder matrix for storing outputs from the network + output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) + output[:, 0] = self.start_index + + for i in range(1, self.max_output_len): + context = output[:, :i] # (bsz, i) + logits = self.network.decode(z, context) # (i, bsz, c) + tokens = torch.argmax(logits, dim=-1) # (i, bsz) + output[:, i : i + 1] = tokens[-1:] + + # Early stopping of prediction loop if token is end or padding token. + if ( + output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index + ).all(): + break + + # Set all tokens after end token to pad token. + for i in range(1, self.max_output_len): + idx = ( + output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index + ) + output[idx, i] = self.pad_index + + return output diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index 0172163..e215e14 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -34,8 +34,6 @@ class VQVAELitModel(BaseLitModel): loss = self.loss_fn(reconstructions, data) loss += vq_loss self.log("val/loss", loss, prog_bar=True) - title = "val_pred_examples" - self._log_prediction(data, reconstructions, title) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" @@ -43,5 +41,4 @@ class VQVAELitModel(BaseLitModel): reconstructions, vq_loss = self.network(data) loss = self.loss_fn(reconstructions, data) loss += vq_loss - title = "test_pred_examples" - self._log_prediction(data, reconstructions, title) + self.log("test/loss", loss) diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py index 618450f..d9ef58b 100644 --- a/text_recognizer/networks/__init__.py +++ b/text_recognizer/networks/__init__.py @@ -1,5 +1 @@ """Network modules""" -# from .encoders import EfficientNet -from .vqvae import VQVAE - -# from .cnn_transformer import CNNTransformer diff --git a/text_recognizer/networks/base.py b/text_recognizer/networks/base.py deleted file mode 100644 index 07b6a32..0000000 --- a/text_recognizer/networks/base.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Base network with required methods.""" -from abc import abstractmethod - -import attr -from torch import nn, Tensor - - -@attr.s -class BaseNetwork(nn.Module): - """Base network.""" - - def __attrs_pre_init__(self) -> None: - super().__init__() - - @abstractmethod - def predict(self, x: Tensor) -> Tensor: - """Return token indices for predictions.""" - ... diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 4acdc36..7371be4 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -1,13 +1,10 @@ """Vision transformer for character recognition.""" import math -from typing import Tuple, Type +from typing import Tuple import attr -import torch from torch import nn, Tensor -from text_recognizer.data.mappings import AbstractMapping -from text_recognizer.networks.base import BaseNetwork from text_recognizer.networks.encoders.efficientnet import EfficientNet from text_recognizer.networks.transformer.layers import Decoder from text_recognizer.networks.transformer.positional_encodings import ( @@ -16,25 +13,24 @@ from text_recognizer.networks.transformer.positional_encodings import ( ) -@attr.s(auto_attribs=True) -class ConvTransformer(BaseNetwork): +@attr.s +class ConvTransformer(nn.Module): + """Convolutional encoder and transformer decoder network.""" + + def __attrs_pre_init__(self) -> None: + super().__init__() + # Parameters and placeholders, input_dims: Tuple[int, int, int] = attr.ib() hidden_dim: int = attr.ib() dropout_rate: float = attr.ib() max_output_len: int = attr.ib() num_classes: int = attr.ib() - start_token: str = attr.ib() - start_index: Tensor = attr.ib(init=False) - end_token: str = attr.ib() - end_index: Tensor = attr.ib(init=False) - pad_token: str = attr.ib() - pad_index: Tensor = attr.ib(init=False) + pad_index: Tensor = attr.ib() # Modules. encoder: EfficientNet = attr.ib() decoder: Decoder = attr.ib() - mapping: Type[AbstractMapping] = attr.ib() latent_encoder: nn.Sequential = attr.ib(init=False) token_embedding: nn.Embedding = attr.ib(init=False) @@ -43,10 +39,6 @@ class ConvTransformer(BaseNetwork): def __attrs_post_init__(self) -> None: """Post init configuration.""" - self.start_index = self.mapping.get_index(self.start_token) - self.end_index = self.mapping.get_index(self.end_token) - self.pad_index = self.mapping.get_index(self.pad_token) - # Latent projector for down sampling number of filters and 2d # positional encoding. self.latent_encoder = nn.Sequential( @@ -156,46 +148,3 @@ class ConvTransformer(BaseNetwork): z = self.encode(x) logits = self.decode(z, context) return logits - - def predict(self, x: Tensor) -> Tensor: - """Predicts text in image. - - Args: - x (Tensor): Image(s) to extract text from. - - Shapes: - - x: :math: `(B, H, W)` - - output: :math: `(B, S)` - - Returns: - Tensor: A tensor of token indices of the predictions from the model. - """ - bsz = x.shape[0] - - # Encode image(s) to latent vectors. - z = self.encode(x) - - # Create a placeholder matrix for storing outputs from the network - output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) - output[:, 0] = self.start_index - - for i in range(1, self.max_output_len): - context = output[:, :i] # (bsz, i) - logits = self.decode(z, context) # (i, bsz, c) - tokens = torch.argmax(logits, dim=-1) # (i, bsz) - output[:, i : i + 1] = tokens[-1:] - - # Early stopping of prediction loop if token is end or padding token. - if ( - output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index - ).all(): - break - - # Set all tokens after end token to pad token. - for i in range(1, self.max_output_len): - idx = ( - output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index - ) - output[idx, i] = self.pad_index - - return output diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 2770dc1..9202cce 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -24,9 +24,9 @@ class Attention(nn.Module): dim: int = attr.ib() num_heads: int = attr.ib() + causal: bool = attr.ib(default=False) dim_head: int = attr.ib(default=64) dropout_rate: float = attr.ib(default=0.0) - casual: bool = attr.ib(default=False) scale: float = attr.ib(init=False) dropout: nn.Dropout = attr.ib(init=False) fc: nn.Linear = attr.ib(init=False) diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index 9b2f236..66c9c50 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -30,8 +30,7 @@ class AttentionLayers(nn.Module): causal: bool = attr.ib(default=False) cross_attend: bool = attr.ib(default=False) pre_norm: bool = attr.ib(default=True) - rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None, init=False) - has_pos_emb: bool = attr.ib(init=False) + rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None) layer_types: Tuple[str, ...] = attr.ib(init=False) layers: nn.ModuleList = attr.ib(init=False) attn: partial = attr.ib(init=False) @@ -40,12 +39,11 @@ class AttentionLayers(nn.Module): def __attrs_post_init__(self) -> None: """Post init configuration.""" - self.has_pos_emb = True if self.rotary_emb is not None else False self.layer_types = self._get_layer_types() * self.depth attn = load_partial_fn( self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs ) - norm = load_partial_fn(self.norm_fn, dim=self.dim) + norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim) ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs) self.layers = self._build_network(attn, norm, ff) @@ -103,13 +101,11 @@ class AttentionLayers(nn.Module): return x +@attr.s(auto_attribs=True) class Encoder(AttentionLayers): - def __init__(self, **kwargs: Any) -> None: - assert "causal" not in kwargs, "Cannot set causality on encoder" - super().__init__(causal=False, **kwargs) + causal: bool = attr.ib(default=False, init=False) +@attr.s(auto_attribs=True) class Decoder(AttentionLayers): - def __init__(self, **kwargs: Any) -> None: - assert "causal" not in kwargs, "Cannot set causality on decoder" - super().__init__(causal=True, **kwargs) + causal: bool = attr.ib(default=True, init=False) diff --git a/text_recognizer/networks/transformer/norm.py b/text_recognizer/networks/transformer/norm.py index 8bc3221..4930adf 100644 --- a/text_recognizer/networks/transformer/norm.py +++ b/text_recognizer/networks/transformer/norm.py @@ -12,9 +12,9 @@ from torch import Tensor class ScaleNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1.0e-5) -> None: + def __init__(self, normalized_shape: int, eps: float = 1.0e-5) -> None: super().__init__() - self.scale = dim ** -0.5 + self.scale = normalized_shape ** -0.5 self.eps = eps self.g = nn.Parameter(torch.ones(1)) @@ -24,9 +24,9 @@ class ScaleNorm(nn.Module): class PreNorm(nn.Module): - def __init__(self, dim: int, fn: Type[nn.Module]) -> None: + def __init__(self, normalized_shape: int, fn: Type[nn.Module]) -> None: super().__init__() - self.norm = nn.LayerNorm(dim) + self.norm = nn.LayerNorm(normalized_shape) self.fn = fn def forward(self, x: Tensor, **kwargs: Dict) -> Tensor: diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py index e822c57..c94e8dc 100644 --- a/text_recognizer/networks/util.py +++ b/text_recognizer/networks/util.py @@ -24,6 +24,6 @@ def activation_function(activation: str) -> Type[nn.Module]: def load_partial_fn(fn: str, **kwargs: Any) -> partial: - """Loads partial function.""" + """Loads partial function/class.""" module = import_module(".".join(fn.split(".")[:-1])) - return partial(getattr(module, fn.split(".")[0]), **kwargs) + return partial(getattr(module, fn.split(".")[-1]), **kwargs) diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py index 1f08e5e..5aa929b 100644 --- a/text_recognizer/networks/vqvae/vqvae.py +++ b/text_recognizer/networks/vqvae/vqvae.py @@ -1,5 +1,4 @@ """The VQ-VAE.""" - from typing import Any, Dict, List, Optional, Tuple from torch import nn |