diff options
Diffstat (limited to 'text_recognizer/model')
-rw-r--r-- | text_recognizer/model/mammut.py | 22 |
1 files changed, 12 insertions, 10 deletions
diff --git a/text_recognizer/model/mammut.py b/text_recognizer/model/mammut.py index 5d33492..320f228 100644 --- a/text_recognizer/model/mammut.py +++ b/text_recognizer/model/mammut.py @@ -1,16 +1,17 @@ """Lightning model for transformer networks.""" from typing import Callable, Optional, Tuple, Type -from text_recognizer.network.mammut import MaMMUT import torch +import torch.nn.functional as F from einops import rearrange from omegaconf import DictConfig -from torch import einsum, nn, Tensor +from torch import Tensor, einsum, nn from torchmetrics import CharErrorRate, WordErrorRate -import torch.nn.functional as F -from text_recognizer.decoder.greedy_decoder import GreedyDecoder from text_recognizer.data.tokenizer import Tokenizer +from text_recognizer.decoder.greedy_decoder import GreedyDecoder +from text_recognizer.network.mammut import MaMMUT + from .base import LitBase @@ -49,8 +50,8 @@ class LitMaMMUT(LitBase): """Autoregressive forward pass.""" return self.predict(data) - def to_caption_loss(self, logits: Tensor, text: Tensor) -> Tensor: - caption_loss = self.loss_fn(logits, text[:, 1:]) + def to_caption_loss(self, logits: Tensor, labels: Tensor) -> Tensor: + caption_loss = self.loss_fn(logits, labels) return self.caption_loss_weight * caption_loss def to_contrastive_loss( @@ -69,14 +70,15 @@ class LitMaMMUT(LitBase): ) / 2 return self.contrastive_loss_weight * contrastive_loss - def teacher_forward(self, images: Tensor, text: Tensor) -> Tuple[Tensor, Tensor]: + def teacher_forward(self, images: Tensor, tagets: Tensor) -> Tuple[Tensor, Tensor]: """Non-autoregressive forward pass.""" - text_embeddings = self.network.to_text_cls_features(text[:, :-1]) + text, labels = tagets[:, :-1], tagets[:, 1:] + text_embeddings = self.network.to_text_cls_features(text) image_embeddings, image_features = self.network.to_image_features(images) - logits = self.network.decode(text[:, :-1], image_features) + logits = self.network.decode(text, image_features) logits = rearrange(logits, "b n c -> b c n") - caption_loss = self.to_caption_loss(logits, text) + caption_loss = self.to_caption_loss(logits, labels) contrastive_loss = self.to_contrastive_loss(image_embeddings, text_embeddings) self.log("train/caption_loss", caption_loss) |