mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fused attention for single query (#1497)
This commit is contained in:
committed by
GitHub
parent
9dd72cd421
commit
50d8bed468
@@ -1,20 +1,13 @@
|
||||
//
|
||||
// scaled_dot_product_attention.cpp
|
||||
// mlx
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
@@ -26,8 +19,7 @@ void sdpa_full_self_attention_metal(
|
||||
const array& k,
|
||||
const array& v,
|
||||
const float alpha,
|
||||
array& out,
|
||||
std::vector<array>& temporaries) {
|
||||
array& out) {
|
||||
std::ostringstream kname_self_attention;
|
||||
kname_self_attention << "steel_gemm_attention_";
|
||||
|
||||
@@ -148,130 +140,58 @@ void sdpa_full_self_attention_metal(
|
||||
MTL::Size group_dims = MTL::Size(32, wm, wn);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[temporaries](MTL::CommandBuffer*) mutable { temporaries.clear(); });
|
||||
return;
|
||||
}
|
||||
|
||||
void sdpa_metal(
|
||||
void sdpa_vector(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
const array& p_lse,
|
||||
const array& p_rowmaxes,
|
||||
const array& o_partial,
|
||||
const uint heads,
|
||||
const uint tile_size,
|
||||
const uint n_tiles,
|
||||
const float alpha,
|
||||
array& out,
|
||||
std::vector<array>& temporaries) {
|
||||
std::ostringstream kname_partials;
|
||||
float scale) {
|
||||
// Set the kernel name
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
kname += "sdpa_vector_";
|
||||
kname += get_type_string(q.dtype());
|
||||
kname += "_";
|
||||
kname += std::to_string(q.shape(-1));
|
||||
|
||||
kname_partials << "fast_inference_sdpa_compute_partials_";
|
||||
// Compute the necessary sizes
|
||||
int gqa_factor = q.shape(1) / k.shape(1);
|
||||
int N = k.shape(2);
|
||||
int B = q.shape(0) * q.shape(1);
|
||||
size_t stride = k.strides()[1];
|
||||
MTL::Size group_dims(1024, 1, 1);
|
||||
MTL::Size grid_dims(1, B, 1);
|
||||
|
||||
std::ostringstream kname_reduce;
|
||||
std::string delimiter = "_";
|
||||
kname_reduce << "fast_inference_sdpa_reduce_tiles" + delimiter;
|
||||
|
||||
for (const auto& arr : {k, v, out}) {
|
||||
if (arr.dtype() != q.dtype()) {
|
||||
throw std::runtime_error(
|
||||
"[ScaledDotProductAttention::eval_gpu]: expected matching dtypes for q,k,v,o");
|
||||
}
|
||||
}
|
||||
|
||||
if (q.dtype() == float32) {
|
||||
kname_partials << "float" + delimiter;
|
||||
kname_reduce << "float";
|
||||
} else if (q.dtype() == float16) {
|
||||
kname_partials << "half" + delimiter;
|
||||
kname_reduce << "half";
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[ScaledDotProductAttention::eval_gpu]: unexpected dtype found for queries: expected either float32 or float16.");
|
||||
}
|
||||
|
||||
std::string kname_suffix_tile_size = std::to_string(tile_size) + delimiter;
|
||||
|
||||
uint nsimd = 8;
|
||||
std::string kname_suffix_nsimdgroups = std::to_string(nsimd);
|
||||
|
||||
// maximum number of splits == 128 at the moment (reserved tile registers in
|
||||
// reduction kernel). this is arbitrary and could be changed in the shader.
|
||||
|
||||
std::string kname_suffix = kname_suffix_tile_size + kname_suffix_nsimdgroups;
|
||||
kname_partials << kname_suffix;
|
||||
// Get the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname_partials.str());
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
constexpr const uint batch = 1;
|
||||
MTL::Size grid_dims = MTL::Size(heads, n_tiles, batch);
|
||||
MTL::Size group_dims = MTL::Size(32, nsimd, 1);
|
||||
|
||||
const uint64_t KV_sequence_length = k.shape(-2);
|
||||
const uint query_sequence_length = q.shape(-2);
|
||||
const uint n_q_heads = q.shape(1);
|
||||
const uint n_kv_heads = k.shape(1);
|
||||
|
||||
MLXScaledDotProductAttentionParams params{
|
||||
query_sequence_length, n_q_heads, n_kv_heads, n_tiles, alpha};
|
||||
|
||||
compute_encoder.set_input_array(q, 0);
|
||||
// Set its arguments
|
||||
compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0);
|
||||
compute_encoder.set_input_array(k, 1);
|
||||
compute_encoder.set_input_array(v, 2);
|
||||
compute_encoder->setBytes(&KV_sequence_length, sizeof(KV_sequence_length), 3);
|
||||
compute_encoder->setBytes(
|
||||
¶ms, sizeof(MLXScaledDotProductAttentionParams), 4);
|
||||
compute_encoder.set_input_array(o_partial, 5);
|
||||
compute_encoder.set_input_array(p_lse, 6);
|
||||
compute_encoder.set_input_array(p_rowmaxes, 7);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
compute_encoder->setBytes(&gqa_factor, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&stride, sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&scale, sizeof(float), 7);
|
||||
|
||||
constexpr const uint tgroupMemorySize = 32768;
|
||||
compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0);
|
||||
// Launch
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
{
|
||||
auto kernel_accum = d.get_kernel(kname_reduce.str());
|
||||
compute_encoder->setComputePipelineState(kernel_accum);
|
||||
compute_encoder.set_input_array(o_partial, 0);
|
||||
compute_encoder.set_input_array(p_lse, 1);
|
||||
compute_encoder.set_input_array(p_rowmaxes, 2);
|
||||
compute_encoder->setBytes(
|
||||
¶ms, sizeof(MLXScaledDotProductAttentionParams), 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
|
||||
MTL::Size grid_dims_reduce = MTL::Size(heads, 1, batch);
|
||||
MTL::Size group_dims_reduce = MTL::Size(128, 1, 1);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims_reduce, group_dims_reduce);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[temporaries](MTL::CommandBuffer*) mutable { temporaries.clear(); });
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
assert(inputs.size() >= 3);
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[ScaledDotProductAttention] Does not yet support non-floating point types.");
|
||||
}
|
||||
assert(inputs.size() == 3);
|
||||
|
||||
if (inputs.size() == 4) {
|
||||
out = fallback_(inputs)[0];
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
@@ -279,84 +199,75 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
auto& k_pre = inputs[1];
|
||||
auto& v_pre = inputs[2];
|
||||
auto& o = out;
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> temporaries;
|
||||
auto check_transpose = [&temporaries, &s](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
return arr;
|
||||
} else {
|
||||
std::vector<array> copies;
|
||||
|
||||
// Define some copy functions to ensure the layout of the inputs is as
|
||||
// expected.
|
||||
auto copy_unless = [&copies, &s](auto predicate, const array& arr) {
|
||||
if (!predicate(arr)) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
temporaries.push_back(arr_copy);
|
||||
size_t stx = arr.shape(-1);
|
||||
copies.push_back(arr_copy);
|
||||
return arr_copy;
|
||||
} else {
|
||||
return arr;
|
||||
}
|
||||
};
|
||||
|
||||
auto q = check_transpose(q_pre);
|
||||
auto k = check_transpose(k_pre);
|
||||
auto v = check_transpose(v_pre);
|
||||
// Checks if arr is fully row contiguous
|
||||
auto is_contiguous = [](const array& arr) {
|
||||
return arr.flags().row_contiguous;
|
||||
};
|
||||
|
||||
const int heads = q.shape(-3);
|
||||
// Returns true if the array is row contiguous except the sequence length
|
||||
// dimension that can be sliced but with step=1.
|
||||
auto is_contiguous_except_seq_len = [](const array& arr) {
|
||||
auto& strides = arr.strides();
|
||||
auto& shape = arr.shape();
|
||||
return strides[3] == 1 && strides[2] == shape[3] &&
|
||||
strides[0] == strides[1] * shape[1];
|
||||
};
|
||||
|
||||
uint query_sequence_length = q.shape(-2);
|
||||
if (query_sequence_length >= 16) {
|
||||
return sdpa_full_self_attention_metal(
|
||||
s, d, q, k, v, scale_, out, temporaries);
|
||||
}
|
||||
int tile_size = 64;
|
||||
const int kv_seq_len = k.shape(-2);
|
||||
if (kv_seq_len > 8000) {
|
||||
tile_size = 128;
|
||||
}
|
||||
if (kv_seq_len > 16000) {
|
||||
tile_size = 256;
|
||||
}
|
||||
if (kv_seq_len > 32000) {
|
||||
tile_size = 512;
|
||||
// Checks that the last two dims are row contiguous.
|
||||
auto is_matrix_contiguous = [](const array& arr) {
|
||||
auto& strides = arr.strides();
|
||||
auto& shape = arr.shape();
|
||||
return strides[3] == 1 && strides[2] == shape[3];
|
||||
};
|
||||
|
||||
// We are in vector mode ie single query
|
||||
if (q_pre.shape(2) == 1) {
|
||||
auto q = copy_unless(is_contiguous, q_pre);
|
||||
auto k = copy_unless(is_contiguous_except_seq_len, k_pre);
|
||||
auto v = copy_unless(is_contiguous_except_seq_len, v_pre);
|
||||
|
||||
// Donate the query if possible
|
||||
if (q.is_donatable()) {
|
||||
o.move_shared_buffer(q);
|
||||
} else {
|
||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||
}
|
||||
|
||||
sdpa_vector(s, d, q, k, v, o, scale_);
|
||||
}
|
||||
|
||||
const int n_tiles = (kv_seq_len + tile_size - 1) / tile_size;
|
||||
// Full attention mode
|
||||
else {
|
||||
auto q = copy_unless(is_matrix_contiguous, q_pre);
|
||||
auto k = copy_unless(is_matrix_contiguous, k_pre);
|
||||
auto v = copy_unless(is_matrix_contiguous, v_pre);
|
||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||
|
||||
array o_partials(
|
||||
{q.shape(-4), q.shape(-3), q.shape(-2), n_tiles * v.shape(-1)},
|
||||
float32,
|
||||
nullptr,
|
||||
{});
|
||||
o_partials.set_data(allocator::malloc_or_wait(o_partials.nbytes()));
|
||||
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
|
||||
}
|
||||
|
||||
array p_lse(
|
||||
{q.shape(-4), q.shape(-3), q.shape(-2), n_tiles}, float32, nullptr, {});
|
||||
array p_rowmaxes(
|
||||
{q.shape(-4), q.shape(-3), q.shape(-2), n_tiles}, float32, nullptr, {});
|
||||
p_lse.set_data(allocator::malloc_or_wait(p_lse.nbytes()));
|
||||
p_rowmaxes.set_data(allocator::malloc_or_wait(p_rowmaxes.nbytes()));
|
||||
|
||||
temporaries.push_back(p_lse);
|
||||
temporaries.push_back(p_rowmaxes);
|
||||
temporaries.push_back(o_partials);
|
||||
|
||||
return sdpa_metal(
|
||||
s,
|
||||
d,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
p_lse,
|
||||
p_rowmaxes,
|
||||
o_partials,
|
||||
heads,
|
||||
tile_size,
|
||||
n_tiles,
|
||||
scale_,
|
||||
out,
|
||||
temporaries);
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
||||
Reference in New Issue
Block a user