mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Switch to CUDA graphs (#2317)
* cuda graph prototype fix signal bug + start to add dependencies capture more capture more ops remaining ops fix reduce and rope deps add concurrent context try update, but not working cosistent topology order use node api use node api directly to reduce overhead fix bug use kernels in unary cache graph format fix synchronization format * comment
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/jit_module.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -178,6 +179,7 @@ void Compiled::eval_gpu(
|
||||
// Whether to use large index.
|
||||
bool large = compiled_use_large_index(inputs, outputs, contiguous);
|
||||
|
||||
cu::KernelArgs args;
|
||||
// Put inputs.
|
||||
int strides_index = 1;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
@@ -185,26 +187,26 @@ void Compiled::eval_gpu(
|
||||
continue;
|
||||
}
|
||||
const auto& x = inputs[i];
|
||||
mod.append_arg(x);
|
||||
args.append(x);
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
mod.append_arg(strides_vec[strides_index++]);
|
||||
args.append_ptr(strides_vec[strides_index++].data());
|
||||
}
|
||||
}
|
||||
|
||||
// Put outputs.
|
||||
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||
for (auto& x : outputs) {
|
||||
mod.append_arg(x);
|
||||
args.append(x);
|
||||
}
|
||||
|
||||
// Put shape and size.
|
||||
if (!contiguous) {
|
||||
mod.append_arg(shape);
|
||||
args.append_ptr(shape.data());
|
||||
}
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(outputs[0].data_size());
|
||||
args.append<int64_t>(outputs[0].data_size());
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(outputs[0].data_size());
|
||||
args.append<uint32_t>(outputs[0].data_size());
|
||||
}
|
||||
|
||||
// Launch kernel.
|
||||
@@ -222,9 +224,10 @@ void Compiled::eval_gpu(
|
||||
for (const auto& out : outputs) {
|
||||
encoder.set_output_array(out);
|
||||
}
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
mod.launch_kernel(stream, kernel_name, outputs[0], large);
|
||||
});
|
||||
|
||||
auto kernel = mod.get_kernel(kernel_name);
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large);
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user