diff options
Diffstat (limited to 'src/text_recognizer/networks/transformer/attention.py')
-rw-r--r-- | src/text_recognizer/networks/transformer/attention.py | 93 |
1 files changed, 0 insertions, 93 deletions
diff --git a/src/text_recognizer/networks/transformer/attention.py b/src/text_recognizer/networks/transformer/attention.py deleted file mode 100644 index cce1ecc..0000000 --- a/src/text_recognizer/networks/transformer/attention.py +++ /dev/null @@ -1,93 +0,0 @@ -"""Implementes the attention module for the transformer.""" -from typing import Optional, Tuple - -from einops import rearrange -import numpy as np -import torch -from torch import nn -from torch import Tensor - - -class MultiHeadAttention(nn.Module): - """Implementation of multihead attention.""" - - def __init__( - self, hidden_dim: int, num_heads: int = 8, dropout_rate: float = 0.0 - ) -> None: - super().__init__() - self.hidden_dim = hidden_dim - self.num_heads = num_heads - self.fc_q = nn.Linear( - in_features=hidden_dim, out_features=hidden_dim, bias=False - ) - self.fc_k = nn.Linear( - in_features=hidden_dim, out_features=hidden_dim, bias=False - ) - self.fc_v = nn.Linear( - in_features=hidden_dim, out_features=hidden_dim, bias=False - ) - self.fc_out = nn.Linear(in_features=hidden_dim, out_features=hidden_dim) - - self._init_weights() - - self.dropout = nn.Dropout(p=dropout_rate) - - def _init_weights(self) -> None: - nn.init.normal_( - self.fc_q.weight, - mean=0, - std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), - ) - nn.init.normal_( - self.fc_k.weight, - mean=0, - std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), - ) - nn.init.normal_( - self.fc_v.weight, - mean=0, - std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), - ) - nn.init.xavier_normal_(self.fc_out.weight) - - def scaled_dot_product_attention( - self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None - ) -> Tensor: - """Calculates the scaled dot product attention.""" - - # Compute the energy. - energy = torch.einsum("bhlk,bhtk->bhlt", [query, key]) / np.sqrt( - query.shape[-1] - ) - - # If we have a mask for padding some inputs. - if mask is not None: - energy = energy.masked_fill(mask == 0, -np.inf) - - # Compute the attention from the energy. - attention = torch.softmax(energy, dim=3) - - out = torch.einsum("bhlt,bhtv->bhlv", [attention, value]) - out = rearrange(out, "b head l v -> b l (head v)") - return out, attention - - def forward( - self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None - ) -> Tuple[Tensor, Tensor]: - """Forward pass for computing the multihead attention.""" - # Get the query, key, and value tensor. - query = rearrange( - self.fc_q(query), "b l (head k) -> b head l k", head=self.num_heads - ) - key = rearrange( - self.fc_k(key), "b t (head k) -> b head t k", head=self.num_heads - ) - value = rearrange( - self.fc_v(value), "b t (head v) -> b head t v", head=self.num_heads - ) - - out, attention = self.scaled_dot_product_attention(query, key, value, mask) - - out = self.fc_out(out) - out = self.dropout(out) - return out, attention |