mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
2-Pass Sdpa Inference Kernel (#1597)
This commit is contained in:
committed by
GitHub
parent
9bd03dd9b4
commit
073076ac7d
@@ -8,6 +8,7 @@
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
@@ -184,6 +185,94 @@ void sdpa_vector(
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void sdpa_vector_2pass(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
array& out,
|
||||
float scale) {
|
||||
// Set the kernel name
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
kname += "sdpa_vector_2pass_1_";
|
||||
kname += get_type_string(q.dtype());
|
||||
kname += "_";
|
||||
kname += std::to_string(q.shape(-1));
|
||||
|
||||
// Compute the necessary sizes
|
||||
int gqa_factor = q.shape(1) / k.shape(1);
|
||||
int N = k.shape(2);
|
||||
int blocks = 32;
|
||||
int B = q.shape(0) * q.shape(1);
|
||||
size_t k_stride = k.strides()[1];
|
||||
size_t v_stride = v.strides()[1];
|
||||
MTL::Size group_dims(8 * 32, 1, 1);
|
||||
MTL::Size grid_dims(1, B, blocks);
|
||||
|
||||
// Allocate the intermediates
|
||||
std::vector<int> intermediate_shape;
|
||||
intermediate_shape.reserve(out.ndim() + 1);
|
||||
intermediate_shape.insert(
|
||||
intermediate_shape.end(), out.shape().begin(), out.shape().end() - 1);
|
||||
intermediate_shape.push_back(blocks);
|
||||
intermediate_shape.push_back(out.shape().back());
|
||||
array intermediate(intermediate_shape, float32, nullptr, {});
|
||||
intermediate_shape.pop_back();
|
||||
array sums(intermediate_shape, float32, nullptr, {});
|
||||
array maxs(std::move(intermediate_shape), float32, nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
sums.set_data(allocator::malloc_or_wait(sums.nbytes()));
|
||||
maxs.set_data(allocator::malloc_or_wait(maxs.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
d.add_temporary(sums, s.index);
|
||||
d.add_temporary(maxs, s.index);
|
||||
|
||||
// Get the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set its arguments
|
||||
compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0);
|
||||
compute_encoder.set_input_array(k, 1);
|
||||
compute_encoder.set_input_array(v, 2);
|
||||
compute_encoder.set_output_array(intermediate, 3);
|
||||
compute_encoder.set_output_array(sums, 4);
|
||||
compute_encoder.set_output_array(maxs, 5);
|
||||
compute_encoder.set_bytes(gqa_factor, 6);
|
||||
compute_encoder.set_bytes(N, 7);
|
||||
compute_encoder.set_bytes(k_stride, 8);
|
||||
compute_encoder.set_bytes(v_stride, 9);
|
||||
compute_encoder.set_bytes(scale, 10);
|
||||
|
||||
// Launch
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
// Final pass
|
||||
kname.clear();
|
||||
kname += "sdpa_vector_2pass_2_";
|
||||
kname += get_type_string(q.dtype());
|
||||
kname += "_";
|
||||
kname += std::to_string(q.shape(-1));
|
||||
|
||||
// Get the kernel
|
||||
kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set its arguments
|
||||
compute_encoder.set_input_array(intermediate, 0);
|
||||
compute_encoder.set_input_array(sums, 1);
|
||||
compute_encoder.set_input_array(maxs, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
// Launch
|
||||
group_dims = MTL::Size(1024, 1, 1);
|
||||
grid_dims = MTL::Size(1, B, 1);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
@@ -249,7 +338,17 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
} else {
|
||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||
}
|
||||
sdpa_vector(s, d, q, k, v, o, scale_);
|
||||
|
||||
// We route to the 2 pass fused attention if
|
||||
// - The device is large and the sequence length long
|
||||
// - The sequence length is even longer and we have gqa
|
||||
char devc = d.get_architecture().back();
|
||||
if ((devc == 'd' && k.shape(2) >= 1024) ||
|
||||
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
|
||||
sdpa_vector_2pass(s, d, q, k, v, o, scale_);
|
||||
} else {
|
||||
sdpa_vector(s, d, q, k, v, o, scale_);
|
||||
}
|
||||
}
|
||||
|
||||
// Full attention mode
|
||||
|
||||
Reference in New Issue
Block a user