summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conv_transformer.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-10-02 01:45:34 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-10-02 01:45:34 +0200
commitffec11ce67d8fe75ea0d5dde5ddf17eb1017fa7d (patch)
treedb8c78232e588b12d7a8b408682783e0b5858615 /text_recognizer/networks/conv_transformer.py
parentcf2a827db5798a245dd5207685251675d311dbec (diff)
Add comments
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
-rw-r--r--text_recognizer/networks/conv_transformer.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index e36a786..d36162a 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -1,7 +1,6 @@
"""Base network module."""
from typing import Type
-import torch
from torch import Tensor, nn
from text_recognizer.networks.transformer.decoder import Decoder
@@ -28,11 +27,11 @@ class ConvTransformer(nn.Module):
return self.decoder(tokens, img_features)
def forward(self, img: Tensor, tokens: Tensor) -> Tensor:
- """Encodes images into word piece logtis.
+ """Encodes images into token logtis.
Args:
img (Tensor): Input image(s).
- tokens (Tensor): Target word embeddings.
+ tokens (Tensor): token embeddings.
Shapes:
- img: :math: `(B, 1, H, W)`