diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 2c93d6861..c7dfb36e6 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -586,9 +586,9 @@ array scaled_dot_product_attention( throw std::invalid_argument(msg.str()); } - if (mask && (*mask).dtype() != final_type) { + if (mask && promote_types((*mask).dtype(), final_type) != final_type) { std::ostringstream msg; - msg << "[scaled_dot_product_attention] Mask should match output type. " + msg << "[scaled_dot_product_attention] Mask type must promote to output type. " << final_type << "."; throw std::invalid_argument(msg.str()); }