summaryrefslogtreecommitdiff
path: root/text_recognizer/network/convformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/network/convformer.py')
-rw-r--r--text_recognizer/network/convformer.py16
1 files changed, 6 insertions, 10 deletions
diff --git a/text_recognizer/network/convformer.py b/text_recognizer/network/convformer.py
index 0ee5487..e2b0204 100644
--- a/text_recognizer/network/convformer.py
+++ b/text_recognizer/network/convformer.py
@@ -1,12 +1,10 @@
-from typing import Optional
from einops.layers.torch import Rearrange
from torch import Tensor, nn
-from text_recognizer.network.convnext.convnext import ConvNext
-from .transformer.embedding.token import TokenEmbedding
-from .transformer.embedding.sincos import sincos_2d
from .transformer.decoder import Decoder
+from .transformer.embedding.sincos import sincos_2d
+from .transformer.embedding.token import TokenEmbedding
from .transformer.encoder import Encoder
@@ -24,12 +22,10 @@ class Convformer(nn.Module):
token_embedding: TokenEmbedding,
tie_embeddings: bool,
pad_index: int,
- stem: Optional[ConvNext] = None,
channels: int = 1,
) -> None:
super().__init__()
patch_dim = patch_height * patch_width * channels
- self.stem = stem if stem is not None else nn.Identity()
self.to_patch_embedding = nn.Sequential(
Rearrange(
"b c (h ph) (w pw) -> b (h w) (ph pw c)",
@@ -53,11 +49,11 @@ class Convformer(nn.Module):
self.decoder = decoder
self.pad_index = pad_index
- def encode(self, img: Tensor) -> Tensor:
- x = self.stem(img)
+ def encode(self, images: Tensor) -> Tensor:
+ x = self.encoder(images)
x = self.to_patch_embedding(x)
- x += self.patch_embedding.to(img.device, dtype=img.dtype)
- return self.encoder(x)
+ x = x + self.patch_embedding.to(images.device, dtype=images.dtype)
+ return x
def decode(self, text: Tensor, img_features: Tensor) -> Tensor:
text = text.long()