[CUDA] Switch to CUDA graphs (#2317)

* cuda graph prototype

fix signal bug + start to add dependencies

capture more

capture more ops

remaining ops

fix reduce and rope deps

add concurrent context

try update, but not working

cosistent topology order

use node api

use node api directly to reduce overhead

fix bug

use kernels in unary

cache graph

format

fix synchronization

format

* comment
This commit is contained in:
Awni Hannun
2025-07-02 15:59:13 -07:00
committed by GitHub
parent e76e9b87f0
commit ec0d5db67b
36 changed files with 1461 additions and 1212 deletions

View File

@@ -3,13 +3,16 @@
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include "cuda_jit_sources.h"
#include <cuda.h>
#include <fmt/format.h>
#include <nvrtc.h>
#include <nvtx3/nvtx3.hpp>
#include <cassert>
@@ -22,7 +25,7 @@ namespace {
constexpr const char* g_scatter_ops[] = {"Max", "Min", "Sum", "Prod", "Assign"};
void append_indices_arg(
cu::JitModule& mod,
cu::KernelArgs& args,
const std::vector<array>& inputs,
int nidx,
int idx_ndim) {
@@ -30,7 +33,7 @@ void append_indices_arg(
for (int i = 0; i < nidx; ++i) {
indices[i] = inputs[i + 1].data<void>();
}
mod.append_arg(std::move(indices));
args.append(std::move(indices));
std::vector<int32_t> indices_shape(nidx * idx_ndim);
for (int i = 0; i < nidx; ++i) {
std::copy_n(
@@ -38,7 +41,7 @@ void append_indices_arg(
idx_ndim,
indices_shape.data() + i * idx_ndim);
}
mod.append_arg(std::move(indices_shape));
args.append(std::move(indices_shape));
std::vector<int64_t> indices_strides(nidx * idx_ndim);
for (int i = 0; i < nidx; ++i) {
std::copy_n(
@@ -46,7 +49,7 @@ void append_indices_arg(
idx_ndim,
indices_strides.data() + i * idx_ndim);
}
mod.append_arg(std::move(indices_strides));
args.append(std::move(indices_strides));
}
} // namespace
@@ -94,20 +97,21 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
return std::make_pair(jit_source_gather, std::move(kernel_names));
});
mod.append_arg(src);
mod.append_arg(out);
cu::KernelArgs args;
args.append(src);
args.append(out);
if (large) {
mod.append_arg<int64_t>(out.size());
args.append<int64_t>(out.size());
} else {
mod.append_arg<int32_t>(out.size());
args.append<int32_t>(out.size());
}
mod.append_ndim_arg(src.shape());
mod.append_ndim_arg(src.strides());
mod.append_arg<int32_t>(src.ndim());
mod.append_ndim_arg(slice_sizes_);
mod.append_arg(slice_size);
mod.append_arg(axes_);
append_indices_arg(mod, inputs, nidx, idx_ndim);
args.append_ndim(src.shape());
args.append_ndim(src.strides());
args.append<int32_t>(src.ndim());
args.append_ndim(slice_sizes_);
args.append(slice_size);
args.append(axes_);
append_indices_arg(args, inputs, nidx, idx_ndim);
std::string kernel_name = fmt::format(
"mlx::core::cu::gather<{}, {}, {}, {}, {}>",
@@ -122,9 +126,10 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in);
}
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
mod.launch_kernel(stream, kernel_name, out, large);
});
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
}
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -187,26 +192,27 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
return std::make_pair(jit_source_scatter, std::move(kernel_names));
});
mod.append_arg(upd);
mod.append_arg(out);
cu::KernelArgs args;
args.append(upd);
args.append(out);
if (large) {
mod.append_arg<int64_t>(upd.size());
args.append<int64_t>(upd.size());
} else {
mod.append_arg<int32_t>(upd.size());
args.append<int32_t>(upd.size());
}
mod.append_ndim_arg(upd.shape());
mod.append_ndim_arg(upd.strides());
mod.append_arg<int32_t>(upd.ndim());
args.append_ndim(upd.shape());
args.append_ndim(upd.strides());
args.append<int32_t>(upd.ndim());
if (large) {
mod.append_arg<int64_t>(upd_post_idx_size);
args.append<int64_t>(upd_post_idx_size);
} else {
mod.append_arg<int32_t>(upd_post_idx_size);
args.append<int32_t>(upd_post_idx_size);
}
mod.append_ndim_arg(out.shape());
mod.append_ndim_arg(out.strides());
mod.append_arg<int32_t>(out.ndim());
mod.append_arg(axes_);
append_indices_arg(mod, inputs, nidx, idx_ndim);
args.append_ndim(out.shape());
args.append_ndim(out.strides());
args.append<int32_t>(out.ndim());
args.append(axes_);
append_indices_arg(args, inputs, nidx, idx_ndim);
std::string kernel_name = fmt::format(
"mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>",
@@ -222,9 +228,9 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in);
}
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
mod.launch_kernel(stream, kernel_name, upd, large);
});
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, upd, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
}
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -275,25 +281,26 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
}
size_t idx_size_axis = idx.shape(axis_);
mod.append_arg(src);
mod.append_arg(idx);
mod.append_arg(out);
cu::KernelArgs args;
args.append(src);
args.append(idx);
args.append(out);
if (large) {
mod.append_arg<int64_t>(idx_size_pre);
mod.append_arg<int64_t>(idx_size_axis);
mod.append_arg<int64_t>(idx_size_post);
args.append<int64_t>(idx_size_pre);
args.append<int64_t>(idx_size_axis);
args.append<int64_t>(idx_size_post);
} else {
mod.append_arg<int32_t>(idx_size_pre);
mod.append_arg<int32_t>(idx_size_axis);
mod.append_arg<int32_t>(idx_size_post);
args.append<int32_t>(idx_size_pre);
args.append<int32_t>(idx_size_axis);
args.append<int32_t>(idx_size_post);
}
mod.append_arg(remove_index(idx.shape(), axis_));
mod.append_arg(remove_index(src.strides(), axis_));
mod.append_arg(remove_index(idx.strides(), axis_));
mod.append_arg<int32_t>(axis_);
mod.append_arg(src.shape(axis_));
mod.append_arg(src.strides(axis_));
mod.append_arg(idx.strides(axis_));
args.append(remove_index(idx.shape(), axis_));
args.append(remove_index(src.strides(), axis_));
args.append(remove_index(idx.strides(), axis_));
args.append<int32_t>(axis_);
args.append(src.shape(axis_));
args.append(src.strides(axis_));
args.append(idx.strides(axis_));
std::string kernel_name = fmt::format(
"mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>",
@@ -309,9 +316,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in);
}
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
mod.launch_kernel(stream, kernel_name, idx, large);
});
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
}
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -377,25 +384,26 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
}
size_t idx_size_axis = idx.shape(axis_);
mod.append_arg(upd);
mod.append_arg(idx);
mod.append_arg(out);
cu::KernelArgs args;
args.append(upd);
args.append(idx);
args.append(out);
if (large) {
mod.append_arg<int64_t>(idx_size_pre);
mod.append_arg<int64_t>(idx_size_axis);
mod.append_arg<int64_t>(idx_size_post);
args.append<int64_t>(idx_size_pre);
args.append<int64_t>(idx_size_axis);
args.append<int64_t>(idx_size_post);
} else {
mod.append_arg<int32_t>(idx_size_pre);
mod.append_arg<int32_t>(idx_size_axis);
mod.append_arg<int32_t>(idx_size_post);
args.append<int32_t>(idx_size_pre);
args.append<int32_t>(idx_size_axis);
args.append<int32_t>(idx_size_post);
}
mod.append_arg(remove_index(idx.shape(), axis_));
mod.append_arg(remove_index(upd.strides(), axis_));
mod.append_arg(remove_index(idx.strides(), axis_));
mod.append_arg<int32_t>(axis_);
mod.append_arg(out.shape(axis_));
mod.append_arg(upd.strides(axis_));
mod.append_arg(idx.strides(axis_));
args.append(remove_index(idx.shape(), axis_));
args.append(remove_index(upd.strides(), axis_));
args.append(remove_index(idx.strides(), axis_));
args.append<int32_t>(axis_);
args.append(out.shape(axis_));
args.append(upd.strides(axis_));
args.append(idx.strides(axis_));
std::string kernel_name = fmt::format(
"mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>",
@@ -412,9 +420,9 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in);
}
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
mod.launch_kernel(stream, kernel_name, idx, large);
});
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
}
} // namespace mlx::core