[CUDA] Switch to CUDA graphs (#2317)

* cuda graph prototype

fix signal bug + start to add dependencies

capture more

capture more ops

remaining ops

fix reduce and rope deps

add concurrent context

try update, but not working

cosistent topology order

use node api

use node api directly to reduce overhead

fix bug

use kernels in unary

cache graph

format

fix synchronization

format

* comment
This commit is contained in:
Awni Hannun 2025-07-02 15:59:13 -07:00 committed by GitHub
parent e76e9b87f0
commit ec0d5db67b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 1461 additions and 1212 deletions

View File

@ -12,16 +12,11 @@ namespace mlx::core {
inline std::tuple<Shape, Strides, Strides> collapse_batches( inline std::tuple<Shape, Strides, Strides> collapse_batches(
const array& a, const array& a,
const array& b) { const array& b) {
// Get and check the shape for the batched dims if (a.ndim() == 2) {
Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; return {{1}, {0}, {0}};
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
if (A_bshape != B_bshape) {
std::ostringstream msg;
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
<< a.shape() << ", B " << b.shape() << ".";
throw std::runtime_error(msg.str());
} }
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
@ -42,17 +37,11 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
inline std::tuple<Shape, Strides, Strides, Strides> inline std::tuple<Shape, Strides, Strides, Strides>
collapse_batches(const array& a, const array& b, const array& c) { collapse_batches(const array& a, const array& b, const array& c) {
// Get and check the shape for the batched dims if (a.ndim() == 2) {
Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; return {{1}, {0}, {0}, {0}};
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
Shape C_bshape{c.shape().begin(), c.shape().end() - 2};
if (A_bshape != B_bshape || A_bshape != C_bshape) {
std::ostringstream msg;
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
<< a.shape() << ", B " << b.shape() << ", B " << c.shape() << ".";
throw std::runtime_error(msg.str());
} }
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
Strides C_bstride{c.strides().begin(), c.strides().end() - 2}; Strides C_bstride{c.strides().begin(), c.strides().end() - 2};

View File

@ -151,30 +151,29 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) { using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; constexpr uint32_t N_READS = 4;
constexpr uint32_t N_READS = 4; dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
dispatch_block_dim( dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { auto kernel =
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); cu::arg_reduce_general<T, cu::ArgMax<T>, block_dim(), N_READS>;
auto kernel = if (reduce_type_ == ArgReduce::ArgMin) {
cu::arg_reduce_general<T, cu::ArgMax<T>, block_dim(), N_READS>; kernel = cu::arg_reduce_general<T, cu::ArgMin<T>, block_dim(), N_READS>;
if (reduce_type_ == ArgReduce::ArgMin) { }
kernel = cu:: encoder.add_kernel_node(
arg_reduce_general<T, cu::ArgMin<T>, block_dim(), N_READS>; kernel,
} num_blocks,
kernel<<<num_blocks, block_dim(), 0, stream>>>( block_dim(),
in.data<T>(), in.data<T>(),
out.data<uint32_t>(), out.data<uint32_t>(),
out.size(), out.size(),
const_param(shape), const_param(shape),
const_param(in_strides), const_param(in_strides),
const_param(out_strides), const_param(out_strides),
ndim, ndim,
axis_stride, axis_stride,
axis_size); axis_size);
});
}); });
}); });
} }

View File

@ -139,90 +139,92 @@ void binary_op_gpu_inplace(
encoder.set_input_array(a); encoder.set_input_array(a);
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
dispatch_all_types(a.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) { using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_IN = MLX_GET_TYPE(in_type_tag); using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) { using InType = cuda_type_t<CTYPE_IN>;
using InType = cuda_type_t<CTYPE_IN>; using OutType = cuda_type_t<CTYPE_OUT>;
using OutType = cuda_type_t<CTYPE_OUT>; auto bopt = get_binary_op_type(a, b);
auto bopt = get_binary_op_type(a, b); if (bopt == BinaryOpType::General) {
if (bopt == BinaryOpType::General) { dispatch_bool(
dispatch_bool( a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
out.data_size() > INT32_MAX, [&](auto large) {
[&](auto large) { using IdxT = std::conditional_t<large(), int64_t, int32_t>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>; Shape shape;
Shape shape; std::vector<Strides> strides;
std::vector<Strides> strides; std::tie(shape, strides) = collapse_contiguous_dims(a, b, out);
std::tie(shape, strides) = auto& a_strides = strides[0];
collapse_contiguous_dims(a, b, out); auto& b_strides = strides[1];
auto& a_strides = strides[0]; int ndim = shape.size();
auto& b_strides = strides[1]; if (ndim <= 3) {
int ndim = shape.size(); dispatch_1_2_3(ndim, [&](auto dims_constant) {
if (ndim <= 3) { auto kernel = cu::
dispatch_1_2_3(ndim, [&](auto dims_constant) { binary_g_nd<Op, InType, OutType, IdxT, dims_constant()>;
auto kernel = cu::binary_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides));
});
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large()); get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>( encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),
out.data<OutType>(), out.data<OutType>(),
out.size(), out.size(),
const_param(shape), const_param<dims_constant()>(shape),
const_param(a_strides), const_param<dims_constant()>(a_strides),
const_param(b_strides), const_param<dims_constant()>(b_strides));
ndim); });
} } else {
}); auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
} else { auto [num_blocks, block_dims] =
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { get_launch_args(kernel, out, large());
using IdxT = std::conditional_t<large(), int64_t, uint32_t>; encoder.add_kernel_node(
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>; kernel,
if (bopt == BinaryOpType::ScalarVector) { num_blocks,
kernel = cu::binary_sv<Op, InType, OutType, IdxT>; block_dims,
} else if (bopt == BinaryOpType::VectorScalar) { a.data<InType>(),
kernel = cu::binary_vs<Op, InType, OutType, IdxT>; b.data<InType>(),
} else if (bopt == BinaryOpType::VectorVector) { out.data<OutType>(),
kernel = cu::binary_vv<Op, InType, OutType, IdxT>; out.size(),
} const_param(shape),
auto [num_blocks, block_dims] = get_launch_args( const_param(a_strides),
kernel, out.data_size(), out.shape(), out.strides(), large()); const_param(b_strides),
kernel<<<num_blocks, block_dims, 0, stream>>>( ndim);
a.data<InType>(), }
b.data<InType>(), });
out.data<OutType>(),
out.data_size());
});
}
} else { } else {
throw std::runtime_error(fmt::format( dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
"Can not do binary op {} on inputs of {} with result of {}.", using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
op, auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
dtype_to_string(a.dtype()), if (bopt == BinaryOpType::ScalarVector) {
dtype_to_string(out.dtype()))); kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorScalar) {
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), large());
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size());
});
} }
}); } else {
throw std::runtime_error(fmt::format(
"Can not do binary op {} on inputs of {} with result of {}.",
op,
dtype_to_string(a.dtype()),
dtype_to_string(out.dtype())));
}
}); });
}); });
} }

View File

@ -137,98 +137,101 @@ void binary_op_gpu_inplace(
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out_a); encoder.set_output_array(out_a);
encoder.set_output_array(out_b); encoder.set_output_array(out_b);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
dispatch_all_types(a.dtype(), [&](auto in_type_tag) { dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) {
dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) { using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_IN = MLX_GET_TYPE(in_type_tag); using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) { using InType = cuda_type_t<CTYPE_IN>;
using InType = cuda_type_t<CTYPE_IN>; using OutType = cuda_type_t<CTYPE_OUT>;
using OutType = cuda_type_t<CTYPE_OUT>;
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
if (bopt == BinaryOpType::General) { if (bopt == BinaryOpType::General) {
dispatch_bool( dispatch_bool(
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
out_a.data_size() > INT32_MAX, out_a.data_size() > INT32_MAX,
[&](auto large) { [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, int32_t>; using IdxT = std::conditional_t<large(), int64_t, int32_t>;
Shape shape; Shape shape;
std::vector<Strides> strides; std::vector<Strides> strides;
std::tie(shape, strides) = std::tie(shape, strides) =
collapse_contiguous_dims(a, b, out_a); collapse_contiguous_dims(a, b, out_a);
auto& a_strides = strides[0]; auto& a_strides = strides[0];
auto& b_strides = strides[1]; auto& b_strides = strides[1];
int ndim = shape.size(); int ndim = shape.size();
if (ndim <= 3) { if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu::binary_g_nd< auto kernel = cu::
Op, binary_g_nd<Op, InType, OutType, IdxT, dims_constant()>;
InType,
OutType,
IdxT,
dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides));
});
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large()); get_launch_args(kernel, out_a, large());
kernel<<<num_blocks, block_dims, 0, stream>>>( encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),
out_a.data<OutType>(), out_a.data<OutType>(),
out_b.data<OutType>(), out_b.data<OutType>(),
out_a.size(), out_a.size(),
const_param(shape), const_param<dims_constant()>(shape),
const_param(a_strides), const_param<dims_constant()>(a_strides),
const_param(b_strides), const_param<dims_constant()>(b_strides));
ndim); });
} } else {
}); auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
} else { auto [num_blocks, block_dims] =
dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) { get_launch_args(kernel, out_a, large());
using IdxT = std::conditional_t<large(), int64_t, uint32_t>; encoder.add_kernel_node(
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>; kernel,
if (bopt == BinaryOpType::ScalarVector) { num_blocks,
kernel = cu::binary_sv<Op, InType, OutType, IdxT>; block_dims,
} else if (bopt == BinaryOpType::VectorScalar) { a.data<InType>(),
kernel = cu::binary_vs<Op, InType, OutType, IdxT>; b.data<InType>(),
} else if (bopt == BinaryOpType::VectorVector) { out_a.data<OutType>(),
kernel = cu::binary_vv<Op, InType, OutType, IdxT>; out_b.data<OutType>(),
} out_a.size(),
auto [num_blocks, block_dims] = get_launch_args( const_param(shape),
kernel, const_param(a_strides),
out_a.data_size(), const_param(b_strides),
out_a.shape(), ndim);
out_a.strides(), }
large()); });
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.data_size());
});
}
} else { } else {
throw std::runtime_error(fmt::format( dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) {
"Can not do binary op {} on inputs of {} with result of {}.", using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
op, auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
dtype_to_string(a.dtype()), if (bopt == BinaryOpType::ScalarVector) {
dtype_to_string(out_a.dtype()))); kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorScalar) {
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] = get_launch_args(
kernel,
out_a.data_size(),
out_a.shape(),
out_a.strides(),
large());
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.data_size());
});
} }
}); } else {
throw std::runtime_error(fmt::format(
"Can not do binary op {} on inputs of {} with result of {}.",
op,
dtype_to_string(a.dtype()),
dtype_to_string(out_a.dtype())));
}
}); });
}); });
} }

View File

