Compare commits

..

3 Commits

Author SHA1 Message Date
Cheng
f8bd675655 [CUDA] Output of SDPA should have same layout with inputs (#2826)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-25 15:22:58 +09:00
Cheng
23a9168d34 [CUDA] Add debug env to save cuda graphs to dot files (#2825) 2025-11-25 15:22:36 +09:00
Awni Hannun
bca205e287 [CUDA] Exit on crash and more helpful errors (#2830) 2025-11-24 19:46:03 -08:00
6 changed files with 73 additions and 14 deletions

View File

@@ -154,17 +154,21 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
}
lock.unlock();
if (!buf) {
buf = new CudaBuffer{nullptr, size, device};
cudaError_t err;
void* data = nullptr;
if (device == -1) {
err = cudaMallocManaged(&buf->data, size);
err = cudaMallocManaged(&data, size);
} else {
err = cudaMallocAsync(&buf->data, size, stream);
err = cudaMallocAsync(&data, size, stream);
}
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
if (!data) {
return Buffer{nullptr};
}
buf = new CudaBuffer{data, size, device};
}
lock.lock();
}

View File

@@ -29,6 +29,10 @@ class CudaHandle {
}
~CudaHandle() {
// Skip if there was an error to avoid throwing in the destructors
if (cudaPeekAtLastError() != cudaSuccess) {
return;
}
reset();
}

View File

@@ -24,12 +24,21 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
}
bool use_cuda_graphs() {
static bool use_graphs = []() {
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
}();
static bool use_graphs = env::get_var("MLX_USE_CUDA_GRAPHS", true);
return use_graphs;
}
const char* save_cuda_graphs_dot_file() {
static const char* filename = []() -> const char* {
const char* env = std::getenv("MLX_SAVE_CUDA_GRAPHS_DOT_FILE");
if (env && std::strlen(env) == 0) {
return nullptr;
}
return env;
}();
return filename;
}
} // namespace
Device::Device(int device) : device_(device) {
@@ -421,6 +430,14 @@ void CommandEncoder::commit() {
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
}
// Save cuda graph to dot file
if (const char* filename = save_cuda_graphs_dot_file(); filename) {
static int count = 0;
auto path = fmt::format("{}_{}.dot", filename, ++count);
CHECK_CUDA_ERROR(cudaGraphDebugDotPrint(graph_, path.c_str(), 0));
}
// Reset state
from_nodes_.clear();
to_nodes_.clear();

View File

@@ -305,6 +305,7 @@ void Event::wait() {
} else {
event->atomic->wait(value());
}
CHECK_CUDA_ERROR(cudaPeekAtLastError());
}
void Event::wait(Stream s) {

View File

@@ -63,6 +63,38 @@ array prepare_sdpa_input(const array& x, Stream s) {
return x;
}
void malloc_with_same_layout(
cu::CommandEncoder& encoder,
array& o,
const array& q) {
if (q.flags().row_contiguous) {
o.set_data(cu::malloc_async(o.nbytes(), encoder));
return;
}
// fill_order = argsort(q.strides())
Shape fill_order(q.ndim());
std::iota(fill_order.begin(), fill_order.end(), 0);
std::stable_sort(
fill_order.begin(), fill_order.end(), [&q](int idx1, int idx2) {
auto s1 = q.strides(idx1) > 0 ? q.strides(idx1) : 1;
auto s2 = q.strides(idx2) > 0 ? q.strides(idx2) : 1;
return s1 < s2;
});
// Generate o_strides with fill_order
Strides o_strides(q.ndim());
int64_t stride = 1;
for (int i : fill_order) {
o_strides[i] = stride;
stride *= o.shape(i);
}
// o is a transposed contiguous array
o.set_data(
cu::malloc_async(o.nbytes(), encoder),
o.size(),
o_strides,
{true, false, false});
}
constexpr int QKV_NDIM = 4;
struct SDPACacheKey {
@@ -338,9 +370,7 @@ void sdpa_cudnn(
auto& encoder = cu::get_command_encoder(s);
auto handle = encoder.device().cudnn_handle();
// TODO: Handle donation.
// TODO: Make O use same memory layout with Q.
o.set_data(cu::malloc_async(o.nbytes(), encoder));
malloc_with_same_layout(encoder, o, q);
encoder.set_input_array(q);
encoder.set_input_array(k);
@@ -392,10 +422,9 @@ void sdpa_backward_cudnn(
auto& encoder = cu::get_command_encoder(s);
auto handle = encoder.device().cudnn_handle();
// TODO: Handle donation.
d_q.set_data(cu::malloc_async(d_q.nbytes(), encoder));
d_k.set_data(cu::malloc_async(d_k.nbytes(), encoder));
d_v.set_data(cu::malloc_async(d_v.nbytes(), encoder));
malloc_with_same_layout(encoder, d_q, q);
malloc_with_same_layout(encoder, d_k, k);
malloc_with_same_layout(encoder, d_v, v);
encoder.set_input_array(q);
encoder.set_input_array(k);

View File

@@ -135,7 +135,11 @@ class Scheduler {
~Scheduler() {
for (auto s : streams_) {
synchronize(s);
try {
synchronize(s);
} catch (const std::runtime_error&) {
// ignore errors if synch fails
}
}
for (auto t : threads_) {
if (t != nullptr) {