diff options
Diffstat (limited to 'src/text_recognizer/networks/transformer/attention.py')
-rw-r--r-- | src/text_recognizer/networks/transformer/attention.py | 93 |
1 files changed, 93 insertions, 0 deletions
diff --git a/src/text_recognizer/networks/transformer/attention.py b/src/text_recognizer/networks/transformer/attention.py new file mode 100644 index 0000000..cce1ecc --- /dev/null +++ b/src/text_recognizer/networks/transformer/attention.py @@ -0,0 +1,93 @@ +"""Implementes the attention module for the transformer.""" +from typing import Optional, Tuple + +from einops import rearrange +import numpy as np +import torch +from torch import nn +from torch import Tensor + + +class MultiHeadAttention(nn.Module): + """Implementation of multihead attention.""" + + def __init__( + self, hidden_dim: int, num_heads: int = 8, dropout_rate: float = 0.0 + ) -> None: + super().__init__() + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.fc_q = nn.Linear( + in_features=hidden_dim, out_features=hidden_dim, bias=False + ) + self.fc_k = nn.Linear( + in_features=hidden_dim, out_features=hidden_dim, bias=False + ) + self.fc_v = nn.Linear( + in_features=hidden_dim, out_features=hidden_dim, bias=False + ) + self.fc_out = nn.Linear(in_features=hidden_dim, out_features=hidden_dim) + + self._init_weights() + + self.dropout = nn.Dropout(p=dropout_rate) + + def _init_weights(self) -> None: + nn.init.normal_( + self.fc_q.weight, + mean=0, + std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), + ) + nn.init.normal_( + self.fc_k.weight, + mean=0, + std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), + ) + nn.init.normal_( + self.fc_v.weight, + mean=0, + std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), + ) + nn.init.xavier_normal_(self.fc_out.weight) + + def scaled_dot_product_attention( + self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None + ) -> Tensor: + """Calculates the scaled dot product attention.""" + + # Compute the energy. + energy = torch.einsum("bhlk,bhtk->bhlt", [query, key]) / np.sqrt( + query.shape[-1] + ) + + # If we have a mask for padding some inputs. + if mask is not None: + energy = energy.masked_fill(mask == 0, -np.inf) + + # Compute the attention from the energy. + attention = torch.softmax(energy, dim=3) + + out = torch.einsum("bhlt,bhtv->bhlv", [attention, value]) + out = rearrange(out, "b head l v -> b l (head v)") + return out, attention + + def forward( + self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None + ) -> Tuple[Tensor, Tensor]: + """Forward pass for computing the multihead attention.""" + # Get the query, key, and value tensor. + query = rearrange( + self.fc_q(query), "b l (head k) -> b head l k", head=self.num_heads + ) + key = rearrange( + self.fc_k(key), "b t (head k) -> b head t k", head=self.num_heads + ) + value = rearrange( + self.fc_v(value), "b t (head v) -> b head t v", head=self.num_heads + ) + + out, attention = self.scaled_dot_product_attention(query, key, value, mask) + + out = self.fc_out(out) + out = self.dropout(out) + return out, attention |