diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-27 22:27:24 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-27 22:27:24 +0200 |
commit | 9a8044f4a3826a119416665741b709cd686fca87 (patch) | |
tree | e339593bb4e3858fa9379d14752dc52bf5949825 /text_recognizer/networks | |
parent | ae8bfa62f0e02bd70c27bc1e71697249a5a79e7e (diff) |
Remove Barlow Twins
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/barlow_twins/__init__.py | 1 | ||||
-rw-r--r-- | text_recognizer/networks/barlow_twins/network.py | 18 | ||||
-rw-r--r-- | text_recognizer/networks/barlow_twins/projector.py | 36 |
3 files changed, 0 insertions, 55 deletions
diff --git a/text_recognizer/networks/barlow_twins/__init__.py b/text_recognizer/networks/barlow_twins/__init__.py deleted file mode 100644 index 0b74818..0000000 --- a/text_recognizer/networks/barlow_twins/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Module for projector network in Barlow Twins.""" diff --git a/text_recognizer/networks/barlow_twins/network.py b/text_recognizer/networks/barlow_twins/network.py deleted file mode 100644 index a3e3750..0000000 --- a/text_recognizer/networks/barlow_twins/network.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Barlow Twins network.""" -from typing import Type - -from torch import nn, Tensor -import torch.nn.functional as F - - -class BarlowTwins(nn.Module): - def __init__(self, encoder: Type[nn.Module], projector: Type[nn.Module]) -> None: - super().__init__() - self.encoder = encoder - self.projector = projector - - def forward(self, x: Tensor) -> Tensor: - z = self.encoder(x) - z_e = F.adaptive_avg_pool2d(z, (1, 1)).flatten(start_dim=1) - z_p = self.projector(z_e) - return z_p diff --git a/text_recognizer/networks/barlow_twins/projector.py b/text_recognizer/networks/barlow_twins/projector.py deleted file mode 100644 index 05d5e2e..0000000 --- a/text_recognizer/networks/barlow_twins/projector.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Projector network in Barlow Twins.""" - -from typing import List -import torch -from torch import nn -from torch import Tensor - - -class Projector(nn.Module): - """MLP network.""" - - def __init__(self, dims: List[int]) -> None: - super().__init__() - self.dims = dims - self.network = self._build() - - def _build(self) -> nn.Sequential: - """Builds projector network.""" - layers = [ - nn.Sequential( - nn.Linear( - in_features=self.dims[i], out_features=self.dims[i + 1], bias=False - ), - nn.BatchNorm1d(self.dims[i + 1]), - nn.ReLU(inplace=True), - ) - for i in range(len(self.dims) - 2) - ] - layers.append( - nn.Linear(in_features=self.dims[-2], out_features=self.dims[-1], bias=False) - ) - return nn.Sequential(*layers) - - def forward(self, x: Tensor) -> Tensor: - """Project latent to higher dimesion.""" - return self.network(x) |