@ -3,6 +3,7 @@
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/graph_utils.h" #include "mlx/graph_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@ -178,6 +179,7 @@ void Compiled::eval_gpu(
// Whether to use large index. // Whether to use large index.
bool large = compiled_use_large_index(inputs, outputs, contiguous); bool large = compiled_use_large_index(inputs, outputs, contiguous);
cu::KernelArgs args;
// Put inputs. // Put inputs.
int strides_index = 1; int strides_index = 1;
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
@ -185,26 +187,26 @@ void Compiled::eval_gpu(
continue; continue;
} }
const auto& x = inputs[i]; const auto& x = inputs[i];
mod.append_arg(x); args.append(x);
if (!contiguous && !is_scalar(x)) { if (!contiguous && !is_scalar(x)) {
mod.append_arg(strides_vec[strides_index++]); args.append_ptr(strides_vec[strides_index++].data());
} }
} }
// Put outputs. // Put outputs.
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous); compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
for (auto& x : outputs) { for (auto& x : outputs) {
mod.append_arg(x); args.append(x);
} }
// Put shape and size. // Put shape and size.
if (!contiguous) { if (!contiguous) {
mod.append_arg(shape); args.append_ptr(shape.data());
} }
if (large) { if (large) {
mod.append_arg<int64_t>(outputs[0].data_size()); args.append<int64_t>(outputs[0].data_size());
} else { } else {
mod.append_arg<uint32_t>(outputs[0].data_size()); args.append<uint32_t>(outputs[0].data_size());
} }
// Launch kernel. // Launch kernel.
@ -222,9 +224,10 @@ void Compiled::eval_gpu(
for (const auto& out : outputs) { for (const auto& out : outputs) {
encoder.set_output_array(out); encoder.set_output_array(out);
} }
encoder.launch_kernel([&](cudaStream_t stream) {
mod.launch_kernel(stream, kernel_name, outputs[0], large); auto kernel = mod.get_kernel(kernel_name);
}); auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -35,24 +35,25 @@ void copy_contiguous(
array& out, array& out,
int64_t in_offset, int64_t in_offset,
int64_t out_offset) { int64_t out_offset) {
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) { dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>; using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>; using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
using IdxT = std::conditional_t<large(), int64_t, uint32_t>; auto kernel = cu::copy_s<InType, OutType, IdxT>;
auto kernel = cu::copy_s<InType, OutType, IdxT>; if (ctype == CopyType::Vector) {
if (ctype == CopyType::Vector) { kernel = cu::copy_v<InType, OutType, IdxT>;
kernel = cu::copy_v<InType, OutType, IdxT>; }
} auto [num_blocks, block_dims] = get_launch_args(
auto [num_blocks, block_dims] = get_launch_args( kernel, out.data_size(), out.shape(), out.strides(), large());
kernel, out.data_size(), out.shape(), out.strides(), large()); encoder.add_kernel_node(
kernel<<<num_blocks, block_dims, 0, stream>>>( kernel,
in.data<InType>() + in_offset, num_blocks,
out.data<OutType>() + out_offset, block_dims,
out.data_size()); in.data<InType>() + in_offset,
}); out.data<OutType>() + out_offset,
out.data_size());
}); });
}); });
}); });

View File

@ -55,50 +55,54 @@ void copy_general(
const Shape& shape, const Shape& shape,
const Strides& strides_in, const Strides& strides_in,
const Strides& strides_out) { const Strides& strides_out) {
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) { dispatch_bool(
dispatch_bool( in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, [&](auto large) {
[&](auto large) { using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>; using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>; using IdxT = std::conditional_t<large(), int64_t, int32_t>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>; const InType* in_ptr = in.data<InType>() + offset_in;
const InType* in_ptr = in.data<InType>() + offset_in; OutType* out_ptr = out.data<OutType>() + offset_out;
OutType* out_ptr = out.data<OutType>() + offset_out; int ndim = shape.size();
int ndim = shape.size(); size_t data_size = 1;
size_t data_size = 1; for (auto& s : shape)
for (auto& s : shape) data_size *= s;
data_size *= s; if (ndim <= 3) {
if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto ndim_constant) {
dispatch_1_2_3(ndim, [&](auto ndim_constant) { auto kernel =
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>;
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>;
auto [num_blocks, block_dims] = get_launch_args(
kernel, data_size, shape, out.strides(), large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
data_size,
const_param<ndim_constant()>(shape),
const_param<ndim_constant()>(strides_in),
const_param<ndim_constant()>(strides_out));
});
} else { // ndim >= 4
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args( auto [num_blocks, block_dims] = get_launch_args(
kernel, data_size, shape, out.strides(), large()); kernel, data_size, shape, out.strides(), large());
kernel<<<num_blocks, block_dims, 0, stream>>>( encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
in_ptr, in_ptr,
out_ptr, out_ptr,
data_size, data_size,
const_param(shape), const_param<ndim_constant()>(shape),
const_param(strides_in), const_param<ndim_constant()>(strides_in),
const_param(strides_out), const_param<ndim_constant()>(strides_out));
ndim); });
} } else { // ndim >= 4
}); auto kernel = cu::copy_gg<InType, OutType, IdxT>;
}); auto [num_blocks, block_dims] = get_launch_args(
kernel, data_size, shape, out.strides(), large());
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
in_ptr,
out_ptr,
data_size,
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim);
}
});
}); });
}); });
} }

View File

@ -61,54 +61,55 @@ void copy_general_dynamic(
const Strides& strides_out, const Strides& strides_out,
const array& dynamic_offset_in, const array& dynamic_offset_in,
const array& dynamic_offset_out) { const array& dynamic_offset_out) {
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) { dispatch_bool(
dispatch_bool( in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, [&](auto large) {
[&](auto large) { using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>; using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>; using IdxT = std::conditional_t<large(), int64_t, int32_t>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>; const InType* in_ptr = in.data<InType>() + offset_in;
const InType* in_ptr = in.data<InType>() + offset_in; OutType* out_ptr = out.data<OutType>() + offset_out;
OutType* out_ptr = out.data<OutType>() + offset_out; int ndim = shape.size();
int ndim = shape.size(); if (ndim <= 3) {
if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
dispatch_1_2_3(ndim, [&](auto dims_constant) { auto kernel = cu::
auto kernel = cu::copy_gg_dynamic_nd< copy_gg_dynamic_nd<InType, OutType, IdxT, dims_constant()>;
InType,
OutType,
IdxT,
dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(strides_in),
const_param<dims_constant()>(strides_out),
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
});
} else { // ndim >= 4
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large()); get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>( encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
in_ptr, in_ptr,
out_ptr, out_ptr,
out.size(), out.size(),
const_param(shape), const_param<dims_constant()>(shape),
const_param(strides_in), const_param<dims_constant()>(strides_in),
const_param(strides_out), const_param<dims_constant()>(strides_out),
ndim,
dynamic_offset_in.data<int64_t>(), dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>()); dynamic_offset_out.data<int64_t>());
} });
}); } else { // ndim >= 4
}); auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
in_ptr,
out_ptr,
out.size(),
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim,
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
}
});
}); });
}); });
} }

View File

@ -50,45 +50,49 @@ void copy_general_input(
int64_t offset_out, int64_t offset_out,
const Shape& shape, const Shape& shape,
const Strides& strides_in) { const Strides& strides_in) {
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) { dispatch_bool(
dispatch_bool( in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, [&](auto large) {
[&](auto large) { using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>; using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>; using IdxT = std::conditional_t<large(), int64_t, int32_t>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>; const InType* in_ptr = in.data<InType>() + offset_in;
const InType* in_ptr = in.data<InType>() + offset_in; OutType* out_ptr = out.data<OutType>() + offset_out;
OutType* out_ptr = out.data<OutType>() + offset_out; int ndim = shape.size();
int ndim = shape.size(); if (ndim <= 3) {
if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
dispatch_1_2_3(ndim, [&](auto dims_constant) { auto kernel =
auto kernel = cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>;
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(strides_in));
});
} else { // ndim >= 4
auto kernel = cu::copy_g<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large()); get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>( encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
in_ptr, in_ptr,
out_ptr, out_ptr,
out.size(), out.size(),
const_param(shape), const_param<dims_constant()>(shape),
const_param(strides_in), const_param<dims_constant()>(strides_in));
ndim); });
} } else { // ndim >= 4
}); auto kernel = cu::copy_g<InType, OutType, IdxT>;
}); auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
in_ptr,
out_ptr,
out.size(),
const_param(shape),
const_param(strides_in),
ndim);
}
});
}); });
}); });
} }

View File

