summaryrefslogtreecommitdiff
path: root/models/losses.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/losses.py')
-rw-r--r--models/losses.py101
1 files changed, 101 insertions, 0 deletions
diff --git a/models/losses.py b/models/losses.py
new file mode 100644
index 0000000..b3118e7
--- /dev/null
+++ b/models/losses.py
@@ -0,0 +1,101 @@
+from typing import Any, Tuple, Dict, Sequence, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+IGNORE_LABEL_ID = -100
+
+
+def s(x, epsilon=1e-30):
+ return torch.where(
+ x<0,
+ 1/(1-x+ epsilon),
+ x + 1
+ )
+
+
+def log_stablemax(x, dim=-1):
+ s_x = s(x)
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
+
+
+def stablemax_cross_entropy(logits, labels, ignore_index: int = -100):
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
+
+ valid_mask = labels != ignore_index
+ transformed_labels = torch.where(valid_mask, labels, 0)
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
+
+ return -torch.where(valid_mask, prediction_logprobs, 0)
+
+
+def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
+ # Cast logits to f32
+ # Flatten logits
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
+
+
+class ACTLossHead(nn.Module):
+ def __init__(self, model: nn.Module, loss_type: str):
+ super().__init__()
+ self.model = model
+ self.loss_fn = globals()[loss_type]
+
+ def initial_carry(self, *args, **kwargs):
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
+
+ def forward(
+ self,
+ return_keys: Sequence[str],
+ # Model args
+ **model_kwargs,
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
+ # Model logits
+ # B x SeqLen x D
+ new_carry, outputs = self.model(**model_kwargs)
+ labels = new_carry.current_data["labels"]
+
+ # Correctness
+ with torch.no_grad():
+ mask = labels != IGNORE_LABEL_ID
+ loss_counts = mask.sum(-1)
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
+
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
+ seq_is_correct = is_correct.sum(-1) == loss_counts
+
+ # Metrics (halted)
+ valid_metrics = new_carry.halted & (loss_counts > 0)
+ metrics = {
+ "count": valid_metrics.sum(),
+
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
+
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
+ }
+
+ # Losses
+ # FIXME: Assuming the batch is always full
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID) / loss_divisor).sum()
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
+
+ metrics.update({
+ "lm_loss": lm_loss.detach(),
+ "q_halt_loss": q_halt_loss.detach(),
+ })
+
+ # Q continue (bootstrapping target loss)
+ q_continue_loss = 0
+ if "target_q_continue" in outputs:
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
+
+ metrics["q_continue_loss"] = q_continue_loss.detach()
+
+ # Filter outputs for return
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
+
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()