summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-15 21:50:14 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-15 21:50:14 +0200
commit73edf3fcb4c47cda27e230c98718d4abdc3400e2 (patch)
treedb5f8e9d2a9a9c331d0eba01e3cf269345684611
parent2c5093dda783cf7618a8554ae649d32b92b84b4c (diff)
Update mammut
-rw-r--r--text_recognizer/model/mammut.py22
1 files changed, 12 insertions, 10 deletions
diff --git a/text_recognizer/model/mammut.py b/text_recognizer/model/mammut.py
index 5d33492..320f228 100644
--- a/text_recognizer/model/mammut.py
+++ b/text_recognizer/model/mammut.py
@@ -1,16 +1,17 @@
"""Lightning model for transformer networks."""
from typing import Callable, Optional, Tuple, Type
-from text_recognizer.network.mammut import MaMMUT
import torch
+import torch.nn.functional as F
from einops import rearrange
from omegaconf import DictConfig
-from torch import einsum, nn, Tensor
+from torch import Tensor, einsum, nn
from torchmetrics import CharErrorRate, WordErrorRate
-import torch.nn.functional as F
-from text_recognizer.decoder.greedy_decoder import GreedyDecoder
from text_recognizer.data.tokenizer import Tokenizer
+from text_recognizer.decoder.greedy_decoder import GreedyDecoder
+from text_recognizer.network.mammut import MaMMUT
+
from .base import LitBase
@@ -49,8 +50,8 @@ class LitMaMMUT(LitBase):
"""Autoregressive forward pass."""
return self.predict(data)
- def to_caption_loss(self, logits: Tensor, text: Tensor) -> Tensor:
- caption_loss = self.loss_fn(logits, text[:, 1:])
+ def to_caption_loss(self, logits: Tensor, labels: Tensor) -> Tensor:
+ caption_loss = self.loss_fn(logits, labels)
return self.caption_loss_weight * caption_loss
def to_contrastive_loss(
@@ -69,14 +70,15 @@ class LitMaMMUT(LitBase):
) / 2
return self.contrastive_loss_weight * contrastive_loss
- def teacher_forward(self, images: Tensor, text: Tensor) -> Tuple[Tensor, Tensor]:
+ def teacher_forward(self, images: Tensor, tagets: Tensor) -> Tuple[Tensor, Tensor]:
"""Non-autoregressive forward pass."""
- text_embeddings = self.network.to_text_cls_features(text[:, :-1])
+ text, labels = tagets[:, :-1], tagets[:, 1:]
+ text_embeddings = self.network.to_text_cls_features(text)
image_embeddings, image_features = self.network.to_image_features(images)
- logits = self.network.decode(text[:, :-1], image_features)
+ logits = self.network.decode(text, image_features)
logits = rearrange(logits, "b n c -> b c n")
- caption_loss = self.to_caption_loss(logits, text)
+ caption_loss = self.to_caption_loss(logits, labels)
contrastive_loss = self.to_contrastive_loss(image_embeddings, text_embeddings)
self.log("train/caption_loss", caption_loss)