@ -2,38 +2,23 @@
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/worker.h" #include "mlx/backend/cuda/worker.h"
#include "mlx/backend/metal/metal.h" #include "mlx/utils.h"
#include <fmt/format.h> #include <fmt/format.h>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <future> #include <future>
#include <unordered_set>
namespace mlx::core { namespace mlx::core {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
// This should be less than 255
constexpr int default_max_nodes_per_graph = 20;
constexpr int max_graph_cache_size = 100;
namespace cu { namespace cu {
DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {}
void DeviceStream::synchronize() {
cudaStreamSynchronize(stream_);
}
cudaStream_t DeviceStream::schedule_cuda_stream() {
// TODO: Return a stream that maximizes parallelism.
return stream_;
}
cudaStream_t DeviceStream::last_cuda_stream() {
return stream_;
}
CommandEncoder& DeviceStream::get_encoder() {
if (!encoder_) {
encoder_ = std::make_unique<CommandEncoder>(*this);
}
return *encoder_;
}
Device::Device(int device) : device_(device) { Device::Device(int device) : device_(device) {
CHECK_CUDA_ERROR(cudaDeviceGetAttribute( CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
&compute_capability_major_, cudaDevAttrComputeCapabilityMajor, device_)); &compute_capability_major_, cudaDevAttrComputeCapabilityMajor, device_));
@ -67,49 +52,253 @@ void Device::make_current() {
} }
} }
DeviceStream& Device::get_stream(Stream s) { CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
auto it = streams_.find(s.index); CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
if (it == streams_.end()) { CHECK_CUDA_ERROR(cudaStreamBeginCaptureToGraph(
it = streams_.try_emplace(s.index, *this).first; enc.stream(), graph, NULL, NULL, 0, cudaStreamCaptureModeGlobal));
}
CommandEncoder::CaptureContext::~CaptureContext() {
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
size_t num_nodes;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes));
if (num_nodes == 1) {
cudaGraphNode_t captured_node;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes));
CUDA_KERNEL_NODE_PARAMS params;
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, &params));
cudaGraphNode_t node;
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, enc.graph_, NULL, 0, &params));
enc.insert_graph_dependencies(GraphNode{node, 'K'});
} else {
cudaGraphNode_t node;
CHECK_CUDA_ERROR(
cudaGraphAddChildGraphNode(&node, enc.graph_, NULL, 0, graph));
enc.insert_graph_dependencies(GraphNode{node, 'G'});
}
CHECK_CUDA_ERROR(cudaGraphDestroy(graph));
}
CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
: enc(enc) {
enc.in_concurrent_ = true;
}
CommandEncoder::ConcurrentContext::~ConcurrentContext() {
enc.in_concurrent_ = false;
// Use an empty graph node for synchronization
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
enc.empty_node_count_++;
CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0));
// Insert the concurrent -> empty node dependencies
for (auto& from : enc.concurrent_nodes_) {
enc.from_nodes_.push_back(from.node);
enc.to_nodes_.push_back(empty.node);
enc.graph_key_ += from.id;
enc.graph_key_ += from.node_type;
enc.graph_key_ += empty.id;
enc.graph_key_ += empty.node_type;
}
// Insert the input -> concurrent node dependencies without updating output
// nodes
auto outputs = std::move(enc.active_outputs_);
enc.insert_graph_dependencies(std::move(enc.concurrent_nodes_));
// Update output node to be the empty node
for (auto o : outputs) {
enc.node_map_.emplace(o, empty).first->second = empty;
}
}
void CommandEncoder::insert_graph_dependencies(GraphNode node) {
if (node.node_type == 'G') {
graph_node_count_++;
}
node.id = std::to_string(node_count_++);
if (in_concurrent_) {
concurrent_nodes_.push_back(std::move(node));
} else {
std::vector<GraphNode> nodes;
nodes.push_back(std::move(node));
insert_graph_dependencies(std::move(nodes));
}
}
void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
std::vector<GraphNode> deps;
{
// Dependencies must be added in the same order to produce a consistent
// topology
std::unordered_set<cudaGraphNode_t> set_deps;
for (auto d : active_deps_) {
if (auto it = node_map_.find(d); it != node_map_.end()) {
auto [_, inserted] = set_deps.insert(it->second.node);
if (inserted) {
deps.push_back(it->second);
}
}
}
}
active_deps_.clear();
for (auto o : active_outputs_) {
for (auto& node : nodes) {
node_map_.emplace(o, node).first->second = node;
}
}
active_outputs_.clear();
for (auto& from : deps) {
for (auto& to : nodes) {
from_nodes_.push_back(from.node);
to_nodes_.push_back(to.node);
graph_key_ += from.id;
graph_key_ += from.node_type;
graph_key_ += to.id;
graph_key_ += to.node_type;
}
}
}
CommandEncoder& Device::get_command_encoder(Stream s) {
auto it = encoders_.find(s.index);
if (it == encoders_.end()) {
it = encoders_.try_emplace(s.index, *this).first;
} }
return it->second; return it->second;
} }
CommandEncoder::CommandEncoder(DeviceStream& s) CommandEncoder::CommandEncoder(Device& d) : stream_(d) {
: device_(s.device()), stream_(s) {} CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
}
void clear_graphs(std::unordered_map<std::string, cudaGraphExec_t>& graphs) {
for (auto& [_, graph_exec] : graphs) {
CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec));
}
graphs.clear();
}
CommandEncoder::~CommandEncoder() {
clear_graphs(graph_cache_);
}
void CommandEncoder::add_completed_handler(std::function<void()> task) { void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task)); worker_.add_task(std::move(task));
} }
void CommandEncoder::end_encoding() { void CommandEncoder::set_input_array(const array& arr) {
if (!temporaries_.empty()) { auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
add_completed_handler([temporaries = std::move(temporaries_)]() {}); active_deps_.push_back(id);
} }
// There is no kernel running, run completion handlers immediately. void CommandEncoder::set_output_array(const array& arr) {
if (!has_gpu_work_) { auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
worker_.consume_in_this_thread(); active_deps_.push_back(id);
return; active_outputs_.push_back(id);
} }
has_gpu_work_ = false;
// Put completion handlers in a batch. void CommandEncoder::maybe_commit() {
worker_.end_batch(); if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) {
// Signaling kernel completion is expensive, delay until enough batches.
// TODO: This number is arbitrarily picked, profile for a better stragety.
if (worker_.uncommited_batches() > 8) {
commit(); commit();
} }
} }
void CommandEncoder::add_kernel_node(
void* func,
dim3 grid_dim,
dim3 block_dim,
void** params) {
cudaKernelNodeParams kernel_params = {0};
kernel_params.func = func;
kernel_params.gridDim = grid_dim;
kernel_params.blockDim = block_dim;
kernel_params.kernelParams = params;
cudaGraphNode_t node;
CHECK_CUDA_ERROR(
cudaGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
insert_graph_dependencies(GraphNode{node, 'K'});
}
void CommandEncoder::add_kernel_node(
CUfunction func,
dim3 grid_dim,
dim3 block_dim,
void** params) {
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
kernel_params.func = func;
kernel_params.gridDimX = grid_dim.x;
kernel_params.gridDimY = grid_dim.y;
kernel_params.gridDimZ = grid_dim.z;
kernel_params.blockDimX = block_dim.x;
kernel_params.blockDimY = block_dim.y;
kernel_params.blockDimZ = block_dim.z;
kernel_params.kernelParams = params;
CUgraphNode node;
CHECK_CUDA_ERROR(
cuGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
insert_graph_dependencies(GraphNode{node, 'K'});
}
void CommandEncoder::commit() { void CommandEncoder::commit() {
worker_.commit(stream_.last_cuda_stream()); if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {});
}
if (node_count_ > 0) {
if (!from_nodes_.empty()) {
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size()));
}
// TODO smarter cache policy
if (graph_cache_.size() > max_graph_cache_size) {
clear_graphs(graph_cache_);
}
graph_key_ += ".";
graph_key_ += std::to_string(node_count_);
graph_key_ += ".";
graph_key_ += std::to_string(graph_node_count_);
graph_key_ += ".";
graph_key_ += std::to_string(empty_node_count_);
auto [it, _] = graph_cache_.emplace(graph_key_, nullptr);
auto& graph_exec = it->second;
if (graph_exec != NULL) {
cudaGraphExecUpdateResultInfo update_result;
cudaGraphExecUpdate(graph_exec, graph_, &update_result);
if (update_result.result != cudaGraphExecUpdateSuccess) {
cudaGetLastError();
CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec));
graph_exec = NULL;
}
}
if (graph_exec == NULL) {
CHECK_CUDA_ERROR(
cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0));
}
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
// Reset state
node_count_ = 0;
graph_node_count_ = 0;
from_nodes_.clear();
to_nodes_.clear();
graph_key_.clear();
node_map_.clear();
CHECK_CUDA_ERROR(cudaGraphDestroy(graph_));
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
}
// Put completion handlers in a batch.
worker_.end_batch();
worker_.commit(stream_);
} }
void CommandEncoder::synchronize() { void CommandEncoder::synchronize() {
stream().synchronize(); cudaStreamSynchronize(stream_);
auto p = std::make_shared<std::promise<void>>(); auto p = std::make_shared<std::promise<void>>();
std::future<void> f = p->get_future(); std::future<void> f = p->get_future();
add_completed_handler([p = std::move(p)]() { p->set_value(); }); add_completed_handler([p = std::move(p)]() { p->set_value(); });
@ -127,12 +316,8 @@ Device& device(mlx::core::Device device) {
return it->second; return it->second;
} }
DeviceStream& get_stream(Stream s) {
return device(s.device).get_stream(s);
}
CommandEncoder& get_command_encoder(Stream s) { CommandEncoder& get_command_encoder(Stream s) {
return get_stream(s).get_encoder(); return device(s.device).get_command_encoder(s);
} }
} // namespace cu } // namespace cu

View File

@ -7,41 +7,108 @@
#include "mlx/stream.h" #include "mlx/stream.h"
#include <cublasLt.h> #include <cublasLt.h>
#include <cuda.h>
#include <thrust/execution_policy.h> #include <thrust/execution_policy.h>
#include <unordered_map> #include <unordered_map>
namespace mlx::core::cu { namespace mlx::core::cu {
class Device; class CommandEncoder {
class CommandEncoder;
class DeviceStream {
public: public:
explicit DeviceStream(Device& device); struct CaptureContext {
CaptureContext(CommandEncoder& enc);
~CaptureContext();
cudaGraph_t graph;
CommandEncoder& enc;
};
struct ConcurrentContext {
ConcurrentContext(CommandEncoder& enc);
~ConcurrentContext();
CommandEncoder& enc;
};
DeviceStream(const DeviceStream&) = delete; explicit CommandEncoder(Device& d);
DeviceStream& operator=(const DeviceStream&) = delete; ~CommandEncoder();
// Wait until kernels in the stream complete. CommandEncoder(const CommandEncoder&) = delete;
void synchronize(); CommandEncoder& operator=(const CommandEncoder&) = delete;
// Return a cuda stream for launching kernels. CaptureContext capture_context() {
cudaStream_t schedule_cuda_stream(); return CaptureContext{*this};
}
// Return the last cuda stream used. ConcurrentContext concurrent_context() {
cudaStream_t last_cuda_stream(); return ConcurrentContext{*this};
CommandEncoder& get_encoder();
Device& device() {
return device_;
} }
void set_input_array(const array& arr);
void set_output_array(const array& arr);
template <typename F, typename... Params>
void
add_kernel_node(F* func, dim3 grid_dim, dim3 block_dim, Params&&... params) {
constexpr size_t num = sizeof...(Params);
void* ptrs[num];
size_t i = 0;
([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(
std::forward<Params>(params)),
...);
add_kernel_node((void*)func, grid_dim, block_dim, ptrs);
}
void add_kernel_node(
CUfunction func,
dim3 grid_dim,
dim3 block_dim,
void** params);
void
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params);
void add_temporary(const array& arr) {
temporaries_.push_back(arr.data_shared_ptr());
}
void add_completed_handler(std::function<void()> task);
void maybe_commit();
void commit();
CudaStream& stream() {
return stream_;
}
// Wait until kernels and completion handlers are finished
void synchronize();
private: private:
Device& device_; struct GraphNode {
cudaGraphNode_t node;
// K = kernel
// E = empty
// G = subgraph
char node_type;
std::string id;
};
void insert_graph_dependencies(GraphNode node);
void insert_graph_dependencies(std::vector<GraphNode> nodes);
CudaStream stream_; CudaStream stream_;
std::unique_ptr<CommandEncoder> encoder_; cudaGraph_t graph_;
Worker worker_;
char node_count_{0};
char graph_node_count_{0};
char empty_node_count_{0};
bool in_concurrent_{false};
std::vector<cudaGraphNode_t> from_nodes_;
std::vector<cudaGraphNode_t> to_nodes_;
std::string graph_key_;
std::vector<GraphNode> concurrent_nodes_;
std::vector<std::shared_ptr<array::Data>> temporaries_;
std::unordered_map<std::string, cudaGraphExec_t> graph_cache_;
std::vector<std::uintptr_t> active_deps_;
std::vector<std::uintptr_t> active_outputs_;
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
}; };
class Device { class Device {
@ -55,7 +122,7 @@ class Device {
// Make this device the current cuda device, required by some cuda calls. // Make this device the current cuda device, required by some cuda calls.
void make_current(); void make_current();
DeviceStream& get_stream(Stream s); CommandEncoder& get_command_encoder(Stream s);
int cuda_device() const { int cuda_device() const {
return device_; return device_;
@ -75,67 +142,10 @@ class Device {
int compute_capability_major_; int compute_capability_major_;
int compute_capability_minor_; int compute_capability_minor_;
cublasLtHandle_t lt_; cublasLtHandle_t lt_;
std::unordered_map<int, DeviceStream> streams_; std::unordered_map<int, CommandEncoder> encoders_;
};
class CommandEncoder {
public:
explicit CommandEncoder(DeviceStream& stream);
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
void set_input_array(const array& arr) {}
void set_output_array(const array& arr) {}
void add_temporary(const array& arr) {
temporaries_.push_back(arr.data_shared_ptr());
}
void add_completed_handler(std::function<void()> task);
void end_encoding();
void commit();
// Schedule a cuda stream for |fun| to launch kernels, and check error
// afterwards.
template <typename F>
void launch_kernel(F&& fun) {
launch_kernel(stream_.schedule_cuda_stream(), std::forward<F>(fun));
}
template <typename F>
void launch_kernel(cudaStream_t stream, F&& fun) {
device_.make_current();
fun(stream);
check_cuda_error("kernel launch", cudaGetLastError());
has_gpu_work_ = true;
}
Device& device() {
return device_;
}
DeviceStream& stream() {
return stream_;
}
bool has_gpu_work() const {
return has_gpu_work_;
}
// Wait until kernels and completion handlers are finished
void synchronize();
private:
Device& device_;
DeviceStream& stream_;
Worker worker_;
bool has_gpu_work_{false};
std::vector<std::shared_ptr<array::Data>> temporaries_;
}; };
Device& device(mlx::core::Device device); Device& device(mlx::core::Device device);
DeviceStream& get_stream(Stream s);
CommandEncoder& get_command_encoder(Stream s); CommandEncoder& get_command_encoder(Stream s);
// Return an execution policy that does not sync for result. // Return an execution policy that does not sync for result.

View File

@ -37,22 +37,20 @@ void eval(array& arr) {
} }
auto& encoder = cu::get_command_encoder(arr.primitive().stream()); auto& encoder = cu::get_command_encoder(arr.primitive().stream());
if (encoder.has_gpu_work()) { // Keep used buffers alive until kernel finishes running.
// Keep used buffers alive until kernel finishes running. std::unordered_set<std::shared_ptr<array::Data>> buffers;
std::unordered_set<std::shared_ptr<array::Data>> buffers; for (auto& in : arr.inputs()) {
for (auto& in : arr.inputs()) { buffers.insert(in.data_shared_ptr());
buffers.insert(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
}
// Remove the output if it was donated to by an input.
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
} }
encoder.end_encoding(); for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
}
// Remove the output if it was donated to by an input.
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
encoder.maybe_commit();
} }
void finalize(Stream s) { void finalize(Stream s) {

View File

@ -61,7 +61,9 @@ void CudaEvent::wait(Stream s) {
if (s.device == mlx::core::Device::cpu) { if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this]() mutable { wait(); }); scheduler::enqueue(s, [*this]() mutable { wait(); });
} else { } else {
wait(cu::get_stream(s).last_cuda_stream()); auto& enc = cu::get_command_encoder(s);
enc.commit();
wait(enc.stream());
} }
} }
@ -74,7 +76,9 @@ void CudaEvent::record(Stream s) {
if (s.device == mlx::core::Device::cpu) { if (s.device == mlx::core::Device::cpu) {
throw std::runtime_error("CudaEvent can not wait on cpu stream."); throw std::runtime_error("CudaEvent can not wait on cpu stream.");
} else { } else {
record(cu::get_stream(s).last_cuda_stream()); auto& enc = cu::get_command_encoder(s);
enc.commit();
record(enc.stream());
} }
} }
@ -136,11 +140,9 @@ void SharedEvent::wait(Stream s, uint64_t value) {
scheduler::enqueue(s, [*this, value]() mutable { wait(value); }); scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
} else { } else {
auto& encoder = get_command_encoder(s); auto& encoder = get_command_encoder(s);
encoder.launch_kernel( encoder.commit();
encoder.stream().last_cuda_stream(), wait(encoder.stream(), value);
[this, value](cudaStream_t stream) { wait(stream, value); });
encoder.add_completed_handler([ac = ac_]() {}); encoder.add_completed_handler([ac = ac_]() {});
encoder.end_encoding();
} }
} }
@ -162,11 +164,9 @@ void SharedEvent::signal(Stream s, uint64_t value) {
scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); }); scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); });
} else { } else {
auto& encoder = get_command_encoder(s); auto& encoder = get_command_encoder(s);
encoder.launch_kernel( encoder.commit();
encoder.stream().last_cuda_stream(), signal(encoder.stream(), value);
[this, value](cudaStream_t stream) { signal(stream, value); });
encoder.add_completed_handler([ac = ac_]() {}); encoder.add_completed_handler([ac = ac_]() {});
encoder.end_encoding();
} }
} }

