diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 58800fb70b..2c93d68618 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -586,6 +586,13 @@ array scaled_dot_product_attention( throw std::invalid_argument(msg.str()); } + if (mask && (*mask).dtype() != final_type) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] Mask should match output type. " + << final_type << "."; + throw std::invalid_argument(msg.str()); + } + auto q = astype(queries, final_type, s); auto k = astype(keys, final_type, s); auto v = astype(values, final_type, s);