From 1f3ab1c091b44b119765f785eab16e7dd06dfa4d Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Thu, 2 Jun 2022 23:38:19 +0200
Subject: Add mlp conformer layer

---
 text_recognizer/networks/conformer/mlp.py | 17 +++++++++++++++++
 1 file changed, 17 insertions(+)

(limited to 'text_recognizer/networks')

diff --git a/text_recognizer/networks/conformer/mlp.py b/text_recognizer/networks/conformer/mlp.py
index e69de29..031bde9 100644
--- a/text_recognizer/networks/conformer/mlp.py
+++ b/text_recognizer/networks/conformer/mlp.py
@@ -0,0 +1,17 @@
+"""Conformer feedforward block."""
+from torch import nn, Tensor
+
+
+class MLP(nn.Module):
+    def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0) -> None:
+        super().__init__()
+        self.layers = nn.Sequential(
+            nn.Linear(dim, mult * dim),
+            nn.Mish(inplace=True),
+            nn.Dropout(dropout),
+            nn.Linear(mult * dim, dim),
+            nn.Dropout(dropout),
+        )
+
+    def forward(self, x: Tensor) -> Tensor:
+        return self.layers(x)
-- 
cgit v1.2.3-70-g09d2