View File

@ -3,13 +3,16 @@
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "cuda_jit_sources.h" #include "cuda_jit_sources.h"
#include <cuda.h>
#include <fmt/format.h> #include <fmt/format.h>
#include <nvrtc.h>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <cassert> #include <cassert>
@ -22,7 +25,7 @@ namespace {
constexpr const char* g_scatter_ops[] = {"Max", "Min", "Sum", "Prod", "Assign"}; constexpr const char* g_scatter_ops[] = {"Max", "Min", "Sum", "Prod", "Assign"};
void append_indices_arg( void append_indices_arg(
cu::JitModule& mod, cu::KernelArgs& args,
const std::vector<array>& inputs, const std::vector<array>& inputs,
int nidx, int nidx,
int idx_ndim) { int idx_ndim) {
@ -30,7 +33,7 @@ void append_indices_arg(
for (int i = 0; i < nidx; ++i) { for (int i = 0; i < nidx; ++i) {
indices[i] = inputs[i + 1].data<void>(); indices[i] = inputs[i + 1].data<void>();
} }
mod.append_arg(std::move(indices)); args.append(std::move(indices));
std::vector<int32_t> indices_shape(nidx * idx_ndim); std::vector<int32_t> indices_shape(nidx * idx_ndim);
for (int i = 0; i < nidx; ++i) { for (int i = 0; i < nidx; ++i) {
std::copy_n( std::copy_n(
@ -38,7 +41,7 @@ void append_indices_arg(
idx_ndim, idx_ndim,
indices_shape.data() + i * idx_ndim); indices_shape.data() + i * idx_ndim);
} }
mod.append_arg(std::move(indices_shape)); args.append(std::move(indices_shape));
std::vector<int64_t> indices_strides(nidx * idx_ndim); std::vector<int64_t> indices_strides(nidx * idx_ndim);
for (int i = 0; i < nidx; ++i) { for (int i = 0; i < nidx; ++i) {
std::copy_n( std::copy_n(
@ -46,7 +49,7 @@ void append_indices_arg(
idx_ndim, idx_ndim,
indices_strides.data() + i * idx_ndim); indices_strides.data() + i * idx_ndim);
} }
mod.append_arg(std::move(indices_strides)); args.append(std::move(indices_strides));
} }
} // namespace } // namespace
@ -94,20 +97,21 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
return std::make_pair(jit_source_gather, std::move(kernel_names)); return std::make_pair(jit_source_gather, std::move(kernel_names));
}); });
mod.append_arg(src); cu::KernelArgs args;
mod.append_arg(out); args.append(src);
args.append(out);
if (large) { if (large) {
mod.append_arg<int64_t>(out.size()); args.append<int64_t>(out.size());
} else { } else {
mod.append_arg<int32_t>(out.size()); args.append<int32_t>(out.size());
} }
mod.append_ndim_arg(src.shape()); args.append_ndim(src.shape());
mod.append_ndim_arg(src.strides()); args.append_ndim(src.strides());
mod.append_arg<int32_t>(src.ndim()); args.append<int32_t>(src.ndim());
mod.append_ndim_arg(slice_sizes_); args.append_ndim(slice_sizes_);
mod.append_arg(slice_size); args.append(slice_size);
mod.append_arg(axes_); args.append(axes_);
append_indices_arg(mod, inputs, nidx, idx_ndim); append_indices_arg(args, inputs, nidx, idx_ndim);
std::string kernel_name = fmt::format( std::string kernel_name = fmt::format(
"mlx::core::cu::gather<{}, {}, {}, {}, {}>", "mlx::core::cu::gather<{}, {}, {}, {}, {}>",
@ -122,9 +126,10 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in); encoder.set_input_array(in);
} }
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
mod.launch_kernel(stream, kernel_name, out, large); auto kernel = mod.get_kernel(kernel_name);
}); auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
} }
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) { void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@ -187,26 +192,27 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
return std::make_pair(jit_source_scatter, std::move(kernel_names)); return std::make_pair(jit_source_scatter, std::move(kernel_names));
}); });
mod.append_arg(upd); cu::KernelArgs args;
mod.append_arg(out); args.append(upd);
args.append(out);
if (large) { if (large) {
mod.append_arg<int64_t>(upd.size()); args.append<int64_t>(upd.size());
} else { } else {
mod.append_arg<int32_t>(upd.size()); args.append<int32_t>(upd.size());
} }
mod.append_ndim_arg(upd.shape()); args.append_ndim(upd.shape());
mod.append_ndim_arg(upd.strides()); args.append_ndim(upd.strides());
mod.append_arg<int32_t>(upd.ndim()); args.append<int32_t>(upd.ndim());
if (large) { if (large) {
mod.append_arg<int64_t>(upd_post_idx_size); args.append<int64_t>(upd_post_idx_size);
} else { } else {
mod.append_arg<int32_t>(upd_post_idx_size); args.append<int32_t>(upd_post_idx_size);
} }
mod.append_ndim_arg(out.shape()); args.append_ndim(out.shape());
mod.append_ndim_arg(out.strides()); args.append_ndim(out.strides());
mod.append_arg<int32_t>(out.ndim()); args.append<int32_t>(out.ndim());
mod.append_arg(axes_); args.append(axes_);
append_indices_arg(mod, inputs, nidx, idx_ndim); append_indices_arg(args, inputs, nidx, idx_ndim);
std::string kernel_name = fmt::format( std::string kernel_name = fmt::format(
"mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>", "mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>",
@ -222,9 +228,9 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in); encoder.set_input_array(in);
} }
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { auto kernel = mod.get_kernel(kernel_name);
mod.launch_kernel(stream, kernel_name, upd, large); auto [num_blocks, block_dims] = get_launch_args(kernel, upd, large);
}); encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
} }
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) { void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
@ -275,25 +281,26 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
size_t idx_size_axis = idx.shape(axis_); size_t idx_size_axis = idx.shape(axis_);
mod.append_arg(src); cu::KernelArgs args;
mod.append_arg(idx); args.append(src);
mod.append_arg(out); args.append(idx);
args.append(out);
if (large) { if (large) {
mod.append_arg<int64_t>(idx_size_pre); args.append<int64_t>(idx_size_pre);
mod.append_arg<int64_t>(idx_size_axis); args.append<int64_t>(idx_size_axis);
mod.append_arg<int64_t>(idx_size_post); args.append<int64_t>(idx_size_post);
} else { } else {
mod.append_arg<int32_t>(idx_size_pre); args.append<int32_t>(idx_size_pre);
mod.append_arg<int32_t>(idx_size_axis); args.append<int32_t>(idx_size_axis);
mod.append_arg<int32_t>(idx_size_post); args.append<int32_t>(idx_size_post);
} }
mod.append_arg(remove_index(idx.shape(), axis_)); args.append(remove_index(idx.shape(), axis_));
mod.append_arg(remove_index(src.strides(), axis_)); args.append(remove_index(src.strides(), axis_));
mod.append_arg(remove_index(idx.strides(), axis_)); args.append(remove_index(idx.strides(), axis_));
mod.append_arg<int32_t>(axis_); args.append<int32_t>(axis_);
mod.append_arg(src.shape(axis_)); args.append(src.shape(axis_));
mod.append_arg(src.strides(axis_)); args.append(src.strides(axis_));
mod.append_arg(idx.strides(axis_)); args.append(idx.strides(axis_));
std::string kernel_name = fmt::format( std::string kernel_name = fmt::format(
"mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>", "mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>",
@ -309,9 +316,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in); encoder.set_input_array(in);
} }
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { auto kernel = mod.get_kernel(kernel_name);
mod.launch_kernel(stream, kernel_name, idx, large); auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
}); encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
} }
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) { void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
@ -377,25 +384,26 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
size_t idx_size_axis = idx.shape(axis_); size_t idx_size_axis = idx.shape(axis_);
mod.append_arg(upd); cu::KernelArgs args;
mod.append_arg(idx); args.append(upd);
mod.append_arg(out); args.append(idx);
args.append(out);
if (large) { if (large) {
mod.append_arg<int64_t>(idx_size_pre); args.append<int64_t>(idx_size_pre);
mod.append_arg<int64_t>(idx_size_axis); args.append<int64_t>(idx_size_axis);
mod.append_arg<int64_t>(idx_size_post); args.append<int64_t>(idx_size_post);
} else { } else {
mod.append_arg<int32_t>(idx_size_pre); args.append<int32_t>(idx_size_pre);
mod.append_arg<int32_t>(idx_size_axis); args.append<int32_t>(idx_size_axis);
mod.append_arg<int32_t>(idx_size_post); args.append<int32_t>(idx_size_post);
} }
mod.append_arg(remove_index(idx.shape(), axis_)); args.append(remove_index(idx.shape(), axis_));
mod.append_arg(remove_index(upd.strides(), axis_)); args.append(remove_index(upd.strides(), axis_));
mod.append_arg(remove_index(idx.strides(), axis_)); args.append(remove_index(idx.strides(), axis_));
mod.append_arg<int32_t>(axis_); args.append<int32_t>(axis_);
mod.append_arg(out.shape(axis_)); args.append(out.shape(axis_));
mod.append_arg(upd.strides(axis_)); args.append(upd.strides(axis_));
mod.append_arg(idx.strides(axis_)); args.append(idx.strides(axis_));
std::string kernel_name = fmt::format( std::string kernel_name = fmt::format(
"mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>", "mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>",
@ -412,9 +420,9 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in); encoder.set_input_array(in);
} }
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { auto kernel = mod.get_kernel(kernel_name);
mod.launch_kernel(stream, kernel_name, idx, large); auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
}); encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -26,16 +26,6 @@ void check_nvrtc_error(const char* name, nvrtcResult err) {
} }
} }
#define CHECK_CU_ERROR(cmd) check_cu_error(#cmd, (cmd))
void check_cu_error(const char* name, CUresult err) {
if (err != CUDA_SUCCESS) {
const char* err_str = "Unknown error";
cuGetErrorString(err, &err_str);
throw std::runtime_error(fmt::format("{} failed: {}", name, err_str));
}
}
// Return the location of the CUDA toolkit. // Return the location of the CUDA toolkit.
const std::string& cuda_home() { const std::string& cuda_home() {
static std::string home = []() -> std::string { static std::string home = []() -> std::string {
@ -280,60 +270,13 @@ JitModule::JitModule(
// Load kernels. // Load kernels.
for (const auto& [name, mangled] : ptx_kernels) { for (const auto& [name, mangled] : ptx_kernels) {
CUfunction kernel; CUfunction kernel;
CHECK_CU_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
kernels_[name] = kernel; kernels_[name] = kernel;
} }
} }
JitModule::~JitModule() { JitModule::~JitModule() {
CHECK_CU_ERROR(cuModuleUnload(module_)); CHECK_CUDA_ERROR(cuModuleUnload(module_));
}
void JitModule::launch_kernel(
CUstream stream,
const std::string& kernel_name,
const array& arr,
bool large,
int work_per_thread) {
CUfunction kernel = get_kernel(kernel_name);
size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread);
int _, block_dim;
CHECK_CU_ERROR(
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
if (block_dim > nthreads) {
block_dim = nthreads;
}
Dims num_blocks{1, 1, 1};
if (large) {
num_blocks =
get_2d_grid_dims_common(arr.shape(), arr.strides(), work_per_thread);
std::get<0>(num_blocks) =
(std::get<0>(num_blocks) + block_dim - 1) / block_dim;
} else {
std::get<0>(num_blocks) = (nthreads + block_dim - 1) / block_dim;
}
launch_kernel(stream, kernel, num_blocks, Dims{block_dim, 1, 1});
}
void JitModule::launch_kernel(
CUstream stream,
CUfunction kernel,
Dims num_blocks,
Dims block_dims) {
CHECK_CU_ERROR(cuLaunchKernel(
kernel,
std::get<0>(num_blocks),
std::get<1>(num_blocks),
std::get<2>(num_blocks),
std::get<0>(block_dims),
std::get<1>(block_dims),
std::get<2>(block_dims),
0,
stream,
args_.data(),
nullptr));
args_.clear();
storage_.clear();
} }
CUfunction JitModule::get_kernel(const std::string& kernel_name) { CUfunction JitModule::get_kernel(const std::string& kernel_name) {
@ -345,10 +288,6 @@ CUfunction JitModule::get_kernel(const std::string& kernel_name) {
return it->second; return it->second;
} }
void JitModule::append_ptr_arg(const void* v) {
args_.push_back(const_cast<void*>(v));
}
JitModule& get_jit_module( JitModule& get_jit_module(
const mlx::core::Device& device, const mlx::core::Device& device,
const std::string& name, const std::string& name,

View File

@ -4,6 +4,7 @@
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/config.h" #include "mlx/backend/cuda/device/config.h"
#include <deque> #include <deque>
@ -23,72 +24,48 @@ using KernelBuilderResult = std::pair<
/* kernel names */ std::vector<std::string>>; /* kernel names */ std::vector<std::string>>;
using KernelBuilder = std::function<KernelBuilderResult()>; using KernelBuilder = std::function<KernelBuilderResult()>;
class JitModule { struct KernelArgs {
public: void** args() {
JitModule( return args_.data();
Device& device, }
const std::string& module_name,
const KernelBuilder& builder);
~JitModule();
JitModule(const JitModule&) = delete; void append(const array& a) {
JitModule& operator=(const JitModule&) = delete; append(reinterpret_cast<CUdeviceptr>(a.data<void>()));
void append_arg(const array& a) {
append_arg(reinterpret_cast<CUdeviceptr>(a.data<void>()));
} }
template <typename T> template <typename T>
void append_arg(T val) { void append(T val) {
storage_.emplace_back(val); storage_.emplace_back(val);
append_ptr_arg(&storage_.back()); append_ptr(&storage_.back());
} }
template <typename T> template <typename T>
void append_arg(std::vector<T> vec) { void append(std::vector<T> vec) {
if (vec.empty()) { if (vec.empty()) {
// The nullptr can not be used as arg, pass something not null. // The nullptr can not be used as arg, pass something not null.
append_arg(std::monostate{}); append(std::monostate{});
} else { } else {
append_ptr_arg(vec.data()); append_ptr(vec.data());
storage_.emplace_back(std::move(vec)); storage_.emplace_back(std::move(vec));
} }
} }
// Make sure the arg is copied to an array with size of NDIM. // Make sure the arg is copied to an array with size of NDIM.
template <size_t NDIM = MAX_NDIM, typename T> template <size_t NDIM = MAX_NDIM, typename T>
void append_ndim_arg(const std::vector<T>& vec) { void append_ndim(std::vector<T> vec) {
if (vec.size() > NDIM) { if (vec.size() > NDIM) {
throw std::runtime_error( throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", NDIM)); fmt::format("ndim can not be larger than {}.", NDIM));
} }
std::vector<T> copied(NDIM); vec.resize(NDIM);
std::copy(vec.begin(), vec.end(), copied.data()); append(std::move(vec));
append_arg(std::move(copied));
} }
// Launch kernel with |kernel_name| that each thread works on void append_ptr(const void* v) {
// |work_per_thread| elements of |arr|. args_.push_back(const_cast<void*>(v));
void launch_kernel( }
CUstream stream,
const std::string& kernel_name,
const array& arr,
bool large,
int work_per_thread = 1);
void launch_kernel(
CUstream stream,
CUfunction kernel,
Dims num_blocks,
Dims block_dims);
CUfunction get_kernel(const std::string& kernel_name);
private: private:
void append_ptr_arg(const void* v);
CUmodule module_{nullptr};
std::unordered_map<std::string, CUfunction> kernels_;
std::vector<void*> args_; std::vector<void*> args_;
// The cuLaunchKernel API requires passing pointers to arguments so store // The cuLaunchKernel API requires passing pointers to arguments so store
@ -105,6 +82,23 @@ class JitModule {
std::deque<Arg> storage_; std::deque<Arg> storage_;
}; };
class JitModule {
public:
JitModule(
Device& device,
const std::string& module_name,
const KernelBuilder& builder);
~JitModule();
JitModule(const JitModule&) = delete;
JitModule& operator=(const JitModule&) = delete;
CUfunction get_kernel(const std::string& kernel_name);
private:
CUmodule module_{nullptr};
std::unordered_map<std::string, CUfunction> kernels_;
};
JitModule& get_jit_module( JitModule& get_jit_module(
const mlx::core::Device& device, const mlx::core::Device& device,
const std::string& name, const std::string& name,

View File

@ -12,6 +12,7 @@
#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/device/utils.cuh"
#include <cuComplex.h> #include <cuComplex.h>
#include <cuda.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <fmt/format.h> #include <fmt/format.h>
@ -120,7 +121,13 @@ std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
template <typename T> template <typename T>
inline uint max_occupancy_block_dim(T kernel) { inline uint max_occupancy_block_dim(T kernel) {
int _, block_dim; int _, block_dim;
CHECK_CUDA_ERROR(cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel)); if constexpr (std::is_same_v<T, CUfunction>) {
CHECK_CUDA_ERROR(
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
} else {
CHECK_CUDA_ERROR(
cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
}
return block_dim; return block_dim;
} }

View File

@ -258,23 +258,23 @@ void LayerNorm::eval_gpu(
encoder.set_input_array(w); encoder.set_input_array(w);
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) {
dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) { constexpr uint32_t N_READS = 4;
constexpr uint32_t N_READS = 4; dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
dispatch_block_dim( using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { auto kernel = cu::layer_norm<DataType, block_dim(), N_READS>;
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; encoder.add_kernel_node(
auto kernel = cu::layer_norm<DataType, block_dim(), N_READS>; kernel,
kernel<<<n_rows, block_dim(), 0, stream>>>( n_rows,
x.data<DataType>(), block_dim(),
w.data<DataType>(), x.data<DataType>(),
b.data<DataType>(), w.data<DataType>(),
out.data<DataType>(), b.data<DataType>(),
eps_, out.data<DataType>(),
axis_size, eps_,
w_stride, axis_size,
b_stride); w_stride,
}); b_stride);
}); });
}); });
} }
@ -289,21 +289,25 @@ void LayerNormVJP::eval_gpu(
// Ensure row contiguity. We could relax this step by checking that the array // Ensure row contiguity. We could relax this step by checking that the array
// is contiguous (no broadcasts or holes) and that the input strides are the // is contiguous (no broadcasts or holes) and that the input strides are the
// same as the cotangent strides but for now this is simpler. // same as the cotangent strides but for now this is simpler.
auto check_input = [&s](const array& x) -> std::pair<array, bool> { auto check_input = [&s](const array& x, bool& copied) {
if (x.flags().row_contiguous) { if (x.flags().row_contiguous) {
return {x, false}; copied = false;
return x;
} }
copied = true;
array x_copy(x.shape(), x.dtype(), nullptr, {}); array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s); copy_gpu(x, x_copy, CopyType::General, s);
return {x_copy, true}; return x_copy;
}; };
bool donate_x = inputs[0].is_donatable(); bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[3].is_donatable(); bool donate_g = inputs[3].is_donatable();
auto [x, copied] = check_input(inputs[0]); bool copied;
auto x = check_input(inputs[0], copied);
donate_x |= copied; donate_x |= copied;
const array& w = inputs[1]; const array& w = inputs[1];
const array& b = inputs[2]; const array& b = inputs[2];
auto [g, g_copied] = check_input(inputs[3]); bool g_copied;
auto g = check_input(inputs[3], g_copied);
donate_g |= g_copied; donate_g |= g_copied;
array& gx = outputs[0]; array& gx = outputs[0];
array& gw = outputs[1]; array& gw = outputs[1];
@ -334,8 +338,10 @@ void LayerNormVJP::eval_gpu(
// gradient accumulators. // gradient accumulators.
array gw_temp = array gw_temp =
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
bool g_in_gw = false;
if (has_w) { if (has_w) {
if (!g_in_gx && donate_g) { if (!g_in_gx && donate_g) {
g_in_gw = true;
gw_temp.copy_shared_buffer(g); gw_temp.copy_shared_buffer(g);
} else { } else {
gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
@ -343,41 +349,47 @@ void LayerNormVJP::eval_gpu(
} }
} }
// Finish with the gradient for b in case we had a b. // The gradient for b in case we had a b.
if (gb.ndim() == 1 && gb.size() == axis_size) { bool has_gb = (gb.ndim() == 1 && gb.size() == axis_size);
if (has_gb) {
ReductionPlan plan( ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan); col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan);
} }
// Insert dependency if `g` was donated
if ((g_in_gx || g_in_gw) && has_gb) {
encoder.set_input_array(gb);
}
encoder.set_input_array(x); encoder.set_input_array(x);
encoder.set_input_array(w); encoder.set_input_array(w);
encoder.set_input_array(g); encoder.set_input_array(g);
encoder.set_output_array(gx); encoder.set_output_array(gx);
encoder.set_output_array(gw_temp); encoder.set_output_array(gw_temp);
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) {
dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) { dispatch_bool(has_w, [&](auto has_w_constant) {
dispatch_bool(has_w, [&](auto has_w_constant) { constexpr int N_READS = 4;
constexpr int N_READS = 4; dispatch_block_dim(
dispatch_block_dim( cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; auto kernel = cu::layer_norm_vjp<
auto kernel = cu::layer_norm_vjp< DataType,
DataType, has_w_constant.value,
has_w_constant.value, block_dim(),
block_dim(), N_READS>;
N_READS>; encoder.add_kernel_node(
kernel<<<n_rows, block_dim(), 0, stream>>>( kernel,
x.data<DataType>(), n_rows,
w.data<DataType>(), block_dim(),
g.data<DataType>(), x.data<DataType>(),
gx.data<DataType>(), w.data<DataType>(),
gw_temp.data<DataType>(), g.data<DataType>(),
eps_, gx.data<DataType>(),
axis_size, gw_temp.data<DataType>(),
w_stride); eps_,
}); axis_size,
}); w_stride);
});
}); });
}); });

