From fd9b1570c568d9ce8f1ac7258f05f9977a5cc9c8 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Tue, 13 Sep 2022 18:45:58 +0200
Subject: Remove perceiver net

---
 text_recognizer/networks/perceiver/__init__.py  |  1 -
 text_recognizer/networks/perceiver/attention.py | 40 -----------
 text_recognizer/networks/perceiver/perceiver.py | 91 -------------------------
 3 files changed, 132 deletions(-)
 delete mode 100644 text_recognizer/networks/perceiver/__init__.py
 delete mode 100644 text_recognizer/networks/perceiver/attention.py
 delete mode 100644 text_recognizer/networks/perceiver/perceiver.py

(limited to 'text_recognizer')

diff --git a/text_recognizer/networks/perceiver/__init__.py b/text_recognizer/networks/perceiver/__init__.py
deleted file mode 100644
index ac2c102..0000000
--- a/text_recognizer/networks/perceiver/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from text_recognizer.networks.perceiver.perceiver import PerceiverIO
diff --git a/text_recognizer/networks/perceiver/attention.py b/text_recognizer/networks/perceiver/attention.py
deleted file mode 100644
index 19e3e17..0000000
--- a/text_recognizer/networks/perceiver/attention.py
+++ /dev/null
@@ -1,40 +0,0 @@
-"""Attention module."""
-from typing import Optional
-
-from einops import rearrange, repeat
-import torch
-from torch import einsum, nn, Tensor
-import torch.nn.functional as F
-
-
-class Attention(nn.Module):
-    def __init__(
-        self,
-        query_dim: int,
-        context_dim: Optional[int] = None,
-        heads: int = 8,
-        dim_head: int = 64,
-    ) -> None:
-        super().__init__()
-        inner_dim = heads * dim_head
-        context_dim = context_dim if context_dim is not None else query_dim
-        self.scale = dim_head ** -0.5
-        self.heads = heads
-
-        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
-        self.to_kv = nn.Linear(context_dim, 2 * inner_dim, bias=False)
-        self.to_out = nn.Linear(inner_dim, query_dim, bias=False)
-
-    def forward(self, x: Tensor, context: Optional[Tensor] = None) -> Tensor:
-        h = self.heads
-        q = self.to_q(x)
-        context = context if context is not None else x
-        k, v = self.to_kv(context).chunk(2, dim=-1)
-
-        q, v, k = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
-        sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
-
-        attn = sim.softmax(dim=-1)
-        out = einsum("b i j, b j d -> b i d", attn, v)
-        out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
-        return self.to_out(out)
diff --git a/text_recognizer/networks/perceiver/perceiver.py b/text_recognizer/networks/perceiver/perceiver.py
deleted file mode 100644
index 5b4ab26..0000000
--- a/text_recognizer/networks/perceiver/perceiver.py
+++ /dev/null
@@ -1,91 +0,0 @@
-"""Perceiver IO.
-
-A copy from lucidrains.
-"""
-from typing import Optional
-
-from einops import repeat, rearrange
-import torch
-from torch import nn, Tensor
-
-from text_recognizer.networks.perceiver.attention import Attention
-from text_recognizer.networks.transformer.ff import FeedForward
-from text_recognizer.networks.transformer.norm import PreNorm
-
-
-class PerceiverIO(nn.Module):
-    def __init__(
-        self,
-        dim: int,
-        cross_heads: int,
-        cross_head_dim: int,
-        num_latents: int,
-        latent_dim: int,
-        latent_heads: int,
-        depth: int,
-        queries_dim: int,
-        logits_dim: int,
-    ) -> None:
-        super().__init__()
-        self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
-
-        self.cross_attn_block = nn.ModuleList(
-            [
-                PreNorm(
-                    latent_dim,
-                    Attention(
-                        latent_dim, dim, heads=cross_heads, dim_head=cross_head_dim
-                    ),
-                    context_dim=dim,
-                ),
-                PreNorm(latent_dim, FeedForward(latent_dim)),
-            ]
-        )
-
-        self.layers = nn.ModuleList(
-            [
-                nn.ModuleList(
-                    [
-                        PreNorm(
-                            latent_dim,
-                            Attention(
-                                latent_dim, heads=latent_heads, dim_head=latent_dim
-                            ),
-                        ),
-                        PreNorm(latent_dim, FeedForward(latent_dim)),
-                    ]
-                )
-                for _ in range(depth)
-            ]
-        )
-
-        self.decoder_cross_attn = PreNorm(
-            queries_dim,
-            Attention(
-                queries_dim, latent_dim, heads=cross_heads, dim_head=cross_head_dim
-            ),
-            context_dim=latent_dim,
-        )
-        self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim))
-        self.to_logits = nn.Linear(queries_dim, logits_dim)
-
-    def forward(self, data: Tensor, queries: Tensor) -> Tensor:
-        b = data.shape[0]
-        x = repeat(self.latents, "n d -> b n d", b=b)
-
-        cross_attn, cross_ff = self.cross_attn_block
-
-        x = cross_attn(x, context=data) + x
-        x = cross_ff(x) + x
-
-        for attn, ff in self.layers:
-            x = attn(x) + x
-            x = ff(x) + x
-
-        if queries.ndim == 2:
-            queries = repeat(queries, "nd->bnd", b=b)
-
-        latents = self.decoder_cross_attn(queries, context=x)
-        latents = latents + self.decoder_ff(latents)
-
-        return self.to_logits(latents)
-- 
cgit v1.2.3-70-g09d2