diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 97a3a5f6a..dbc38cb1d 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -880,6 +880,11 @@ std::vector ScaledDotProductAttention::vjp( std::vector returned_vjps; for (int arg : argnums) { + if (arg >= 3) { + throw std::invalid_argument( + "[scale_dot_product_attention] Does not support VJP with respect " + " to mask or attention sinks."); + } returned_vjps.push_back(std::move(vjps[arg])); } return returned_vjps;