summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/networks/transformer/norm.py16
1 files changed, 14 insertions, 2 deletions
diff --git a/text_recognizer/networks/transformer/norm.py b/text_recognizer/networks/transformer/norm.py
index 537246d..4cd3b5b 100644
--- a/text_recognizer/networks/transformer/norm.py
+++ b/text_recognizer/networks/transformer/norm.py
@@ -4,7 +4,7 @@ Copied from lucidrains:
https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py
"""
-from typing import Dict, Type
+from typing import Dict, Optional, Type
import torch
from torch import nn
@@ -29,12 +29,24 @@ class RMSNorm(nn.Module):
class PreNorm(nn.Module):
"""Applies layer normalization then function."""
- def __init__(self, normalized_shape: int, fn: Type[nn.Module]) -> None:
+ def __init__(
+ self,
+ normalized_shape: int,
+ fn: Type[nn.Module],
+ context_dim: Optional[int] = None,
+ ) -> None:
super().__init__()
self.norm = nn.LayerNorm(normalized_shape)
self.fn = fn
+ self.norm_context = (
+ nn.LayerNorm(context_dim) if context_dim is not None else None
+ )
def forward(self, x: Tensor, **kwargs) -> Tensor:
"""Applies pre norm."""
x = self.norm(x)
+ if self.norm_context is not None:
+ context = kwargs["context"]
+ normed_context = self.norm_context(context)
+ kwargs.update(context=normed_context)
return self.fn(x, **kwargs)