diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-05 23:05:25 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-05 23:05:25 +0200 |
commit | 4d1f2cef39688871d2caafce42a09316381a27ae (patch) | |
tree | 0f4385969e7df6d7d313cd5910bde9a7475ca027 /text_recognizer/networks | |
parent | f0481decdad9afb52494e9e95996deef843ef233 (diff) |
Refactor with attr, working on cnn+transformer network
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/cnn_tranformer.py | 14 | ||||
-rw-r--r-- | text_recognizer/networks/loss/__init__.py | 2 | ||||
-rw-r--r-- | text_recognizer/networks/loss/label_smoothing_loss.py | 42 | ||||
-rw-r--r-- | text_recognizer/networks/util.py | 7 |
4 files changed, 16 insertions, 49 deletions
diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/cnn_tranformer.py new file mode 100644 index 0000000..da69311 --- /dev/null +++ b/text_recognizer/networks/cnn_tranformer.py @@ -0,0 +1,14 @@ +"""Vision transformer for character recognition.""" +from typing import Type + +import attr +from torch import nn, Tensor + + +@attr.s +class CnnTransformer(nn.Module): + def __attrs_pre_init__(self) -> None: + super().__init__() + + backbone: Type[nn.Module] = attr.ib() + head = Type[nn.Module] = attr.ib() diff --git a/text_recognizer/networks/loss/__init__.py b/text_recognizer/networks/loss/__init__.py deleted file mode 100644 index cb83608..0000000 --- a/text_recognizer/networks/loss/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Loss module.""" -from .loss import LabelSmoothingCrossEntropy diff --git a/text_recognizer/networks/loss/label_smoothing_loss.py b/text_recognizer/networks/loss/label_smoothing_loss.py deleted file mode 100644 index 40a7609..0000000 --- a/text_recognizer/networks/loss/label_smoothing_loss.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Implementations of custom loss functions.""" -import torch -from torch import nn -from torch import Tensor -import torch.nn.functional as F - - -class LabelSmoothingLoss(nn.Module): - """Label smoothing cross entropy loss.""" - - def __init__( - self, label_smoothing: float, vocab_size: int, ignore_index: int = -100 - ) -> None: - assert 0.0 < label_smoothing <= 1.0 - self.ignore_index = ignore_index - super().__init__() - - smoothing_value = label_smoothing / (vocab_size - 2) - one_hot = torch.full((vocab_size,), smoothing_value) - one_hot[self.ignore_index] = 0 - self.register_buffer("one_hot", one_hot.unsqueeze(0)) - - self.confidence = 1.0 - label_smoothing - - def forward(self, output: Tensor, targets: Tensor) -> Tensor: - """Computes the loss. - - Args: - output (Tensor): Predictions from the network. - targets (Tensor): Ground truth. - - Shapes: - outpus: Batch size x num classes - targets: Batch size - - Returns: - Tensor: Label smoothing loss. - """ - model_prob = self.one_hot.repeat(targets.size(0), 1) - model_prob.scatter_(1, targets.unsqueeze(1), self.confidence) - model_prob.masked_fill_((targets == self.ignore_index).unsqueeze(1), 0) - return F.kl_div(output, model_prob, reduction="sum") diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py index 05b10a8..109bf4d 100644 --- a/text_recognizer/networks/util.py +++ b/text_recognizer/networks/util.py @@ -1,10 +1,6 @@ """Miscellaneous neural network functionality.""" -import importlib -from pathlib import Path -from typing import Dict, NamedTuple, Union, Type +from typing import Type -from loguru import logger -import torch from torch import nn @@ -19,6 +15,7 @@ def activation_function(activation: str) -> Type[nn.Module]: ["none", nn.Identity()], ["relu", nn.ReLU(inplace=True)], ["selu", nn.SELU(inplace=True)], + ["mish", nn.Mish(inplace=True)], ] ) return activation_fns[activation.lower()] |