View File

@ -143,16 +143,18 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) {
dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) { constexpr int N_READS = 4;
constexpr int N_READS = 4; dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
dispatch_block_dim( using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { auto kernel = cu::logsumexp<DataType, float, block_dim(), N_READS>;
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; encoder.add_kernel_node(
auto kernel = cu::logsumexp<DataType, float, block_dim(), N_READS>; kernel,
kernel<<<n_rows, block_dim(), 0, stream>>>( n_rows,
in.data<DataType>(), out.data<DataType>(), axis_size); block_dim(),
}); in.data<DataType>(),
out.data<DataType>(),
axis_size);
}); });
}); });
} }

View File

@ -42,7 +42,8 @@ class MatMul {
int64_t ldb, int64_t ldb,
int32_t batch_count, int32_t batch_count,
int64_t a_batch_stride, int64_t a_batch_stride,
int64_t b_batch_stride) { int64_t b_batch_stride)
: handle_(device.lt_handle()) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cuda_type(dtype); auto scale_type = dtype_to_cuda_type(dtype);
@ -147,7 +148,7 @@ class MatMul {
if (heuristic_.state != CUBLAS_STATUS_SUCCESS) { if (heuristic_.state != CUBLAS_STATUS_SUCCESS) {
int ret = 0; int ret = 0;
CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic( CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(
encoder.device().lt_handle(), handle_,
matmul_desc_, matmul_desc_,
a_desc_, a_desc_,
b_desc_, b_desc_,
@ -172,25 +173,24 @@ class MatMul {
workspace_ptr = workspace.data<void>(); workspace_ptr = workspace.data<void>();
} }
encoder.launch_kernel([&](cudaStream_t stream) { auto capture = encoder.capture_context();
CHECK_CUBLAS_ERROR(cublasLtMatmul( CHECK_CUBLAS_ERROR(cublasLtMatmul(
encoder.device().lt_handle(), handle_,
matmul_desc_, matmul_desc_,
&alpha, &alpha,
a, a,
a_desc_, a_desc_,
b, b,
b_desc_, b_desc_,
&beta, &beta,
c ? c : out, c ? c : out,
c ? c_desc_ : out_desc_, c ? c_desc_ : out_desc_,
out, out,
out_desc_, out_desc_,
&heuristic_.algo, &heuristic_.algo,
workspace_ptr, workspace_ptr,
heuristic_.workspaceSize, heuristic_.workspaceSize,
stream)); encoder.stream()));
});
} }
private: private:
@ -259,6 +259,7 @@ class MatMul {
return desc; return desc;
} }
cublasLtHandle_t handle_{nullptr};
cublasLtMatmulDesc_t matmul_desc_{nullptr}; cublasLtMatmulDesc_t matmul_desc_{nullptr};
cublasLtMatmulPreference_t pref_{nullptr}; cublasLtMatmulPreference_t pref_{nullptr};
cublasLtMatrixLayout_t a_desc_{nullptr}; cublasLtMatrixLayout_t a_desc_{nullptr};
@ -273,7 +274,7 @@ class MatMul {
namespace { namespace {
std::tuple<bool, int64_t, array> std::tuple<bool, int64_t, array>
check_transpose(std::vector<array>& copies, const Stream& s, const array& arr) { check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2]; auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1]; auto sty = arr.strides()[arr.ndim() - 1];
if (sty == 1 && stx == arr.shape(-1)) { if (sty == 1 && stx == arr.shape(-1)) {
@ -283,7 +284,7 @@ check_transpose(std::vector<array>& copies, const Stream& s, const array& arr) {
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s); copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy); enc.add_temporary(arr_copy);
return std::make_tuple(false, arr.shape(-1), arr_copy); return std::make_tuple(false, arr.shape(-1), arr_copy);
} }
} }
@ -317,13 +318,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Keep a vector with copies to be cleared in the completed buffer to release // Keep a vector with copies to be cleared in the completed buffer to release
// the arrays // the arrays
std::vector<array> copies; auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre); auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre);
for (auto& temp : copies) {
encoder.add_temporary(temp);
}
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions // Check and collapse batch dimensions
@ -348,7 +344,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Invoke cublasLt // Invoke cublasLt
cu::MatMul matmul( cu::MatMul matmul(
encoder.device(), cu::device(s.device),
a.dtype(), a.dtype(),
a_transposed, a_transposed,
M, M,
@ -373,6 +369,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) { for (size_t i = 0; i < nbatch; ++i) {
matmul.run( matmul.run(
encoder, encoder,
@ -405,14 +402,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Keep a vector with copies to be cleared in the completed buffer to release // Keep a vector with copies to be cleared in the completed buffer to release
// the arrays // the arrays
std::vector<array> copies; auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre); auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre); auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre);
auto [c_transposed, ldc, c] = check_transpose(copies, s, c_pre);
for (auto& temp : copies) {
encoder.add_temporary(temp);
}
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions // Check and collapse batch dimensions
@ -440,7 +432,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Invoke cublasLt // Invoke cublasLt
cu::MatMul matmul( cu::MatMul matmul(
encoder.device(), cu::device(s.device),
a.dtype(), a.dtype(),
a_transposed, a_transposed,
M, M,
@ -478,6 +470,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) { for (size_t i = 0; i < nbatch; ++i) {
matmul.run( matmul.run(
encoder, encoder,

View File

@ -24,23 +24,21 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }
auto& s = stream(); auto& encoder = cu::get_command_encoder(stream());
auto& encoder = cu::get_command_encoder(s);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&, this](cudaStream_t stream) { auto capture = encoder.capture_context();
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag); using CTYPE = MLX_GET_TYPE(type_tag);
using OutType = cuda_type_t<CTYPE>; using OutType = cuda_type_t<CTYPE>;
CTYPE step = CTYPE step =
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_); static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
thrust::transform( thrust::transform(
cu::thrust_policy(stream), cu::thrust_policy(encoder.stream()),
thrust::counting_iterator<uint32_t>(0), thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(out.data_size()), thrust::counting_iterator<uint32_t>(out.data_size()),
thrust::device_pointer_cast(out.data<OutType>()), thrust::device_pointer_cast(out.data<OutType>()),
cu::Arange<OutType>{ cu::Arange<OutType>{
static_cast<OutType>(start_), static_cast<OutType>(step)}); static_cast<OutType>(start_), static_cast<OutType>(step)});
});
}); });
} }

