mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Add stricter condition to matrix sdpa
This commit is contained in:
@@ -499,7 +499,8 @@ void sdpa_vector_1pass_fallback(
|
||||
dispatch_headdim(params.D, [&](auto headdim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
|
||||
auto kernel = cu::kernel_sdpav_1pass<DataType, do_causal(), headdim()>;
|
||||
auto kernel =
|
||||
cu::kernel_sdpav_1pass<DataType, do_causal.value, headdim.value>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid_dim,
|
||||
@@ -569,8 +570,8 @@ void sdpa_vector_2pass_fallback(
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
|
||||
{
|
||||
auto kernel =
|
||||
cu::kernel_sdpav_2pass_1<DataType, do_causal(), headdim()>;
|
||||
auto kernel = cu::
|
||||
kernel_sdpav_2pass_1<DataType, do_causal.value, headdim.value>;
|
||||
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
@@ -597,8 +598,8 @@ void sdpa_vector_2pass_fallback(
|
||||
}
|
||||
|
||||
{
|
||||
auto kernel =
|
||||
cu::kernel_sdpav_2pass_2<DataType, do_causal(), headdim()>;
|
||||
auto kernel = cu::
|
||||
kernel_sdpav_2pass_2<DataType, do_causal.value, headdim.value>;
|
||||
|
||||
encoder.set_input_array(intermediate);
|
||||
encoder.set_input_array(sums);
|
||||
@@ -963,6 +964,7 @@ bool ScaledDotProductAttention::use_fallback(
|
||||
|
||||
const bool supported_matrix_config = query_sequence_length > 4 &&
|
||||
cu_device.compute_capability_major() >= 8 &&
|
||||
query_sequence_length == key_sequence_length &&
|
||||
(q.dtype() == float16 || q.dtype() == bfloat16);
|
||||
|
||||
const bool supported_config =
|
||||
@@ -1105,4 +1107,4 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
||||
|
Reference in New Issue
Block a user