mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-01 08:38:12 +08:00
Add stricter condition to matrix sdpa
This commit is contained in:
@@ -499,7 +499,8 @@ void sdpa_vector_1pass_fallback(
|
|||||||
dispatch_headdim(params.D, [&](auto headdim) {
|
dispatch_headdim(params.D, [&](auto headdim) {
|
||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
|
||||||
auto kernel = cu::kernel_sdpav_1pass<DataType, do_causal(), headdim()>;
|
auto kernel =
|
||||||
|
cu::kernel_sdpav_1pass<DataType, do_causal.value, headdim.value>;
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
grid_dim,
|
grid_dim,
|
||||||
@@ -569,8 +570,8 @@ void sdpa_vector_2pass_fallback(
|
|||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
|
||||||
{
|
{
|
||||||
auto kernel =
|
auto kernel = cu::
|
||||||
cu::kernel_sdpav_2pass_1<DataType, do_causal(), headdim()>;
|
kernel_sdpav_2pass_1<DataType, do_causal.value, headdim.value>;
|
||||||
|
|
||||||
encoder.set_input_array(q);
|
encoder.set_input_array(q);
|
||||||
encoder.set_input_array(k);
|
encoder.set_input_array(k);
|
||||||
@@ -597,8 +598,8 @@ void sdpa_vector_2pass_fallback(
|
|||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
auto kernel =
|
auto kernel = cu::
|
||||||
cu::kernel_sdpav_2pass_2<DataType, do_causal(), headdim()>;
|
kernel_sdpav_2pass_2<DataType, do_causal.value, headdim.value>;
|
||||||
|
|
||||||
encoder.set_input_array(intermediate);
|
encoder.set_input_array(intermediate);
|
||||||
encoder.set_input_array(sums);
|
encoder.set_input_array(sums);
|
||||||
@@ -963,6 +964,7 @@ bool ScaledDotProductAttention::use_fallback(
|
|||||||
|
|
||||||
const bool supported_matrix_config = query_sequence_length > 4 &&
|
const bool supported_matrix_config = query_sequence_length > 4 &&
|
||||||
cu_device.compute_capability_major() >= 8 &&
|
cu_device.compute_capability_major() >= 8 &&
|
||||||
|
query_sequence_length == key_sequence_length &&
|
||||||
(q.dtype() == float16 || q.dtype() == bfloat16);
|
(q.dtype() == float16 || q.dtype() == bfloat16);
|
||||||
|
|
||||||
const bool supported_config =
|
const bool supported_config =
|
||||||
@@ -1105,4 +1107,4 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
|
|
||||||
} // namespace fast
|
} // namespace fast
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
Reference in New Issue
Block a user