View File

@ -156,34 +156,39 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(keys); encoder.set_input_array(keys);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dim3 grid_dims{num_keys, half_size + odd};
dim3 grid_dims{num_keys, half_size + odd}; int64_t total = grid_dims.x * grid_dims.y;
int64_t total = grid_dims.x * grid_dims.y; int32_t threads_y = 1;
int32_t threads_y = 1; while ((total / threads_y) >= (1U << 31)) {
while ((total / threads_y) >= (1U << 31)) { threads_y *= 2;
threads_y *= 2; }
} int32_t threads_x = cuda::ceil_div(total, threads_y);
int32_t threads_x = cuda::ceil_div(total, threads_y); auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1); auto& stream = encoder.stream();
if (keys.flags().row_contiguous) { if (keys.flags().row_contiguous) {
cu::rbitsc<<<grid, block, 0, stream>>>( encoder.add_kernel_node(
keys.data<uint32_t>(), cu::rbitsc,
out.data<uint8_t>(), grid,
grid_dims, block,
odd, keys.data<uint32_t>(),
bytes_per_key); out.data<uint8_t>(),
} else { grid_dims,
cu::rbits<<<grid, block, 0, stream>>>( odd,
keys.data<uint32_t>(), bytes_per_key);
out.data<uint8_t>(), } else {
grid_dims, encoder.add_kernel_node(
odd, cu::rbits,
bytes_per_key, grid,
keys.ndim(), block,
const_param(keys.shape()), keys.data<uint32_t>(),
const_param(keys.strides())); out.data<uint8_t>(),
} grid_dims,
}); odd,
bytes_per_key,
keys.ndim(),
const_param(keys.shape()),
const_param(keys.strides()));
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -110,19 +110,20 @@ void all_reduce(
intermediate.set_data(allocator::malloc(intermediate.nbytes())); intermediate.set_data(allocator::malloc(intermediate.nbytes()));
encoder.add_temporary(intermediate); encoder.add_temporary(intermediate);
encoder.set_output_array(intermediate); encoder.set_output_array(intermediate);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(dt, [&](auto type_tag) {
dispatch_all_types(dt, [&](auto type_tag) { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { using OP = MLX_GET_TYPE(reduce_type_tag);
using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using U = typename cu::ReduceResult<OP, T>::type;
using U = typename cu::ReduceResult<OP, T>::type; auto kernel = cu::all_reduce<T, U, OP, N_READS>;
auto kernel = cu::all_reduce<T, U, OP, N_READS>; encoder.add_kernel_node(
kernel<<<blocks, threads, 0, stream>>>( kernel,
static_cast<T*>(indata), blocks,
intermediate.data<U>(), threads,
block_step, static_cast<T*>(indata),
insize); intermediate.data<U>(),
}); block_step,
insize);
}); });
}); });
@ -135,16 +136,20 @@ void all_reduce(
} }
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(dt, [&](auto type_tag) {
dispatch_all_types(dt, [&](auto type_tag) { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { using OP = MLX_GET_TYPE(reduce_type_tag);
using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using U = typename cu::ReduceResult<OP, T>::type;
using U = typename cu::ReduceResult<OP, T>::type; auto kernel = cu::all_reduce<T, U, OP, N_READS>;
auto kernel = cu::all_reduce<T, U, OP, N_READS>; encoder.add_kernel_node(
kernel<<<blocks, threads, 0, stream>>>( kernel,
static_cast<T*>(indata), out.data<U>(), block_step, insize); blocks,
}); threads,
static_cast<T*>(indata),
out.data<U>(),
block_step,
insize);
}); });
}); });
} }

View File

@ -214,26 +214,24 @@ void col_reduce_looped(
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_all_types(in.dtype(), [&](auto type_tag) { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { using OP = MLX_GET_TYPE(reduce_type_tag);
using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using U = typename cu::ReduceResult<OP, T>::type;
using U = typename cu::ReduceResult<OP, T>::type; // Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
// Cub doesn't like const pointers for vectorized loads. (sigh) constexpr int N_READS = 4;
T* indata = const_cast<T*>(in.data<T>()); constexpr int BM = 32;
constexpr int BN = 32;
constexpr int N_READS = 4; dim3 grid = output_grid_for_col_reduce(out, args, BN);
constexpr int BM = 32; int blocks = BM * BN / N_READS;
constexpr int BN = 32; auto kernel =
dim3 grid = output_grid_for_col_reduce(out, args, BN); cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
int blocks = BM * BN / N_READS; encoder.add_kernel_node(
auto kernel = kernel, grid, blocks, indata, out.data<U>(), args);
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
kernel<<<grid, blocks, 0, stream>>>(indata, out.data<U>(), args);
});
}); });
}); });
}); });

View File

