mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-06 10:54:11 +08:00
[CUDA] Output of SDPA should have same layout with inputs (#2826)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
This commit is contained in:
@@ -63,6 +63,38 @@ array prepare_sdpa_input(const array& x, Stream s) {
|
||||
return x;
|
||||
}
|
||||
|
||||
void malloc_with_same_layout(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& o,
|
||||
const array& q) {
|
||||
if (q.flags().row_contiguous) {
|
||||
o.set_data(cu::malloc_async(o.nbytes(), encoder));
|
||||
return;
|
||||
}
|
||||
// fill_order = argsort(q.strides())
|
||||
Shape fill_order(q.ndim());
|
||||
std::iota(fill_order.begin(), fill_order.end(), 0);
|
||||
std::stable_sort(
|
||||
fill_order.begin(), fill_order.end(), [&q](int idx1, int idx2) {
|
||||
auto s1 = q.strides(idx1) > 0 ? q.strides(idx1) : 1;
|
||||
auto s2 = q.strides(idx2) > 0 ? q.strides(idx2) : 1;
|
||||
return s1 < s2;
|
||||
});
|
||||
// Generate o_strides with fill_order
|
||||
Strides o_strides(q.ndim());
|
||||
int64_t stride = 1;
|
||||
for (int i : fill_order) {
|
||||
o_strides[i] = stride;
|
||||
stride *= o.shape(i);
|
||||
}
|
||||
// o is a transposed contiguous array
|
||||
o.set_data(
|
||||
cu::malloc_async(o.nbytes(), encoder),
|
||||
o.size(),
|
||||
o_strides,
|
||||
{true, false, false});
|
||||
}
|
||||
|
||||
constexpr int QKV_NDIM = 4;
|
||||
|
||||
struct SDPACacheKey {
|
||||
@@ -338,9 +370,7 @@ void sdpa_cudnn(
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
|
||||
// TODO: Handle donation.
|
||||
// TODO: Make O use same memory layout with Q.
|
||||
o.set_data(cu::malloc_async(o.nbytes(), encoder));
|
||||
malloc_with_same_layout(encoder, o, q);
|
||||
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
@@ -392,10 +422,9 @@ void sdpa_backward_cudnn(
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
|
||||
// TODO: Handle donation.
|
||||
d_q.set_data(cu::malloc_async(d_q.nbytes(), encoder));
|
||||
d_k.set_data(cu::malloc_async(d_k.nbytes(), encoder));
|
||||
d_v.set_data(cu::malloc_async(d_v.nbytes(), encoder));
|
||||
malloc_with_same_layout(encoder, d_q, q);
|
||||
malloc_with_same_layout(encoder, d_k, k);
|
||||
malloc_with_same_layout(encoder, d_v, v);
|
||||
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
|
||||
Reference in New Issue
Block a user