mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
* cuda graph prototype fix signal bug + start to add dependencies capture more capture more ops remaining ops fix reduce and rope deps add concurrent context try update, but not working cosistent topology order use node api use node api directly to reduce overhead fix bug use kernels in unary cache graph format fix synchronization format * comment
429 lines
13 KiB
C++
429 lines
13 KiB
C++
// Copyright © 2025 Apple Inc.
|
|
|
|
#include "mlx/backend/common/compiled.h"
|
|
#include "mlx/backend/cuda/device.h"
|
|
#include "mlx/backend/cuda/jit_module.h"
|
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
|
#include "mlx/backend/gpu/copy.h"
|
|
#include "mlx/dtype_utils.h"
|
|
#include "mlx/primitives.h"
|
|
|
|
#include "cuda_jit_sources.h"
|
|
|
|
#include <cuda.h>
|
|
#include <fmt/format.h>
|
|
#include <nvrtc.h>
|
|
#include <nvtx3/nvtx3.hpp>
|
|
|
|
#include <cassert>
|
|
#include <numeric>
|
|
|
|
namespace mlx::core {
|
|
|
|
namespace {
|
|
|
|
constexpr const char* g_scatter_ops[] = {"Max", "Min", "Sum", "Prod", "Assign"};
|
|
|
|
void append_indices_arg(
|
|
cu::KernelArgs& args,
|
|
const std::vector<array>& inputs,
|
|
int nidx,
|
|
int idx_ndim) {
|
|
std::vector<const void*> indices(nidx);
|
|
for (int i = 0; i < nidx; ++i) {
|
|
indices[i] = inputs[i + 1].data<void>();
|
|
}
|
|
args.append(std::move(indices));
|
|
std::vector<int32_t> indices_shape(nidx * idx_ndim);
|
|
for (int i = 0; i < nidx; ++i) {
|
|
std::copy_n(
|
|
inputs[i + 1].shape().begin(),
|
|
idx_ndim,
|
|
indices_shape.data() + i * idx_ndim);
|
|
}
|
|
args.append(std::move(indices_shape));
|
|
std::vector<int64_t> indices_strides(nidx * idx_ndim);
|
|
for (int i = 0; i < nidx; ++i) {
|
|
std::copy_n(
|
|
inputs[i + 1].strides().begin(),
|
|
idx_ndim,
|
|
indices_strides.data() + i * idx_ndim);
|
|
}
|
|
args.append(std::move(indices_strides));
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
nvtx3::scoped_range r("Gather::eval_gpu");
|
|
assert(inputs.size() > 0);
|
|
const auto& src = inputs[0];
|
|
|
|
out.set_data(allocator::malloc(out.nbytes()));
|
|
if (out.size() == 0) {
|
|
return;
|
|
}
|
|
|
|
int nidx = inputs.size() - 1;
|
|
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
|
|
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
|
|
|
|
bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) ||
|
|
(src.size() > INT32_MAX) || (out.size() > INT32_MAX);
|
|
|
|
uint32_t slice_size = std::accumulate(
|
|
slice_sizes_.begin(), slice_sizes_.end(), 1, std::multiplies<uint32_t>());
|
|
|
|
std::string module_name = fmt::format(
|
|
"gather_{}_{}_{}",
|
|
dtype_to_string(out.dtype()),
|
|
dtype_to_string(idx_dtype),
|
|
nidx);
|
|
|
|
auto& s = stream();
|
|
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
|
std::vector<std::string> kernel_names;
|
|
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
|
for (int large = 0; large <= 1; ++large) {
|
|
kernel_names.push_back(fmt::format(
|
|
"mlx::core::cu::gather<{}, {}, {}, {}, {}>",
|
|
dtype_to_cuda_type(out.dtype()),
|
|
dtype_to_cuda_type(idx_dtype),
|
|
nidx,
|
|
ndim,
|
|
large ? "int64_t" : "int32_t"));
|
|
}
|
|
}
|
|
return std::make_pair(jit_source_gather, std::move(kernel_names));
|
|
});
|
|
|
|
cu::KernelArgs args;
|
|
args.append(src);
|
|
args.append(out);
|
|
if (large) {
|
|
args.append<int64_t>(out.size());
|
|
} else {
|
|
args.append<int32_t>(out.size());
|
|
}
|
|
args.append_ndim(src.shape());
|
|
args.append_ndim(src.strides());
|
|
args.append<int32_t>(src.ndim());
|
|
args.append_ndim(slice_sizes_);
|
|
args.append(slice_size);
|
|
args.append(axes_);
|
|
append_indices_arg(args, inputs, nidx, idx_ndim);
|
|
|
|
std::string kernel_name = fmt::format(
|
|
"mlx::core::cu::gather<{}, {}, {}, {}, {}>",
|
|
dtype_to_cuda_type(out.dtype()),
|
|
dtype_to_cuda_type(idx_dtype),
|
|
nidx,
|
|
idx_ndim,
|
|
large ? "int64_t" : "int32_t");
|
|
|
|
auto& encoder = cu::get_command_encoder(s);
|
|
for (const auto& in : inputs) {
|
|
encoder.set_input_array(in);
|
|
}
|
|
encoder.set_output_array(out);
|
|
|
|
auto kernel = mod.get_kernel(kernel_name);
|
|
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
|
}
|
|
|
|
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
nvtx3::scoped_range r("Gather::eval_gpu");
|
|
assert(inputs.size() > 1);
|
|
auto& upd = inputs.back();
|
|
|
|
// Copy src into out.
|
|
CopyType copy_type;
|
|
if (inputs[0].data_size() == 1) {
|
|
copy_type = CopyType::Scalar;
|
|
} else if (inputs[0].flags().row_contiguous) {
|
|
copy_type = CopyType::Vector;
|
|
} else {
|
|
copy_type = CopyType::General;
|
|
}
|
|
copy_gpu(inputs[0], out, copy_type);
|
|
|
|
// Empty update.
|
|
if (upd.size() == 0) {
|
|
return;
|
|
}
|
|
|
|
int nidx = axes_.size();
|
|
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
|
|
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
|
|
|
|
bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) ||
|
|
(upd.size() > INT32_MAX) || (out.size() > INT32_MAX);
|
|
|
|
int32_t upd_post_idx_size = std::accumulate(
|
|
upd.shape().begin() + idx_ndim,
|
|
upd.shape().end(),
|
|
1,
|
|
std::multiplies<int32_t>());
|
|
|
|
const char* op = g_scatter_ops[reduce_type_];
|
|
std::string module_name = fmt::format(
|
|
"scatter_{}_{}_{}_{}",
|
|
dtype_to_string(out.dtype()),
|
|
dtype_to_string(idx_dtype),
|
|
op,
|
|
nidx);
|
|
|
|
auto& s = stream();
|
|
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
|
std::vector<std::string> kernel_names;
|
|
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
|
for (int large = 0; large <= 1; ++large) {
|
|
kernel_names.push_back(fmt::format(
|
|
"mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>",
|
|
dtype_to_cuda_type(out.dtype()),
|
|
dtype_to_cuda_type(idx_dtype),
|
|
op,
|
|
nidx,
|
|
ndim,
|
|
large ? "int64_t" : "int32_t"));
|
|
}
|
|
}
|
|
return std::make_pair(jit_source_scatter, std::move(kernel_names));
|
|
});
|
|
|
|
cu::KernelArgs args;
|
|
args.append(upd);
|
|
args.append(out);
|
|
if (large) {
|
|
args.append<int64_t>(upd.size());
|
|
} else {
|
|
args.append<int32_t>(upd.size());
|
|
}
|
|
args.append_ndim(upd.shape());
|
|
args.append_ndim(upd.strides());
|
|
args.append<int32_t>(upd.ndim());
|
|
if (large) {
|
|
args.append<int64_t>(upd_post_idx_size);
|
|
} else {
|
|
args.append<int32_t>(upd_post_idx_size);
|
|
}
|
|
args.append_ndim(out.shape());
|
|
args.append_ndim(out.strides());
|
|
args.append<int32_t>(out.ndim());
|
|
args.append(axes_);
|
|
append_indices_arg(args, inputs, nidx, idx_ndim);
|
|
|
|
std::string kernel_name = fmt::format(
|
|
"mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>",
|
|
dtype_to_cuda_type(out.dtype()),
|
|
dtype_to_cuda_type(idx_dtype),
|
|
op,
|
|
nidx,
|
|
idx_ndim,
|
|
large ? "int64_t" : "int32_t");
|
|
|
|
auto& encoder = cu::get_command_encoder(s);
|
|
for (const auto& in : inputs) {
|
|
encoder.set_input_array(in);
|
|
}
|
|
encoder.set_output_array(out);
|
|
auto kernel = mod.get_kernel(kernel_name);
|
|
auto [num_blocks, block_dims] = get_launch_args(kernel, upd, large);
|
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
|
}
|
|
|
|
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
nvtx3::scoped_range r("GatherAxis::eval_gpu");
|
|
assert(inputs.size() > 1);
|
|
const auto& src = inputs[0];
|
|
const auto& idx = inputs[1];
|
|
|
|
out.set_data(allocator::malloc(out.nbytes()));
|
|
if (out.size() == 0) {
|
|
return;
|
|
}
|
|
|
|
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
|
|
|
std::string module_name = fmt::format(
|
|
"gather_axis_{}_{}",
|
|
dtype_to_string(out.dtype()),
|
|
dtype_to_string(idx.dtype()));
|
|
|
|
auto& s = stream();
|
|
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
|
std::vector<std::string> kernel_names;
|
|
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
|
for (int contiguous = 0; contiguous < 4; ++contiguous) {
|
|
for (int large = 0; large <= 1; ++large) {
|
|
kernel_names.push_back(fmt::format(
|
|
"mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>",
|
|
dtype_to_cuda_type(out.dtype()),
|
|
dtype_to_cuda_type(idx.dtype()),
|
|
ndim,
|
|
contiguous & 1 ? true : false,
|
|
contiguous & 2 ? true : false,
|
|
large ? "int64_t" : "int32_t"));
|
|
}
|
|
}
|
|
}
|
|
return std::make_pair(jit_source_gather_axis, std::move(kernel_names));
|
|
});
|
|
|
|
size_t idx_size_pre = 1;
|
|
size_t idx_size_post = 1;
|
|
for (int i = 0; i < axis_; ++i) {
|
|
idx_size_pre *= idx.shape(i);
|
|
}
|
|
for (int i = axis_ + 1; i < idx.ndim(); ++i) {
|
|
idx_size_post *= idx.shape(i);
|
|
}
|
|
size_t idx_size_axis = idx.shape(axis_);
|
|
|
|
cu::KernelArgs args;
|
|
args.append(src);
|
|
args.append(idx);
|
|
args.append(out);
|
|
if (large) {
|
|
args.append<int64_t>(idx_size_pre);
|
|
args.append<int64_t>(idx_size_axis);
|
|
args.append<int64_t>(idx_size_post);
|
|
} else {
|
|
args.append<int32_t>(idx_size_pre);
|
|
args.append<int32_t>(idx_size_axis);
|
|
args.append<int32_t>(idx_size_post);
|
|
}
|
|
args.append(remove_index(idx.shape(), axis_));
|
|
args.append(remove_index(src.strides(), axis_));
|
|
args.append(remove_index(idx.strides(), axis_));
|
|
args.append<int32_t>(axis_);
|
|
args.append(src.shape(axis_));
|
|
args.append(src.strides(axis_));
|
|
args.append(idx.strides(axis_));
|
|
|
|
std::string kernel_name = fmt::format(
|
|
"mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>",
|
|
dtype_to_cuda_type(out.dtype()),
|
|
dtype_to_cuda_type(idx.dtype()),
|
|
src.ndim() - 1,
|
|
src.flags().row_contiguous,
|
|
idx.flags().row_contiguous,
|
|
large ? "int64_t" : "int32_t");
|
|
|
|
auto& encoder = cu::get_command_encoder(s);
|
|
for (const auto& in : inputs) {
|
|
encoder.set_input_array(in);
|
|
}
|
|
encoder.set_output_array(out);
|
|
auto kernel = mod.get_kernel(kernel_name);
|
|
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
|
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
|
}
|
|
|
|
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
nvtx3::scoped_range r("ScatterAxis::eval_gpu");
|
|
assert(inputs.size() > 2);
|
|
const auto& src = inputs[0];
|
|
const auto& idx = inputs[1];
|
|
const auto& upd = inputs[2];
|
|
|
|
// Copy src into out.
|
|
CopyType copy_type;
|
|
if (src.data_size() == 1) {
|
|
copy_type = CopyType::Scalar;
|
|
} else if (src.flags().row_contiguous) {
|
|
copy_type = CopyType::Vector;
|
|
} else {
|
|
copy_type = CopyType::General;
|
|
}
|
|
copy_gpu(src, out, copy_type);
|
|
|
|
// Empty update.
|
|
if (upd.size() == 0) {
|
|
return;
|
|
}
|
|
|
|
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
|
|
|
const char* op = reduce_type_ == ScatterAxis::Sum ? "Sum" : "Assign";
|
|
std::string module_name = fmt::format(
|
|
"scatter_axis_{}_{}_{}",
|
|
dtype_to_string(out.dtype()),
|
|
dtype_to_string(idx.dtype()),
|
|
op);
|
|
|
|
auto& s = stream();
|
|
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
|
std::vector<std::string> kernel_names;
|
|
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
|
for (int contiguous = 0; contiguous < 4; ++contiguous) {
|
|
for (int large = 0; large <= 1; ++large) {
|
|
kernel_names.push_back(fmt::format(
|
|
"mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>",
|
|
dtype_to_cuda_type(out.dtype()),
|
|
dtype_to_cuda_type(idx.dtype()),
|
|
op,
|
|
ndim,
|
|
contiguous & 1 ? true : false,
|
|
contiguous & 2 ? true : false,
|
|
large ? "int64_t" : "int32_t"));
|
|
}
|
|
}
|
|
}
|
|
return std::make_pair(jit_source_scatter_axis, std::move(kernel_names));
|
|
});
|
|
|
|
size_t idx_size_pre = 1;
|
|
size_t idx_size_post = 1;
|
|
for (int i = 0; i < axis_; ++i) {
|
|
idx_size_pre *= idx.shape(i);
|
|
}
|
|
for (int i = axis_ + 1; i < idx.ndim(); ++i) {
|
|
idx_size_post *= idx.shape(i);
|
|
}
|
|
size_t idx_size_axis = idx.shape(axis_);
|
|
|
|
cu::KernelArgs args;
|
|
args.append(upd);
|
|
args.append(idx);
|
|
args.append(out);
|
|
if (large) {
|
|
args.append<int64_t>(idx_size_pre);
|
|
args.append<int64_t>(idx_size_axis);
|
|
args.append<int64_t>(idx_size_post);
|
|
} else {
|
|
args.append<int32_t>(idx_size_pre);
|
|
args.append<int32_t>(idx_size_axis);
|
|
args.append<int32_t>(idx_size_post);
|
|
}
|
|
args.append(remove_index(idx.shape(), axis_));
|
|
args.append(remove_index(upd.strides(), axis_));
|
|
args.append(remove_index(idx.strides(), axis_));
|
|
args.append<int32_t>(axis_);
|
|
args.append(out.shape(axis_));
|
|
args.append(upd.strides(axis_));
|
|
args.append(idx.strides(axis_));
|
|
|
|
std::string kernel_name = fmt::format(
|
|
"mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>",
|
|
dtype_to_cuda_type(out.dtype()),
|
|
dtype_to_cuda_type(idx.dtype()),
|
|
op,
|
|
idx.ndim() - 1,
|
|
upd.flags().row_contiguous,
|
|
idx.flags().row_contiguous,
|
|
large ? "int64_t" : "int32_t");
|
|
|
|
auto& encoder = cu::get_command_encoder(s);
|
|
for (const auto& in : inputs) {
|
|
encoder.set_input_array(in);
|
|
}
|
|
encoder.set_output_array(out);
|
|
auto kernel = mod.get_kernel(kernel_name);
|
|
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
|
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
|
}
|
|
|
|
} // namespace mlx::core
|