summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-01 23:53:50 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-01 23:53:50 +0200
commit58ae7154aa945cfe5a46592cc1dfb28f0a4e51b3 (patch)
treec89c1b1a4cc1a499900f2700ab09e8535e2cfe99 /text_recognizer/networks
parent7ae1f8f9654dcea0a9a22310ac0665a5d3202f0f (diff)
Working on new attention module
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/__init__.py2
-rw-r--r--text_recognizer/networks/encoders/__init__.py (renamed from text_recognizer/networks/backbones/__init__.py)0
-rw-r--r--text_recognizer/networks/encoders/efficientnet.py (renamed from text_recognizer/networks/backbones/efficientnet.py)0
-rw-r--r--text_recognizer/networks/encoders/residual_network.py (renamed from text_recognizer/networks/residual_network.py)0
-rw-r--r--text_recognizer/networks/encoders/wide_resnet.py (renamed from text_recognizer/networks/wide_resnet.py)0
-rw-r--r--text_recognizer/networks/transformer/attention.py119
-rw-r--r--text_recognizer/networks/transformer/norm.py13
7 files changed, 63 insertions, 71 deletions
diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py
index 63b43b2..a9117f8 100644
--- a/text_recognizer/networks/__init__.py
+++ b/text_recognizer/networks/__init__.py
@@ -1,4 +1,4 @@
"""Network modules"""
-from .backbones import EfficientNet
+from .encoders import EfficientNet
from .vqvae import VQVAE
from .cnn_transformer import CNNTransformer
diff --git a/text_recognizer/networks/backbones/__init__.py b/text_recognizer/networks/encoders/__init__.py
index 25aed0e..25aed0e 100644
--- a/text_recognizer/networks/backbones/__init__.py
+++ b/text_recognizer/networks/encoders/__init__.py
diff --git a/text_recognizer/networks/backbones/efficientnet.py b/text_recognizer/networks/encoders/efficientnet.py
index 61dea77..61dea77 100644
--- a/text_recognizer/networks/backbones/efficientnet.py
+++ b/text_recognizer/networks/encoders/efficientnet.py
diff --git a/text_recognizer/networks/residual_network.py b/text_recognizer/networks/encoders/residual_network.py
index c33f419..c33f419 100644
--- a/text_recognizer/networks/residual_network.py
+++ b/text_recognizer/networks/encoders/residual_network.py
diff --git a/text_recognizer/networks/wide_resnet.py b/text_recognizer/networks/encoders/wide_resnet.py
index b767778..b767778 100644
--- a/text_recognizer/networks/wide_resnet.py
+++ b/text_recognizer/networks/encoders/wide_resnet.py
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index ac75d2f..e1324af 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -1,94 +1,73 @@
"""Implementes the attention module for the transformer."""
from typing import Optional, Tuple
-from einops import rearrange
+from einops.layers.torch import Rearrange
import numpy as np
import torch
from torch import nn
from torch import Tensor
+import torch.nn.functional as F
+from text_recognizer.networks.transformer.rotary_embedding import apply_rotary_pos_emb
-class MultiHeadAttention(nn.Module):
- """Implementation of multihead attention."""
+class Attention(nn.Module):
def __init__(
- self, hidden_dim: int, num_heads: int = 8, dropout_rate: float = 0.0
+ self,
+ dim: int,
+ num_heads: int,
+ dim_head: int = 64,
+ dropout_rate: float = 0.0,
+ causal: bool = False,
) -> None:
- super().__init__()
- self.hidden_dim = hidden_dim
+ self.scale = dim ** -0.5
self.num_heads = num_heads
- self.fc_q = nn.Linear(
- in_features=hidden_dim, out_features=hidden_dim, bias=False
- )
- self.fc_k = nn.Linear(
- in_features=hidden_dim, out_features=hidden_dim, bias=False
- )
- self.fc_v = nn.Linear(
- in_features=hidden_dim, out_features=hidden_dim, bias=False
- )
- self.fc_out = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
-
- self._init_weights()
+ self.causal = causal
+ inner_dim = dim * dim_head
- self.dropout = nn.Dropout(p=dropout_rate)
-
- def _init_weights(self) -> None:
- nn.init.normal_(
- self.fc_q.weight,
- mean=0,
- std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
- )
- nn.init.normal_(
- self.fc_k.weight,
- mean=0,
- std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
+ # Attnetion
+ self.qkv_fn = nn.Sequential(
+ nn.Linear(dim, 3 * inner_dim, bias=False),
+ Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=self.num_heads),
)
- nn.init.normal_(
- self.fc_v.weight,
- mean=0,
- std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
- )
- nn.init.xavier_normal_(self.fc_out.weight)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.attn_fn = F.softmax
- @staticmethod
- def scaled_dot_product_attention(
- query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
- ) -> Tensor:
- """Calculates the scaled dot product attention."""
+ # Feedforward
+ self.proj = nn.Linear(inner_dim, dim)
- # Compute the energy.
- energy = torch.einsum("bhlk,bhtk->bhlt", [query, key]) / np.sqrt(
- query.shape[-1]
- )
-
- # If we have a mask for padding some inputs.
- if mask is not None:
- energy = energy.masked_fill(mask == 0, -np.inf)
-
- # Compute the attention from the energy.
- attention = torch.softmax(energy, dim=3)
+ @staticmethod
+ def _apply_rotary_emb(
+ q: Tensor, k: Tensor, rotary_pos_emb: Tensor
+ ) -> Tuple[Tensor, Tensor]:
+ l = rotary_pos_emb.shape[-1]
+ (ql, qr), (kl, kr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k))
+ ql, kl = apply_rotary_pos_emb(ql, kl, rotary_pos_emb)
+ q = torch.cat((ql, qr), dim=-1)
+ k = torch.cat((kl, kr), dim=-1)
+ return q, k
- out = torch.einsum("bhlt,bhtv->bhlv", [attention, value])
- out = rearrange(out, "b head l v -> b l (head v)")
- return out, attention
+ def _cross_attention(self) -> Tensor:
+ pass
def forward(
- self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
+ self,
+ x: Tensor,
+ context: Optional[Tensor],
+ mask: Optional[Tensor],
+ context_mask: Optional[Tensor],
+ rotary_pos_emb: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
- """Forward pass for computing the multihead attention."""
- # Get the query, key, and value tensor.
- query = rearrange(
- self.fc_q(query), "b l (head k) -> b head l k", head=self.num_heads
- )
- key = rearrange(
- self.fc_k(key), "b t (head k) -> b head t k", head=self.num_heads
- )
- value = rearrange(
- self.fc_v(value), "b t (head v) -> b head t v", head=self.num_heads
+ q, k, v = self.qkv_fn(x)
+ q, k = (
+ self._apply_rotary_emb(q, k, rotary_pos_emb)
+ if rotary_pos_emb is not None
+ else q,
+ k,
)
- out, attention = self.scaled_dot_product_attention(query, key, value, mask)
+ if any(x is not None for x in (mask, context_mask)):
+ pass
- out = self.fc_out(out)
- out = self.dropout(out)
- return out, attention
+ # Compute the attention
+ energy = (q @ k.transpose(-2, -1)) * self.scale
diff --git a/text_recognizer/networks/transformer/norm.py b/text_recognizer/networks/transformer/norm.py
index 99a5291..9160876 100644
--- a/text_recognizer/networks/transformer/norm.py
+++ b/text_recognizer/networks/transformer/norm.py
@@ -20,3 +20,16 @@ class Rezero(nn.Module):
def forward(self, x: Tensor, **kwargs: Dict) -> Tensor:
x, *rest = self.fn(x, **kwargs)
return (x * self.g, *rest)
+
+
+class ScaleNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1.0e-5) -> None:
+ super().__init__()
+ self.scale = dim ** -0.5
+ self.eps = eps
+ self.g = nn.Parameter(torch.ones(1))
+
+ def forward(self, x: Tensor) -> Tensor:
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
+ return x / norm.clamp(min=self.eps) self.g
+