@ -32,18 +32,16 @@ void init_reduce(
} }
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_all_types(in.dtype(), [&](auto type_tag) { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { using OP = MLX_GET_TYPE(reduce_type_tag);
using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using U = typename cu::ReduceResult<OP, T>::type;
using U = typename cu::ReduceResult<OP, T>::type; auto kernel = cu::init_reduce<T, U, OP>;
auto kernel = cu::init_reduce<T, U, OP>; dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1); grid.x = (grid.x + 1023) / 1024;
grid.x = (grid.x + 1023) / 1024; encoder.add_kernel_node(kernel, grid, block, out.data<U>(), out.size());
kernel<<<grid, block, 0, stream>>>(out.data<U>(), out.size());
});
}); });
}); });
} }

View File

@ -245,34 +245,32 @@ void row_reduce_simple(
// 2 passes. Something like 32 * out.size() and then do a warp reduce. // 2 passes. Something like 32 * out.size() and then do a warp reduce.
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_all_types(in.dtype(), [&](auto type_tag) { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { using OP = MLX_GET_TYPE(reduce_type_tag);
using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using U = typename cu::ReduceResult<OP, T>::type;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh) // Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>()); T* indata = const_cast<T*>(in.data<T>());
// Calculate the grid and block dims // Calculate the grid and block dims
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS; size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
int threads = std::min(1024UL, reductions); int threads = std::min(1024UL, reductions);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
dim3 block(threads, 1, 1); dim3 block(threads, 1, 1);
// Pick the kernel // Pick the kernel
auto kernel = cu::row_reduce_simple<T, U, OP, N_READS>; auto kernel = cu::row_reduce_simple<T, U, OP, N_READS>;
if (grid.x >= 1024) { if (grid.x >= 1024) {
grid.x = (grid.x + 1) / 2; grid.x = (grid.x + 1) / 2;
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>; kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
} }
// Launch int size = plan.shape.back();
kernel<<<grid, block, 0, stream>>>( encoder.add_kernel_node(
indata, out.data<U>(), out.size(), plan.shape.back()); kernel, grid, block, indata, out.data<U>(), out.size(), size);
});
}); });
}); });
} }
@ -293,43 +291,39 @@ void row_reduce_looped(
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_all_types(in.dtype(), [&](auto type_tag) { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { using OP = MLX_GET_TYPE(reduce_type_tag);
using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using U = typename cu::ReduceResult<OP, T>::type;
using U = typename cu::ReduceResult<OP, T>::type; // Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
// Cub doesn't like const pointers for vectorized loads. (sigh) // Calculate the grid and block dims
T* indata = const_cast<T*>(in.data<T>()); args.sort_access_pattern(in, axes);
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
int threads = std::min(1024UL, reductions);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
dim3 block(threads, 1, 1);
// Calculate the grid and block dims // Pick the kernel
args.sort_access_pattern(in, axes); auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
size_t reductions = (args.row_size + N_READS - 1) / N_READS; dispatch_block_dim(threads, [&](auto threads_constant) {
int threads = std::min(1024UL, reductions); kernel = cu::row_reduce_looped<
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; T,
dim3 block(threads, 1, 1); U,
OP,
// Pick the kernel reduce_ndim.value,
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>; threads_constant.value,
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { N_READS>;
dispatch_block_dim(threads, [&](auto threads_constant) { block.x = threads_constant.value;
kernel = cu::row_reduce_looped<
T,
U,
OP,
reduce_ndim.value,
threads_constant.value,
N_READS>;
block.x = threads_constant.value;
});
}); });
// Launch
kernel<<<grid, block, 0, stream>>>(
indata, out.data<U>(), out.size(), args);
}); });
encoder.add_kernel_node(
kernel, grid, block, indata, out.data<U>(), out.size(), args);
}); });
}); });
} }

View File

@ -224,21 +224,21 @@ void RMSNorm::eval_gpu(
encoder.set_input_array(x); encoder.set_input_array(x);
encoder.set_input_array(w); encoder.set_input_array(w);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) { constexpr uint32_t N_READS = 4;
constexpr uint32_t N_READS = 4; dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
dispatch_block_dim( using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>;
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; encoder.add_kernel_node(
auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>; kernel,
kernel<<<n_rows, block_dim(), 0, stream>>>( n_rows,
x.data<DataType>(), block_dim(),
w.data<DataType>(), x.data<DataType>(),
out.data<DataType>(), w.data<DataType>(),
eps_, out.data<DataType>(),
axis_size, eps_,
w_stride); axis_size,
}); w_stride);
}); });
}); });
} }
@ -253,20 +253,24 @@ void RMSNormVJP::eval_gpu(
// Ensure row contiguity. We could relax this step by checking that the array // Ensure row contiguity. We could relax this step by checking that the array
// is contiguous (no broadcasts or holes) and that the input strides are the // is contiguous (no broadcasts or holes) and that the input strides are the
// same as the cotangent strides but for now this is simpler. // same as the cotangent strides but for now this is simpler.
auto check_input = [&s](const array& x) -> std::pair<array, bool> { auto check_input = [&s](const array& x, bool& copied) {
if (x.flags().row_contiguous) { if (x.flags().row_contiguous) {
return {x, false}; copied = false;
return x;
} }
copied = true;
array x_copy(x.shape(), x.dtype(), nullptr, {}); array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s); copy_gpu(x, x_copy, CopyType::General, s);
return {x_copy, true}; return x_copy;
}; };
bool donate_x = inputs[0].is_donatable(); bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[2].is_donatable(); bool donate_g = inputs[2].is_donatable();
auto [x, copied] = check_input(inputs[0]); bool copied;
auto x = check_input(inputs[0], copied);
donate_x |= copied; donate_x |= copied;
const array& w = inputs[1]; const array& w = inputs[1];
auto [g, g_copied] = check_input(inputs[2]); bool g_copied;
auto g = check_input(inputs[2], g_copied);
donate_g |= g_copied; donate_g |= g_copied;
array& gx = outputs[0]; array& gx = outputs[0];
array& gw = outputs[1]; array& gw = outputs[1];
@ -310,30 +314,31 @@ void RMSNormVJP::eval_gpu(
encoder.set_input_array(g); encoder.set_input_array(g);
encoder.set_output_array(gx); encoder.set_output_array(gx);
encoder.set_output_array(gw_temp); encoder.set_output_array(gw_temp);
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) {
dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) { dispatch_bool(has_w, [&](auto has_w_constant) {
dispatch_bool(has_w, [&](auto has_w_constant) { constexpr int N_READS = 4;
constexpr int N_READS = 4; dispatch_block_dim(
dispatch_block_dim( cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; constexpr int N_READS = 4;
constexpr int N_READS = 4; auto kernel = cu::rms_norm_vjp<
auto kernel = cu::rms_norm_vjp< DataType,
DataType, has_w_constant.value,
has_w_constant.value, block_dim(),
block_dim(), N_READS>;
N_READS>; encoder.add_kernel_node(
kernel<<<n_rows, block_dim(), 0, stream>>>( kernel,
x.data<DataType>(), n_rows,
w.data<DataType>(), block_dim(),
g.data<DataType>(), x.data<DataType>(),
gx.data<DataType>(), w.data<DataType>(),
gw_temp.data<DataType>(), g.data<DataType>(),
eps_, gx.data<DataType>(),
axis_size, gw_temp.data<DataType>(),
w_stride); eps_,
}); axis_size,
}); w_stride);
});
}); });
}); });

View File

@ -308,76 +308,89 @@ void RoPE::eval_gpu(
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(donated ? out : in); encoder.set_input_array(donated ? out : in);
encoder.set_input_array(offset); encoder.set_input_array(offset);
if (with_freqs) {
encoder.set_input_array(inputs[2]);
}
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) {
dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) { dispatch_bool(traditional_, [&](auto traditional) {
dispatch_bool(traditional_, [&](auto traditional) { dispatch_bool(forward_, [&](auto forward) {
dispatch_bool(forward_, [&](auto forward) { using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; if (single && !with_freqs) {
if (single && !with_freqs) { auto kernel =
auto kernel = cu::rope_single<DataType, traditional.value, forward.value>;
cu::rope_single<DataType, traditional.value, forward.value>; uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); encoder.add_kernel_node(
kernel<<<grid, block, 0, stream>>>( kernel,
(donated ? out : in).data<DataType>(), grid,
out.data<DataType>(), block,
offset.data<int32_t>(), (donated ? out : in).data<DataType>(),
scale_, out.data<DataType>(),
std::log2(base_), offset.data<int32_t>(),
mat_size, scale_,
dims); std::log2(base_),
} else if (single) { mat_size,
auto kernel = cu:: dims);
rope_single_freqs<DataType, traditional.value, forward.value>; } else if (single) {
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); auto kernel =
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); cu::rope_single_freqs<DataType, traditional.value, forward.value>;
kernel<<<grid, block, 0, stream>>>( uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
(donated ? out : in).data<DataType>(), auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
out.data<DataType>(), encoder.add_kernel_node(
offset.data<int32_t>(), kernel,
inputs[2].data<float>(), grid,
scale_, block,
mat_size, (donated ? out : in).data<DataType>(),
dims, out.data<DataType>(),
inputs[2].strides(0)); offset.data<int32_t>(),
} else if (with_freqs) { inputs[2].data<float>(),
auto kernel = scale_,
cu::rope_freqs<DataType, traditional.value, forward.value>; mat_size,
uint3 dims = dims,
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); inputs[2].strides(0));
dims.z = (dims.z + 3) / 4; } else if (with_freqs) {
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); auto kernel =
kernel<<<grid, block, 0, stream>>>( cu::rope_freqs<DataType, traditional.value, forward.value>;
(donated ? out : in).data<DataType>(), uint3 dims =
out.data<DataType>(), make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
offset.data<int32_t>(), dims.z = (dims.z + 3) / 4;
inputs[2].data<float>(), auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
scale_, encoder.add_kernel_node(
std::log2(base_), kernel,
strides, grid,
out_strides, block,
in.size() / mat_size, (donated ? out : in).data<DataType>(),
dims, out.data<DataType>(),
inputs[2].strides(0)); offset.data<int32_t>(),
} else { inputs[2].data<float>(),
auto kernel = cu::rope<DataType, traditional.value, forward.value>; scale_,
uint3 dims = std::log2(base_),
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); strides,
dims.z = (dims.z + 3) / 4; out_strides,
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); in.size() / mat_size,
kernel<<<grid, block, 0, stream>>>( dims,
(donated ? out : in).data<DataType>(), inputs[2].strides(0));
out.data<DataType>(), } else {
offset.data<int32_t>(), auto kernel = cu::rope<DataType, traditional.value, forward.value>;
scale_, uint3 dims =
std::log2(base_), make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
strides, dims.z = (dims.z + 3) / 4;
out_strides, auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
in.size() / mat_size, encoder.add_kernel_node(
dims); kernel,
} grid,
}); block,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
scale_,
std::log2(base_),
strides,
out_strides,
in.size() / mat_size,
dims);
}
}); });
}); });
}); });

View File

@ -141,19 +141,21 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) {
dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) { constexpr int N_READS = 4;
constexpr int N_READS = 4; dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
dispatch_block_dim( using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { auto kernel = cu::softmax<DataType, DataType, block_dim(), N_READS>;
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; if (precise) {
auto kernel = cu::softmax<DataType, DataType, block_dim(), N_READS>; kernel = cu::softmax<DataType, float, block_dim(), N_READS>;
if (precise) { }
kernel = cu::softmax<DataType, float, block_dim(), N_READS>; encoder.add_kernel_node(
} kernel,
kernel<<<n_rows, block_dim(), 0, stream>>>( n_rows,
in.data<DataType>(), out.data<DataType>(), axis_size); block_dim(),
}); in.data<DataType>(),
out.data<DataType>(),
axis_size);
}); });
}); });
} }

View File

