Add stricter condition to matrix sdpa

This commit is contained in:
Angelos Katharopoulos
2025-08-06 19:51:14 -07:00
parent 99d8de8445
commit a22d0bf273

View File

@@ -499,7 +499,8 @@ void sdpa_vector_1pass_fallback(
dispatch_headdim(params.D, [&](auto headdim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::kernel_sdpav_1pass<DataType, do_causal(), headdim()>;
auto kernel =
cu::kernel_sdpav_1pass<DataType, do_causal.value, headdim.value>;
encoder.add_kernel_node(
kernel,
grid_dim,
@@ -569,8 +570,8 @@ void sdpa_vector_2pass_fallback(
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
{
auto kernel =
cu::kernel_sdpav_2pass_1<DataType, do_causal(), headdim()>;
auto kernel = cu::
kernel_sdpav_2pass_1<DataType, do_causal.value, headdim.value>;
encoder.set_input_array(q);
encoder.set_input_array(k);
@@ -597,8 +598,8 @@ void sdpa_vector_2pass_fallback(
}
{
auto kernel =
cu::kernel_sdpav_2pass_2<DataType, do_causal(), headdim()>;
auto kernel = cu::
kernel_sdpav_2pass_2<DataType, do_causal.value, headdim.value>;
encoder.set_input_array(intermediate);
encoder.set_input_array(sums);
@@ -963,6 +964,7 @@ bool ScaledDotProductAttention::use_fallback(
const bool supported_matrix_config = query_sequence_length > 4 &&
cu_device.compute_capability_major() >= 8 &&
query_sequence_length == key_sequence_length &&
(q.dtype() == float16 || q.dtype() == bfloat16);
const bool supported_config =
@@ -1105,4 +1107,4 @@ void ScaledDotProductAttention::eval_gpu(
} // namespace fast
} // namespace mlx::core
} // namespace mlx::core