summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/base.py
blob: 07b6a328155897e7bda015ba40c63e284d1826aa (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
"""Base network with required methods."""
from abc import abstractmethod

import attr
from torch import nn, Tensor


@attr.s
class BaseNetwork(nn.Module):
    """Base network."""

    def __attrs_pre_init__(self) -> None:
        super().__init__()

    @abstractmethod
    def predict(self, x: Tensor) -> Tensor:
        """Return token indices for predictions."""
        ...