summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/cnn_tranformer.py
blob: da69311f58db81b58c22d2e1f2c07f13c84d3114 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
"""Vision transformer for character recognition."""
from typing import Type

import attr
from torch import nn, Tensor


@attr.s
class CnnTransformer(nn.Module):
    def __attrs_pre_init__(self) -> None:
        super().__init__()

    backbone: Type[nn.Module] = attr.ib()
    head = Type[nn.Module] = attr.ib()