summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/attention.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer/attention.py')
-rw-r--r--text_recognizer/networks/transformer/attention.py93
1 files changed, 93 insertions, 0 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
new file mode 100644
index 0000000..cce1ecc
--- /dev/null
+++ b/text_recognizer/networks/transformer/attention.py
@@ -0,0 +1,93 @@
+"""Implementes the attention module for the transformer."""
+from typing import Optional, Tuple
+
+from einops import rearrange
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class MultiHeadAttention(nn.Module):
+ """Implementation of multihead attention."""
+
+ def __init__(
+ self, hidden_dim: int, num_heads: int = 8, dropout_rate: float = 0.0
+ ) -> None:
+ super().__init__()
+ self.hidden_dim = hidden_dim
+ self.num_heads = num_heads
+ self.fc_q = nn.Linear(
+ in_features=hidden_dim, out_features=hidden_dim, bias=False
+ )
+ self.fc_k = nn.Linear(
+ in_features=hidden_dim, out_features=hidden_dim, bias=False
+ )
+ self.fc_v = nn.Linear(
+ in_features=hidden_dim, out_features=hidden_dim, bias=False
+ )
+ self.fc_out = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
+
+ self._init_weights()
+
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def _init_weights(self) -> None:
+ nn.init.normal_(
+ self.fc_q.weight,
+ mean=0,
+ std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
+ )
+ nn.init.normal_(
+ self.fc_k.weight,
+ mean=0,
+ std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
+ )
+ nn.init.normal_(
+ self.fc_v.weight,
+ mean=0,
+ std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
+ )
+ nn.init.xavier_normal_(self.fc_out.weight)
+
+ def scaled_dot_product_attention(
+ self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
+ ) -> Tensor:
+ """Calculates the scaled dot product attention."""
+
+ # Compute the energy.
+ energy = torch.einsum("bhlk,bhtk->bhlt", [query, key]) / np.sqrt(
+ query.shape[-1]
+ )
+
+ # If we have a mask for padding some inputs.
+ if mask is not None:
+ energy = energy.masked_fill(mask == 0, -np.inf)
+
+ # Compute the attention from the energy.
+ attention = torch.softmax(energy, dim=3)
+
+ out = torch.einsum("bhlt,bhtv->bhlv", [attention, value])
+ out = rearrange(out, "b head l v -> b l (head v)")
+ return out, attention
+
+ def forward(
+ self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
+ ) -> Tuple[Tensor, Tensor]:
+ """Forward pass for computing the multihead attention."""
+ # Get the query, key, and value tensor.
+ query = rearrange(
+ self.fc_q(query), "b l (head k) -> b head l k", head=self.num_heads
+ )
+ key = rearrange(
+ self.fc_k(key), "b t (head k) -> b head t k", head=self.num_heads
+ )
+ value = rearrange(
+ self.fc_v(value), "b t (head v) -> b head t v", head=self.num_heads
+ )
+
+ out, attention = self.scaled_dot_product_attention(query, key, value, mask)
+
+ out = self.fc_out(out)
+ out = self.dropout(out)
+ return out, attention