@ -50,32 +50,6 @@ array swapaxes_in_eval(const array& in, int axis1, int axis2) {
return out; return out;
} }
template <typename... Args>
void segmented_sort_pairs(cu::CommandEncoder& encoder, Args&&... args) {
// Allocate temporary storage.
size_t size;
CHECK_CUDA_ERROR(
cub::DeviceSegmentedSort::StableSortPairs(nullptr, size, args...));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Run op.
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
temp.data<void>(), size, args...));
}
template <typename... Args>
void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
// Allocate temporary storage.
size_t size;
CHECK_CUDA_ERROR(
cub::DeviceSegmentedSort::StableSortKeys(nullptr, size, args...));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Run op.
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
temp.data<void>(), size, args...));
}
struct OffsetTransform { struct OffsetTransform {
int nsort; int nsort;
@ -113,57 +87,94 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_all_types(in.dtype(), [&](auto type_tag) { using CTYPE = MLX_GET_TYPE(type_tag);
using CTYPE = MLX_GET_TYPE(type_tag); auto& stream = encoder.stream();
if constexpr (!std::is_same_v<CTYPE, complex64_t>) { if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
using Type = cuda_type_t<CTYPE>; using Type = cuda_type_t<CTYPE>;
auto offsets = thrust::make_transform_iterator( auto offsets = thrust::make_transform_iterator(
thrust::make_counting_iterator(0), OffsetTransform{nsort}); thrust::make_counting_iterator(0), OffsetTransform{nsort});
if (argsort) { if (argsort) {
// Indices in the sorted dimension. // Indices in the sorted dimension.
array indices( array indices(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
allocator::malloc(out.nbytes()), in.shape(), out.dtype()); encoder.add_temporary(indices);
encoder.add_temporary(indices);
thrust::transform(
cu::thrust_policy(stream),
thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(indices.data_size()),
thrust::device_pointer_cast(indices.data<uint32_t>()),
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
// In argsort though we don't need the result of sorted values, the // In argsort though we don't need the result of sorted values, the
// API requires us to provide an array to store it. // API requires us to provide an array to store it.
array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype()); array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype());
encoder.add_temporary(discard); encoder.add_temporary(discard);
segmented_sort_pairs( size_t size;
encoder, CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
in.data<Type>(), nullptr,
discard.data<Type>(), size,
indices.data<uint32_t>(), in.data<Type>(),
out.data<uint32_t>(), discard.data<Type>(),
in.data_size(), indices.data<uint32_t>(),
in.data_size() / nsort, out.data<uint32_t>(),
offsets, in.data_size(),
offsets + 1, in.data_size() / nsort,
stream); offsets,
} else { offsets + 1,
segmented_sort( stream));
encoder,
in.data<Type>(), array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
out.data<Type>(), encoder.add_temporary(temp);
in.data_size(),
in.data_size() / nsort, // Start capturing after allocations
offsets, auto capture = encoder.capture_context();
offsets + 1, thrust::transform(
stream); cu::thrust_policy(stream),
} thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(indices.data_size()),
thrust::device_pointer_cast(indices.data<uint32_t>()),
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
temp.data<void>(),
size,
in.data<Type>(),
discard.data<Type>(),
indices.data<uint32_t>(),
out.data<uint32_t>(),
in.data_size(),
in.data_size() / nsort,
offsets,
offsets + 1,
stream));
} else { } else {
throw std::runtime_error( size_t size;
"CUDA backend does not support sorting complex numbers"); CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
nullptr,
size,
in.data<Type>(),
out.data<Type>(),
in.data_size(),
in.data_size() / nsort,
offsets,
offsets + 1,
stream));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Start capturing after allocations
auto capture = encoder.capture_context();
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
temp.data<void>(),
size,
in.data<Type>(),
out.data<Type>(),
in.data_size(),
in.data_size() / nsort,
offsets,
offsets + 1,
stream));
} }
}); } else {
throw std::runtime_error(
"CUDA backend does not support sorting complex numbers");
}
}); });
if (!is_segmented_sort) { if (!is_segmented_sort) {

View File

@ -91,73 +91,80 @@ void ternary_op_gpu_inplace(
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_input_array(c); encoder.set_input_array(c);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(out.dtype(), [&](auto type_tag) {
dispatch_all_types(out.dtype(), [&](auto type_tag) { using DType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using DType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto topt = get_ternary_op_type(a, b, c); auto topt = get_ternary_op_type(a, b, c);
if (topt == TernaryOpType::General) { if (topt == TernaryOpType::General) {
dispatch_bool( dispatch_bool(
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX, c.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
[&](auto large) { [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, int32_t>; using IdxT = std::conditional_t<large(), int64_t, int32_t>;
Shape shape; Shape shape;
std::vector<Strides> strides; std::vector<Strides> strides;
std::tie(shape, strides) = collapse_contiguous_dims(a, b, c, out); std::tie(shape, strides) = collapse_contiguous_dims(a, b, c, out);
auto& a_strides = strides[0]; auto& a_strides = strides[0];
auto& b_strides = strides[1]; auto& b_strides = strides[1];
auto& c_strides = strides[2]; auto& c_strides = strides[2];
int ndim = shape.size(); int ndim = shape.size();
if (ndim <= 3) { if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = auto kernel =
cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>; cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
out.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides),
const_param<dims_constant()>(c_strides));
});
} else {
auto kernel = cu::ternary_g<Op, DType, IdxT>;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large()); get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>( encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
a.data<bool>(), a.data<bool>(),
b.data<DType>(), b.data<DType>(),
c.data<DType>(), c.data<DType>(),
out.data<DType>(), out.data<DType>(),
out.data_size(), out.size(),
const_param(shape), const_param<dims_constant()>(shape),
const_param(a_strides), const_param<dims_constant()>(a_strides),
const_param(b_strides), const_param<dims_constant()>(b_strides),
const_param(c_strides), const_param<dims_constant()>(c_strides));
ndim); });
} } else {
}); auto kernel = cu::ternary_g<Op, DType, IdxT>;
} else { auto [num_blocks, block_dims] =
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { get_launch_args(kernel, out, large());
using IdxT = std::conditional_t<large(), int64_t, uint32_t>; encoder.add_kernel_node(
auto kernel = cu::ternary_v<Op, DType, IdxT>; kernel,
auto [num_blocks, block_dims] = get_launch_args( num_blocks,
kernel, out.data_size(), out.shape(), out.strides(), large()); block_dims,
kernel<<<num_blocks, block_dims, 0, stream>>>( a.data<bool>(),
a.data<bool>(), b.data<DType>(),
b.data<DType>(), c.data<DType>(),
c.data<DType>(), out.data<DType>(),
out.data<DType>(), out.data_size(),
out.data_size()); const_param(shape),
}); const_param(a_strides),
} const_param(b_strides),
}); const_param(c_strides),
ndim);
}
});
} else {
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
auto kernel = cu::ternary_v<Op, DType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), large());
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
out.data_size());
});
}
}); });
} }

View File

@ -9,14 +9,38 @@
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
namespace mlx::core { namespace mlx::core {
namespace cu { namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void unary_v(const In* in, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(in[index]);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void unary_g(
const In* in,
Out* out,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto idx = elem_to_loc_4d(index, shape.data(), strides.data(), ndim);
out[index] = Op{}(in[idx]);
}
}
template <typename Op, typename In, typename Out> template <typename Op, typename In, typename Out>
constexpr bool supports_unary_op() { constexpr bool supports_unary_op() {
if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> || if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||
@ -71,38 +95,61 @@ void unary_op_gpu_inplace(
if (in.size() == 0) { if (in.size() == 0) {
return; return;
} }
bool contig = in.flags().contiguous;
bool large;
if (!contig) {
large = in.data_size() > INT32_MAX || out.size() > INT32_MAX;
} else {
large = in.data_size() > UINT32_MAX;
}
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) { using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_IN = MLX_GET_TYPE(in_type_tag); using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) { dispatch_bool(large, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
using InType = cuda_type_t<CTYPE_IN>; using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>; using OutType = cuda_type_t<CTYPE_OUT>;
auto policy = cu::thrust_policy(stream); using IdxT = std::conditional_t<large(), int64_t, int32_t>;
auto in_ptr = thrust::device_pointer_cast(in.data<InType>()); if (contig) {
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>()); auto kernel = cu::unary_v<Op, InType, OutType, IdxT>;
if (in.flags().contiguous) { auto [num_blocks, block_dims] = get_launch_args(
thrust::transform( kernel, out.data_size(), out.shape(), out.strides(), large);
policy, in_ptr, in_ptr + in.data_size(), out_ptr, Op()); encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
in.data<InType>(),
out.data<OutType>(),
out.data_size());
} else { } else {
auto [shape, strides] = collapse_contiguous_dims(in); auto [shape, strides] = collapse_contiguous_dims(in);
auto [in_begin, in_end] = cu::make_general_iterators<int64_t>( auto kernel = cu::unary_g<Op, InType, OutType, IdxT>;
in_ptr, in.size(), shape, strides); auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
thrust::transform(policy, in_begin, in_end, out_ptr, Op()); encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
in.data<InType>(),
out.data<OutType>(),
out.data_size(),
const_param(shape),
const_param(strides),
shape.size());
} }
} else { });
throw std::runtime_error(fmt::format( } else {
"Can not do unary op {} on input of {} with output of {}.", throw std::runtime_error(fmt::format(
op, "Can not do unary op {} on input of {} with output of {}.",
dtype_to_string(in.dtype()), op,
dtype_to_string(out.dtype()))); dtype_to_string(in.dtype()),
} dtype_to_string(out.dtype())));
}); }
}); });
}); });
} }

View File

@ -24,6 +24,14 @@ void check_cuda_error(const char* name, cudaError_t err) {
} }
} }
void check_cuda_error(const char* name, CUresult err) {
if (err != CUDA_SUCCESS) {
const char* err_str = "Unknown error";
cuGetErrorString(err, &err_str);
throw std::runtime_error(fmt::format("{} failed: {}", name, err_str));
}
}
const char* dtype_to_cuda_type(const Dtype& dtype) { const char* dtype_to_cuda_type(const Dtype& dtype) {
switch (dtype) { switch (dtype) {
case bool_: case bool_:

View File

@ -4,6 +4,7 @@
#pragma once #pragma once
#include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
namespace mlx::core { namespace mlx::core {
@ -33,6 +34,7 @@ class CudaStream {
// Throw exception if the cuda API does not succeed. // Throw exception if the cuda API does not succeed.
void check_cuda_error(const char* name, cudaError_t err); void check_cuda_error(const char* name, cudaError_t err);
void check_cuda_error(const char* name, CUresult err);
// The macro version that prints the command that failed. // The macro version that prints the command that failed.
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) #define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))

View File

@ -688,7 +688,7 @@ array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) {
perm = expand_dims(perm, -1, s); perm = expand_dims(perm, -1, s);
take_axis -= 1; take_axis -= 1;
} }
auto pb = take_along_axis(b, perm, take_axis); auto pb = take_along_axis(b, perm, take_axis, s);
auto y = solve_triangular(luf[1], pb, /* upper = */ false, s); auto y = solve_triangular(luf[1], pb, /* upper = */ false, s);
return solve_triangular(luf[2], y, /* upper = */ true, s); return solve_triangular(luf[2], y, /* upper = */ true, s);
} }

View File

@ -391,9 +391,11 @@ class TestLoad(mlx_tests.MLXTestCase):
scale = mx.array(2.0) scale = mx.array(2.0)
y = mx.load(save_file) y = mx.load(save_file)
mx.eval(y) mx.eval(y)
mx.synchronize()
load_only = mx.get_peak_memory() load_only = mx.get_peak_memory()
y = mx.load(save_file) * scale y = mx.load(save_file) * scale
mx.eval(y) mx.eval(y)
mx.synchronize()
load_with_binary = mx.get_peak_memory() load_with_binary = mx.get_peak_memory()
self.assertEqual(load_only, load_with_binary) self.assertEqual(load_only, load_with_binary)