From 1b591ec73673c68a3dc8f6a9d6e568b8ed07141e Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 13 Dec 2025 19:48:39 -0800 Subject: [PATCH] No VJP for mask or sinks in attention (#2909) --- mlx/fast.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 97a3a5f6a..dbc38cb1d 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -880,6 +880,11 @@ std::vector ScaledDotProductAttention::vjp( std::vector returned_vjps; for (int arg : argnums) { + if (arg >= 3) { + throw std::invalid_argument( + "[scale_dot_product_attention] Does not support VJP with respect " + " to mask or attention sinks."); + } returned_vjps.push_back(std::move(vjps[arg])); } return returned_vjps;