From 73edf3fcb4c47cda27e230c98718d4abdc3400e2 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Mon, 15 Apr 2024 21:50:14 +0200
Subject: Update mammut

---
 text_recognizer/model/mammut.py | 22 ++++++++++++----------
 1 file changed, 12 insertions(+), 10 deletions(-)

(limited to 'text_recognizer')

diff --git a/text_recognizer/model/mammut.py b/text_recognizer/model/mammut.py
index 5d33492..320f228 100644
--- a/text_recognizer/model/mammut.py
+++ b/text_recognizer/model/mammut.py
@@ -1,16 +1,17 @@
 """Lightning model for transformer networks."""
 from typing import Callable, Optional, Tuple, Type
-from text_recognizer.network.mammut import MaMMUT
 
 import torch
+import torch.nn.functional as F
 from einops import rearrange
 from omegaconf import DictConfig
-from torch import einsum, nn, Tensor
+from torch import Tensor, einsum, nn
 from torchmetrics import CharErrorRate, WordErrorRate
-import torch.nn.functional as F
 
-from text_recognizer.decoder.greedy_decoder import GreedyDecoder
 from text_recognizer.data.tokenizer import Tokenizer
+from text_recognizer.decoder.greedy_decoder import GreedyDecoder
+from text_recognizer.network.mammut import MaMMUT
+
 from .base import LitBase
 
 
@@ -49,8 +50,8 @@ class LitMaMMUT(LitBase):
         """Autoregressive forward pass."""
         return self.predict(data)
 
-    def to_caption_loss(self, logits: Tensor, text: Tensor) -> Tensor:
-        caption_loss = self.loss_fn(logits, text[:, 1:])
+    def to_caption_loss(self, logits: Tensor, labels: Tensor) -> Tensor:
+        caption_loss = self.loss_fn(logits, labels)
         return self.caption_loss_weight * caption_loss
 
     def to_contrastive_loss(
@@ -69,14 +70,15 @@ class LitMaMMUT(LitBase):
         ) / 2
         return self.contrastive_loss_weight * contrastive_loss
 
-    def teacher_forward(self, images: Tensor, text: Tensor) -> Tuple[Tensor, Tensor]:
+    def teacher_forward(self, images: Tensor, tagets: Tensor) -> Tuple[Tensor, Tensor]:
         """Non-autoregressive forward pass."""
-        text_embeddings = self.network.to_text_cls_features(text[:, :-1])
+        text, labels = tagets[:, :-1], tagets[:, 1:]
+        text_embeddings = self.network.to_text_cls_features(text)
         image_embeddings, image_features = self.network.to_image_features(images)
-        logits = self.network.decode(text[:, :-1], image_features)
+        logits = self.network.decode(text, image_features)
         logits = rearrange(logits, "b n c -> b c n")
 
-        caption_loss = self.to_caption_loss(logits, text)
+        caption_loss = self.to_caption_loss(logits, labels)
         contrastive_loss = self.to_contrastive_loss(image_embeddings, text_embeddings)
 
         self.log("train/caption_loss", caption_loss)
-- 
cgit v1.2.3-70-g09d2