summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/transformer/attention.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/transformer/attention.py')
-rw-r--r--src/text_recognizer/networks/transformer/attention.py93
1 files changed, 0 insertions, 93 deletions
diff --git a/src/text_recognizer/networks/transformer/attention.py b/src/text_recognizer/networks/transformer/attention.py
deleted file mode 100644
index cce1ecc..0000000
--- a/src/text_recognizer/networks/transformer/attention.py
+++ /dev/null
@@ -1,93 +0,0 @@
-"""Implementes the attention module for the transformer."""
-from typing import Optional, Tuple
-
-from einops import rearrange
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-
-
-class MultiHeadAttention(nn.Module):
- """Implementation of multihead attention."""
-
- def __init__(
- self, hidden_dim: int, num_heads: int = 8, dropout_rate: float = 0.0
- ) -> None:
- super().__init__()
- self.hidden_dim = hidden_dim
- self.num_heads = num_heads
- self.fc_q = nn.Linear(
- in_features=hidden_dim, out_features=hidden_dim, bias=False
- )
- self.fc_k = nn.Linear(
- in_features=hidden_dim, out_features=hidden_dim, bias=False
- )
- self.fc_v = nn.Linear(
- in_features=hidden_dim, out_features=hidden_dim, bias=False
- )
- self.fc_out = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
-
- self._init_weights()
-
- self.dropout = nn.Dropout(p=dropout_rate)
-
- def _init_weights(self) -> None:
- nn.init.normal_(
- self.fc_q.weight,
- mean=0,
- std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
- )
- nn.init.normal_(
- self.fc_k.weight,
- mean=0,
- std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
- )
- nn.init.normal_(
- self.fc_v.weight,
- mean=0,
- std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
- )
- nn.init.xavier_normal_(self.fc_out.weight)
-
- def scaled_dot_product_attention(
- self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
- ) -> Tensor:
- """Calculates the scaled dot product attention."""
-
- # Compute the energy.
- energy = torch.einsum("bhlk,bhtk->bhlt", [query, key]) / np.sqrt(
- query.shape[-1]
- )
-
- # If we have a mask for padding some inputs.
- if mask is not None:
- energy = energy.masked_fill(mask == 0, -np.inf)
-
- # Compute the attention from the energy.
- attention = torch.softmax(energy, dim=3)
-
- out = torch.einsum("bhlt,bhtv->bhlv", [attention, value])
- out = rearrange(out, "b head l v -> b l (head v)")
- return out, attention
-
- def forward(
- self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
- ) -> Tuple[Tensor, Tensor]:
- """Forward pass for computing the multihead attention."""
- # Get the query, key, and value tensor.
- query = rearrange(
- self.fc_q(query), "b l (head k) -> b head l k", head=self.num_heads
- )
- key = rearrange(
- self.fc_k(key), "b t (head k) -> b head t k", head=self.num_heads
- )
- value = rearrange(
- self.fc_v(value), "b t (head v) -> b head t v", head=self.num_heads
- )
-
- out, attention = self.scaled_dot_product_attention(query, key, value, mask)
-
- out = self.fc_out(out)
- out = self.dropout(out)
- return out, attention