diff options
Diffstat (limited to 'text_recognizer/networks/barlow_twins/projector.py')
-rw-r--r-- | text_recognizer/networks/barlow_twins/projector.py | 36 |
1 files changed, 0 insertions, 36 deletions
diff --git a/text_recognizer/networks/barlow_twins/projector.py b/text_recognizer/networks/barlow_twins/projector.py deleted file mode 100644 index 05d5e2e..0000000 --- a/text_recognizer/networks/barlow_twins/projector.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Projector network in Barlow Twins.""" - -from typing import List -import torch -from torch import nn -from torch import Tensor - - -class Projector(nn.Module): - """MLP network.""" - - def __init__(self, dims: List[int]) -> None: - super().__init__() - self.dims = dims - self.network = self._build() - - def _build(self) -> nn.Sequential: - """Builds projector network.""" - layers = [ - nn.Sequential( - nn.Linear( - in_features=self.dims[i], out_features=self.dims[i + 1], bias=False - ), - nn.BatchNorm1d(self.dims[i + 1]), - nn.ReLU(inplace=True), - ) - for i in range(len(self.dims) - 2) - ] - layers.append( - nn.Linear(in_features=self.dims[-2], out_features=self.dims[-1], bias=False) - ) - return nn.Sequential(*layers) - - def forward(self, x: Tensor) -> Tensor: - """Project latent to higher dimesion.""" - return self.network(x) |