blob: 07b6a328155897e7bda015ba40c63e284d1826aa (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
"""Base network with required methods."""
from abc import abstractmethod
import attr
from torch import nn, Tensor
@attr.s
class BaseNetwork(nn.Module):
"""Base network."""
def __attrs_pre_init__(self) -> None:
super().__init__()
@abstractmethod
def predict(self, x: Tensor) -> Tensor:
"""Return token indices for predictions."""
...
|