summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/attention_layers.py
blob: 721fa273fc5e55212a59ed06a2dfc917f36ed99d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
"""Generates the attention layer architecture."""
from typing import Type

import torch
from torch import nn, Tensor


class AttentionLayers(nn.Module):
    def __init__(
        self,
        dim: int,
        depth: int,
        num_heads: int,
        norm_layer: Type[nn.Module],
        causal: bool = False,
        cross_attend: bool = False,
        only_cross: bool = False,
    ) -> None:
        pass