summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-05 23:05:25 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-05 23:05:25 +0200
commit4d1f2cef39688871d2caafce42a09316381a27ae (patch)
tree0f4385969e7df6d7d313cd5910bde9a7475ca027 /text_recognizer/networks
parentf0481decdad9afb52494e9e95996deef843ef233 (diff)
Refactor with attr, working on cnn+transformer network
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/cnn_tranformer.py14
-rw-r--r--text_recognizer/networks/loss/__init__.py2
-rw-r--r--text_recognizer/networks/loss/label_smoothing_loss.py42
-rw-r--r--text_recognizer/networks/util.py7
4 files changed, 16 insertions, 49 deletions
diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/cnn_tranformer.py
new file mode 100644
index 0000000..da69311
--- /dev/null
+++ b/text_recognizer/networks/cnn_tranformer.py
@@ -0,0 +1,14 @@
+"""Vision transformer for character recognition."""
+from typing import Type
+
+import attr
+from torch import nn, Tensor
+
+
+@attr.s
+class CnnTransformer(nn.Module):
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ backbone: Type[nn.Module] = attr.ib()
+ head = Type[nn.Module] = attr.ib()
diff --git a/text_recognizer/networks/loss/__init__.py b/text_recognizer/networks/loss/__init__.py
deleted file mode 100644
index cb83608..0000000
--- a/text_recognizer/networks/loss/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-"""Loss module."""
-from .loss import LabelSmoothingCrossEntropy
diff --git a/text_recognizer/networks/loss/label_smoothing_loss.py b/text_recognizer/networks/loss/label_smoothing_loss.py
deleted file mode 100644
index 40a7609..0000000
--- a/text_recognizer/networks/loss/label_smoothing_loss.py
+++ /dev/null
@@ -1,42 +0,0 @@
-"""Implementations of custom loss functions."""
-import torch
-from torch import nn
-from torch import Tensor
-import torch.nn.functional as F
-
-
-class LabelSmoothingLoss(nn.Module):
- """Label smoothing cross entropy loss."""
-
- def __init__(
- self, label_smoothing: float, vocab_size: int, ignore_index: int = -100
- ) -> None:
- assert 0.0 < label_smoothing <= 1.0
- self.ignore_index = ignore_index
- super().__init__()
-
- smoothing_value = label_smoothing / (vocab_size - 2)
- one_hot = torch.full((vocab_size,), smoothing_value)
- one_hot[self.ignore_index] = 0
- self.register_buffer("one_hot", one_hot.unsqueeze(0))
-
- self.confidence = 1.0 - label_smoothing
-
- def forward(self, output: Tensor, targets: Tensor) -> Tensor:
- """Computes the loss.
-
- Args:
- output (Tensor): Predictions from the network.
- targets (Tensor): Ground truth.
-
- Shapes:
- outpus: Batch size x num classes
- targets: Batch size
-
- Returns:
- Tensor: Label smoothing loss.
- """
- model_prob = self.one_hot.repeat(targets.size(0), 1)
- model_prob.scatter_(1, targets.unsqueeze(1), self.confidence)
- model_prob.masked_fill_((targets == self.ignore_index).unsqueeze(1), 0)
- return F.kl_div(output, model_prob, reduction="sum")
diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py
index 05b10a8..109bf4d 100644
--- a/text_recognizer/networks/util.py
+++ b/text_recognizer/networks/util.py
@@ -1,10 +1,6 @@
"""Miscellaneous neural network functionality."""
-import importlib
-from pathlib import Path
-from typing import Dict, NamedTuple, Union, Type
+from typing import Type
-from loguru import logger
-import torch
from torch import nn
@@ -19,6 +15,7 @@ def activation_function(activation: str) -> Type[nn.Module]:
["none", nn.Identity()],
["relu", nn.ReLU(inplace=True)],
["selu", nn.SELU(inplace=True)],
+ ["mish", nn.Mish(inplace=True)],
]
)
return activation_fns[activation.lower()]