mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
10 Commits
v0.26.2
...
9794ec6b8e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9794ec6b8e | ||
|
|
e0bb9f3ef8 | ||
|
|
5b089dc5da | ||
|
|
af74818528 | ||
|
|
0d30e9e8ec | ||
|
|
0e0d9ac522 | ||
|
|
8917022deb | ||
|
|
ec0d5db67b | ||
|
|
e76e9b87f0 | ||
|
|
cfb6a244ea |
@@ -192,6 +192,17 @@ void time_reductions() {
|
||||
|
||||
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
||||
TIME(argmin_along_1);
|
||||
|
||||
auto indices = mlx::core::array({1});
|
||||
auto updates = mlx::core::reshape(mlx::core::array({NAN}), {1, 1, 1});
|
||||
std::vector<int> axes{0};
|
||||
auto b = scatter(a, {indices}, updates, axes);
|
||||
mx::eval(b);
|
||||
|
||||
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
|
||||
TIME(max_along_0);
|
||||
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
|
||||
TIME(max_along_1);
|
||||
}
|
||||
|
||||
void time_gather_scatter() {
|
||||
|
||||
@@ -51,6 +51,13 @@ def time_maximum():
|
||||
time_fn(mx.maximum, a, b)
|
||||
|
||||
|
||||
def time_max():
|
||||
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||
a[1, 1] = mx.nan
|
||||
mx.eval(a)
|
||||
time_fn(mx.max, a, 0)
|
||||
|
||||
|
||||
def time_negative():
|
||||
a = mx.random.uniform(shape=(10000, 1000))
|
||||
mx.eval(a)
|
||||
@@ -108,6 +115,7 @@ if __name__ == "__main__":
|
||||
|
||||
time_add()
|
||||
time_matmul()
|
||||
time_max()
|
||||
time_maximum()
|
||||
time_exp()
|
||||
time_negative()
|
||||
|
||||
@@ -12,16 +12,11 @@ namespace mlx::core {
|
||||
inline std::tuple<Shape, Strides, Strides> collapse_batches(
|
||||
const array& a,
|
||||
const array& b) {
|
||||
// Get and check the shape for the batched dims
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
if (A_bshape != B_bshape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
|
||||
<< a.shape() << ", B " << b.shape() << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
if (a.ndim() == 2) {
|
||||
return {{1}, {0}, {0}};
|
||||
}
|
||||
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
|
||||
@@ -42,17 +37,11 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
|
||||
|
||||
inline std::tuple<Shape, Strides, Strides, Strides>
|
||||
collapse_batches(const array& a, const array& b, const array& c) {
|
||||
// Get and check the shape for the batched dims
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
Shape C_bshape{c.shape().begin(), c.shape().end() - 2};
|
||||
if (A_bshape != B_bshape || A_bshape != C_bshape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
|
||||
<< a.shape() << ", B " << b.shape() << ", B " << c.shape() << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
if (a.ndim() == 2) {
|
||||
return {{1}, {0}, {0}, {0}};
|
||||
}
|
||||
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
|
||||
|
||||
@@ -151,30 +151,29 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr uint32_t N_READS = 4;
|
||||
dispatch_block_dim(
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
||||
auto kernel =
|
||||
cu::arg_reduce_general<T, cu::ArgMax<T>, block_dim(), N_READS>;
|
||||
if (reduce_type_ == ArgReduce::ArgMin) {
|
||||
kernel = cu::
|
||||
arg_reduce_general<T, cu::ArgMin<T>, block_dim(), N_READS>;
|
||||
}
|
||||
kernel<<<num_blocks, block_dim(), 0, stream>>>(
|
||||
in.data<T>(),
|
||||
out.data<uint32_t>(),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(in_strides),
|
||||
const_param(out_strides),
|
||||
ndim,
|
||||
axis_stride,
|
||||
axis_size);
|
||||
});
|
||||
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr uint32_t N_READS = 4;
|
||||
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
||||
auto kernel =
|
||||
cu::arg_reduce_general<T, cu::ArgMax<T>, block_dim(), N_READS>;
|
||||
if (reduce_type_ == ArgReduce::ArgMin) {
|
||||
kernel = cu::arg_reduce_general<T, cu::ArgMin<T>, block_dim(), N_READS>;
|
||||
}
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dim(),
|
||||
in.data<T>(),
|
||||
out.data<uint32_t>(),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(in_strides),
|
||||
const_param(out_strides),
|
||||
ndim,
|
||||
axis_stride,
|
||||
axis_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -139,90 +139,92 @@ void binary_op_gpu_inplace(
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
|
||||
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
|
||||
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = cuda_type_t<CTYPE_IN>;
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
if (bopt == BinaryOpType::General) {
|
||||
dispatch_bool(
|
||||
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
out.data_size() > INT32_MAX,
|
||||
[&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
Shape shape;
|
||||
std::vector<Strides> strides;
|
||||
std::tie(shape, strides) =
|
||||
collapse_contiguous_dims(a, b, out);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel = cu::binary_g_nd<
|
||||
Op,
|
||||
InType,
|
||||
OutType,
|
||||
IdxT,
|
||||
dims_constant()>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.size(),
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(a_strides),
|
||||
const_param<dims_constant()>(b_strides));
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
||||
dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
|
||||
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
|
||||
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = cuda_type_t<CTYPE_IN>;
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
if (bopt == BinaryOpType::General) {
|
||||
dispatch_bool(
|
||||
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
out.data_size() > INT32_MAX,
|
||||
[&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
Shape shape;
|
||||
std::vector<Strides> strides;
|
||||
std::tie(shape, strides) = collapse_contiguous_dims(a, b, out);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel = cu::
|
||||
binary_g_nd<Op, InType, OutType, IdxT, dims_constant()>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorVector) {
|
||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.data_size());
|
||||
});
|
||||
}
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(a_strides),
|
||||
const_param<dims_constant()>(b_strides));
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large());
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do binary op {} on inputs of {} with result of {}.",
|
||||
op,
|
||||
dtype_to_string(a.dtype()),
|
||||
dtype_to_string(out.dtype())));
|
||||
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorVector) {
|
||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), large());
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.data_size());
|
||||
});
|
||||
}
|
||||
});
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do binary op {} on inputs of {} with result of {}.",
|
||||
op,
|
||||
dtype_to_string(a.dtype()),
|
||||
dtype_to_string(out.dtype())));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -137,98 +137,101 @@ void binary_op_gpu_inplace(
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out_a);
|
||||
encoder.set_output_array(out_b);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) {
|
||||
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
|
||||
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
|
||||
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = cuda_type_t<CTYPE_IN>;
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) {
|
||||
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
|
||||
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
|
||||
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = cuda_type_t<CTYPE_IN>;
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
if (bopt == BinaryOpType::General) {
|
||||
dispatch_bool(
|
||||
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
out_a.data_size() > INT32_MAX,
|
||||
[&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
Shape shape;
|
||||
std::vector<Strides> strides;
|
||||
std::tie(shape, strides) =
|
||||
collapse_contiguous_dims(a, b, out_a);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel = cu::binary_g_nd<
|
||||
Op,
|
||||
InType,
|
||||
OutType,
|
||||
IdxT,
|
||||
dims_constant()>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out_a, large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
out_a.size(),
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(a_strides),
|
||||
const_param<dims_constant()>(b_strides));
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
if (bopt == BinaryOpType::General) {
|
||||
dispatch_bool(
|
||||
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
out_a.data_size() > INT32_MAX,
|
||||
[&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
Shape shape;
|
||||
std::vector<Strides> strides;
|
||||
std::tie(shape, strides) =
|
||||
collapse_contiguous_dims(a, b, out_a);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel = cu::
|
||||
binary_g_nd<Op, InType, OutType, IdxT, dims_constant()>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out_a, large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
out_a.size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorVector) {
|
||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel,
|
||||
out_a.data_size(),
|
||||
out_a.shape(),
|
||||
out_a.strides(),
|
||||
large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
out_a.data_size());
|
||||
});
|
||||
}
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(a_strides),
|
||||
const_param<dims_constant()>(b_strides));
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out_a, large());
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
out_a.size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do binary op {} on inputs of {} with result of {}.",
|
||||
op,
|
||||
dtype_to_string(a.dtype()),
|
||||
dtype_to_string(out_a.dtype())));
|
||||
dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorVector) {
|
||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel,
|
||||
out_a.data_size(),
|
||||
out_a.shape(),
|
||||
out_a.strides(),
|
||||
large());
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
out_a.data_size());
|
||||
});
|
||||
}
|
||||
});
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do binary op {} on inputs of {} with result of {}.",
|
||||
op,
|
||||
dtype_to_string(a.dtype()),
|
||||
dtype_to_string(out_a.dtype())));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/jit_module.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -178,6 +179,7 @@ void Compiled::eval_gpu(
|
||||
// Whether to use large index.
|
||||
bool large = compiled_use_large_index(inputs, outputs, contiguous);
|
||||
|
||||
cu::KernelArgs args;
|
||||
// Put inputs.
|
||||
int strides_index = 1;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
@@ -185,26 +187,26 @@ void Compiled::eval_gpu(
|
||||
continue;
|
||||
}
|
||||
const auto& x = inputs[i];
|
||||
mod.append_arg(x);
|
||||
args.append(x);
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
mod.append_arg(strides_vec[strides_index++]);
|
||||
args.append_ptr(strides_vec[strides_index++].data());
|
||||
}
|
||||
}
|
||||
|
||||
// Put outputs.
|
||||
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||
for (auto& x : outputs) {
|
||||
mod.append_arg(x);
|
||||
args.append(x);
|
||||
}
|
||||
|
||||
// Put shape and size.
|
||||
if (!contiguous) {
|
||||
mod.append_arg(shape);
|
||||
args.append_ptr(shape.data());
|
||||
}
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(outputs[0].data_size());
|
||||
args.append<int64_t>(outputs[0].data_size());
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(outputs[0].data_size());
|
||||
args.append<uint32_t>(outputs[0].data_size());
|
||||
}
|
||||
|
||||
// Launch kernel.
|
||||
@@ -222,9 +224,10 @@ void Compiled::eval_gpu(
|
||||
for (const auto& out : outputs) {
|
||||
encoder.set_output_array(out);
|
||||
}
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
mod.launch_kernel(stream, kernel_name, outputs[0], large);
|
||||
});
|
||||
|
||||
auto kernel = mod.get_kernel(kernel_name);
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large);
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -35,24 +35,25 @@ void copy_contiguous(
|
||||
array& out,
|
||||
int64_t in_offset,
|
||||
int64_t out_offset) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
auto kernel = cu::copy_s<InType, OutType, IdxT>;
|
||||
if (ctype == CopyType::Vector) {
|
||||
kernel = cu::copy_v<InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>() + in_offset,
|
||||
out.data<OutType>() + out_offset,
|
||||
out.data_size());
|
||||
});
|
||||
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
|
||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
auto kernel = cu::copy_s<InType, OutType, IdxT>;
|
||||
if (ctype == CopyType::Vector) {
|
||||
kernel = cu::copy_v<InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), large());
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
in.data<InType>() + in_offset,
|
||||
out.data<OutType>() + out_offset,
|
||||
out.data_size());
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -55,50 +55,54 @@ void copy_general(
|
||||
const Shape& shape,
|
||||
const Strides& strides_in,
|
||||
const Strides& strides_out) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
dispatch_bool(
|
||||
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||
[&](auto large) {
|
||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
int ndim = shape.size();
|
||||
size_t data_size = 1;
|
||||
for (auto& s : shape)
|
||||
data_size *= s;
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
||||
auto kernel =
|
||||
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, data_size, shape, out.strides(), large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
data_size,
|
||||
const_param<ndim_constant()>(shape),
|
||||
const_param<ndim_constant()>(strides_in),
|
||||
const_param<ndim_constant()>(strides_out));
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
||||
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
dispatch_bool(
|
||||
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||
[&](auto large) {
|
||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
int ndim = shape.size();
|
||||
size_t data_size = 1;
|
||||
for (auto& s : shape)
|
||||
data_size *= s;
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
||||
auto kernel =
|
||||
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, data_size, shape, out.strides(), large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
data_size,
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
const_param(strides_out),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
});
|
||||
const_param<ndim_constant()>(shape),
|
||||
const_param<ndim_constant()>(strides_in),
|
||||
const_param<ndim_constant()>(strides_out));
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, data_size, shape, out.strides(), large());
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
data_size,
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
const_param(strides_out),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -61,54 +61,55 @@ void copy_general_dynamic(
|
||||
const Strides& strides_out,
|
||||
const array& dynamic_offset_in,
|
||||
const array& dynamic_offset_out) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
dispatch_bool(
|
||||
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||
[&](auto large) {
|
||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel = cu::copy_gg_dynamic_nd<
|
||||
InType,
|
||||
OutType,
|
||||
IdxT,
|
||||
dims_constant()>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(strides_in),
|
||||
const_param<dims_constant()>(strides_out),
|
||||
dynamic_offset_in.data<int64_t>(),
|
||||
dynamic_offset_out.data<int64_t>());
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
|
||||
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
dispatch_bool(
|
||||
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||
[&](auto large) {
|
||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel = cu::
|
||||
copy_gg_dynamic_nd<InType, OutType, IdxT, dims_constant()>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
const_param(strides_out),
|
||||
ndim,
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(strides_in),
|
||||
const_param<dims_constant()>(strides_out),
|
||||
dynamic_offset_in.data<int64_t>(),
|
||||
dynamic_offset_out.data<int64_t>());
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large());
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
const_param(strides_out),
|
||||
ndim,
|
||||
dynamic_offset_in.data<int64_t>(),
|
||||
dynamic_offset_out.data<int64_t>());
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -50,45 +50,49 @@ void copy_general_input(
|
||||
int64_t offset_out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
dispatch_bool(
|
||||
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||
[&](auto large) {
|
||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel =
|
||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(strides_in));
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_g<InType, OutType, IdxT>;
|
||||
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
dispatch_bool(
|
||||
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||
[&](auto large) {
|
||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel =
|
||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
});
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(strides_in));
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_g<InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large());
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -2,38 +2,28 @@
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <future>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
|
||||
// This should be less than 255
|
||||
constexpr int default_max_nodes_per_graph = 20;
|
||||
|
||||
int cuda_graph_cache_size() {
|
||||
static int cache_size = []() {
|
||||
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
|
||||
}();
|
||||
return cache_size;
|
||||
}
|
||||
|
||||
namespace cu {
|
||||
|
||||
DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {}
|
||||
|
||||
void DeviceStream::synchronize() {
|
||||
cudaStreamSynchronize(stream_);
|
||||
}
|
||||
|
||||
cudaStream_t DeviceStream::schedule_cuda_stream() {
|
||||
// TODO: Return a stream that maximizes parallelism.
|
||||
return stream_;
|
||||
}
|
||||
|
||||
cudaStream_t DeviceStream::last_cuda_stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
CommandEncoder& DeviceStream::get_encoder() {
|
||||
if (!encoder_) {
|
||||
encoder_ = std::make_unique<CommandEncoder>(*this);
|
||||
}
|
||||
return *encoder_;
|
||||
}
|
||||
|
||||
Device::Device(int device) : device_(device) {
|
||||
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
|
||||
&compute_capability_major_, cudaDevAttrComputeCapabilityMajor, device_));
|
||||
@@ -67,49 +57,254 @@ void Device::make_current() {
|
||||
}
|
||||
}
|
||||
|
||||
DeviceStream& Device::get_stream(Stream s) {
|
||||
auto it = streams_.find(s.index);
|
||||
if (it == streams_.end()) {
|
||||
it = streams_.try_emplace(s.index, *this).first;
|
||||
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
||||
}
|
||||
|
||||
CommandEncoder::CaptureContext::~CaptureContext() {
|
||||
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
|
||||
size_t num_nodes;
|
||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes));
|
||||
if (num_nodes == 1) {
|
||||
cudaGraphNode_t captured_node;
|
||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes));
|
||||
CUDA_KERNEL_NODE_PARAMS params;
|
||||
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms));
|
||||
cudaGraphNode_t node;
|
||||
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, enc.graph_, NULL, 0, ¶ms));
|
||||
enc.insert_graph_dependencies(GraphNode{node, 'K'});
|
||||
} else {
|
||||
cudaGraphNode_t node;
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaGraphAddChildGraphNode(&node, enc.graph_, NULL, 0, graph));
|
||||
enc.insert_graph_dependencies(GraphNode{node, 'G'});
|
||||
}
|
||||
CHECK_CUDA_ERROR(cudaGraphDestroy(graph));
|
||||
}
|
||||
|
||||
CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
|
||||
: enc(enc) {
|
||||
enc.in_concurrent_ = true;
|
||||
}
|
||||
|
||||
CommandEncoder::ConcurrentContext::~ConcurrentContext() {
|
||||
enc.in_concurrent_ = false;
|
||||
|
||||
// Use an empty graph node for synchronization
|
||||
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
|
||||
enc.empty_node_count_++;
|
||||
CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0));
|
||||
|
||||
// Insert the concurrent -> empty node dependencies
|
||||
for (auto& from : enc.concurrent_nodes_) {
|
||||
enc.from_nodes_.push_back(from.node);
|
||||
enc.to_nodes_.push_back(empty.node);
|
||||
enc.graph_key_ += from.id;
|
||||
enc.graph_key_ += from.node_type;
|
||||
enc.graph_key_ += empty.id;
|
||||
enc.graph_key_ += empty.node_type;
|
||||
}
|
||||
|
||||
// Insert the input -> concurrent node dependencies without updating output
|
||||
// nodes
|
||||
auto outputs = std::move(enc.active_outputs_);
|
||||
enc.insert_graph_dependencies(std::move(enc.concurrent_nodes_));
|
||||
|
||||
// Update output node to be the empty node
|
||||
for (auto o : outputs) {
|
||||
enc.node_map_.emplace(o, empty).first->second = empty;
|
||||
}
|
||||
}
|
||||
|
||||
void CommandEncoder::insert_graph_dependencies(GraphNode node) {
|
||||
if (node.node_type == 'G') {
|
||||
graph_node_count_++;
|
||||
}
|
||||
node.id = std::to_string(node_count_++);
|
||||
if (in_concurrent_) {
|
||||
concurrent_nodes_.push_back(std::move(node));
|
||||
} else {
|
||||
std::vector<GraphNode> nodes;
|
||||
nodes.push_back(std::move(node));
|
||||
insert_graph_dependencies(std::move(nodes));
|
||||
}
|
||||
}
|
||||
|
||||
void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
||||
std::vector<GraphNode> deps;
|
||||
{
|
||||
// Dependencies must be added in the same order to produce a consistent
|
||||
// topology
|
||||
std::unordered_set<cudaGraphNode_t> set_deps;
|
||||
for (auto d : active_deps_) {
|
||||
if (auto it = node_map_.find(d); it != node_map_.end()) {
|
||||
auto [_, inserted] = set_deps.insert(it->second.node);
|
||||
if (inserted) {
|
||||
deps.push_back(it->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
active_deps_.clear();
|
||||
|
||||
for (auto o : active_outputs_) {
|
||||
for (auto& node : nodes) {
|
||||
node_map_.emplace(o, node).first->second = node;
|
||||
}
|
||||
}
|
||||
active_outputs_.clear();
|
||||
|
||||
for (auto& from : deps) {
|
||||
for (auto& to : nodes) {
|
||||
from_nodes_.push_back(from.node);
|
||||
to_nodes_.push_back(to.node);
|
||||
graph_key_ += from.id;
|
||||
graph_key_ += from.node_type;
|
||||
graph_key_ += to.id;
|
||||
graph_key_ += to.node_type;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CommandEncoder& Device::get_command_encoder(Stream s) {
|
||||
auto it = encoders_.find(s.index);
|
||||
if (it == encoders_.end()) {
|
||||
it = encoders_.try_emplace(s.index, *this).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
CommandEncoder::CommandEncoder(DeviceStream& s)
|
||||
: device_(s.device()), stream_(s) {}
|
||||
CommandEncoder::CommandEncoder(Device& d) : stream_(d) {
|
||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
||||
}
|
||||
|
||||
void clear_graphs(std::unordered_map<std::string, cudaGraphExec_t>& graphs) {
|
||||
for (auto& [_, graph_exec] : graphs) {
|
||||
CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec));
|
||||
}
|
||||
graphs.clear();
|
||||
}
|
||||
|
||||
CommandEncoder::~CommandEncoder() {
|
||||
clear_graphs(graph_cache_);
|
||||
}
|
||||
|
||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||
worker_.add_task(std::move(task));
|
||||
}
|
||||
|
||||
void CommandEncoder::end_encoding() {
|
||||
if (!temporaries_.empty()) {
|
||||
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
||||
}
|
||||
void CommandEncoder::set_input_array(const array& arr) {
|
||||
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
||||
active_deps_.push_back(id);
|
||||
}
|
||||
|
||||
// There is no kernel running, run completion handlers immediately.
|
||||
if (!has_gpu_work_) {
|
||||
worker_.consume_in_this_thread();
|
||||
return;
|
||||
}
|
||||
has_gpu_work_ = false;
|
||||
void CommandEncoder::set_output_array(const array& arr) {
|
||||
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
||||
active_deps_.push_back(id);
|
||||
active_outputs_.push_back(id);
|
||||
}
|
||||
|
||||
// Put completion handlers in a batch.
|
||||
worker_.end_batch();
|
||||
|
||||
// Signaling kernel completion is expensive, delay until enough batches.
|
||||
// TODO: This number is arbitrarily picked, profile for a better stragety.
|
||||
if (worker_.uncommited_batches() > 8) {
|
||||
void CommandEncoder::maybe_commit() {
|
||||
if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) {
|
||||
commit();
|
||||
}
|
||||
}
|
||||
|
||||
void CommandEncoder::add_kernel_node(
|
||||
void* func,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
void** params) {
|
||||
cudaKernelNodeParams kernel_params = {0};
|
||||
kernel_params.func = func;
|
||||
kernel_params.gridDim = grid_dim;
|
||||
kernel_params.blockDim = block_dim;
|
||||
kernel_params.kernelParams = params;
|
||||
cudaGraphNode_t node;
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
|
||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||
}
|
||||
|
||||
void CommandEncoder::add_kernel_node(
|
||||
CUfunction func,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
void** params) {
|
||||
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
|
||||
kernel_params.func = func;
|
||||
kernel_params.gridDimX = grid_dim.x;
|
||||
kernel_params.gridDimY = grid_dim.y;
|
||||
kernel_params.gridDimZ = grid_dim.z;
|
||||
kernel_params.blockDimX = block_dim.x;
|
||||
kernel_params.blockDimY = block_dim.y;
|
||||
kernel_params.blockDimZ = block_dim.z;
|
||||
kernel_params.kernelParams = params;
|
||||
CUgraphNode node;
|
||||
CHECK_CUDA_ERROR(
|
||||
cuGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
|
||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||
}
|
||||
|
||||
void CommandEncoder::commit() {
|
||||
worker_.commit(stream_.last_cuda_stream());
|
||||
if (!temporaries_.empty()) {
|
||||
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
||||
}
|
||||
if (node_count_ > 0) {
|
||||
if (!from_nodes_.empty()) {
|
||||
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
|
||||
graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size()));
|
||||
}
|
||||
|
||||
graph_key_ += ".";
|
||||
graph_key_ += std::to_string(node_count_);
|
||||
graph_key_ += ".";
|
||||
graph_key_ += std::to_string(graph_node_count_);
|
||||
graph_key_ += ".";
|
||||
graph_key_ += std::to_string(empty_node_count_);
|
||||
auto [it, _] = graph_cache_.emplace(graph_key_, nullptr);
|
||||
auto& graph_exec = it->second;
|
||||
|
||||
if (graph_exec != NULL) {
|
||||
cudaGraphExecUpdateResultInfo update_result;
|
||||
cudaGraphExecUpdate(graph_exec, graph_, &update_result);
|
||||
if (update_result.result != cudaGraphExecUpdateSuccess) {
|
||||
cudaGetLastError();
|
||||
CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec));
|
||||
graph_exec = NULL;
|
||||
}
|
||||
}
|
||||
if (graph_exec == NULL) {
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0));
|
||||
}
|
||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||
|
||||
// TODO smarter cache policy
|
||||
if (graph_cache_.size() > cuda_graph_cache_size()) {
|
||||
clear_graphs(graph_cache_);
|
||||
}
|
||||
|
||||
// Reset state
|
||||
node_count_ = 0;
|
||||
graph_node_count_ = 0;
|
||||
from_nodes_.clear();
|
||||
to_nodes_.clear();
|
||||
graph_key_.clear();
|
||||
node_map_.clear();
|
||||
CHECK_CUDA_ERROR(cudaGraphDestroy(graph_));
|
||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
||||
}
|
||||
|
||||
// Put completion handlers in a batch.
|
||||
worker_.end_batch();
|
||||
worker_.commit(stream_);
|
||||
}
|
||||
|
||||
void CommandEncoder::synchronize() {
|
||||
stream().synchronize();
|
||||
cudaStreamSynchronize(stream_);
|
||||
auto p = std::make_shared<std::promise<void>>();
|
||||
std::future<void> f = p->get_future();
|
||||
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
||||
@@ -127,12 +322,8 @@ Device& device(mlx::core::Device device) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
DeviceStream& get_stream(Stream s) {
|
||||
return device(s.device).get_stream(s);
|
||||
}
|
||||
|
||||
CommandEncoder& get_command_encoder(Stream s) {
|
||||
return get_stream(s).get_encoder();
|
||||
return device(s.device).get_command_encoder(s);
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
@@ -7,41 +7,108 @@
|
||||
#include "mlx/stream.h"
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <cuda.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
class Device;
|
||||
class CommandEncoder;
|
||||
|
||||
class DeviceStream {
|
||||
class CommandEncoder {
|
||||
public:
|
||||
explicit DeviceStream(Device& device);
|
||||
struct CaptureContext {
|
||||
CaptureContext(CommandEncoder& enc);
|
||||
~CaptureContext();
|
||||
cudaGraph_t graph;
|
||||
CommandEncoder& enc;
|
||||
};
|
||||
struct ConcurrentContext {
|
||||
ConcurrentContext(CommandEncoder& enc);
|
||||
~ConcurrentContext();
|
||||
CommandEncoder& enc;
|
||||
};
|
||||
|
||||
DeviceStream(const DeviceStream&) = delete;
|
||||
DeviceStream& operator=(const DeviceStream&) = delete;
|
||||
explicit CommandEncoder(Device& d);
|
||||
~CommandEncoder();
|
||||
|
||||
// Wait until kernels in the stream complete.
|
||||
void synchronize();
|
||||
CommandEncoder(const CommandEncoder&) = delete;
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
|
||||
// Return a cuda stream for launching kernels.
|
||||
cudaStream_t schedule_cuda_stream();
|
||||
|
||||
// Return the last cuda stream used.
|
||||
cudaStream_t last_cuda_stream();
|
||||
|
||||
CommandEncoder& get_encoder();
|
||||
|
||||
Device& device() {
|
||||
return device_;
|
||||
CaptureContext capture_context() {
|
||||
return CaptureContext{*this};
|
||||
}
|
||||
ConcurrentContext concurrent_context() {
|
||||
return ConcurrentContext{*this};
|
||||
}
|
||||
|
||||
void set_input_array(const array& arr);
|
||||
void set_output_array(const array& arr);
|
||||
|
||||
template <typename F, typename... Params>
|
||||
void
|
||||
add_kernel_node(F* func, dim3 grid_dim, dim3 block_dim, Params&&... params) {
|
||||
constexpr size_t num = sizeof...(Params);
|
||||
void* ptrs[num];
|
||||
size_t i = 0;
|
||||
([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(
|
||||
std::forward<Params>(params)),
|
||||
...);
|
||||
add_kernel_node((void*)func, grid_dim, block_dim, ptrs);
|
||||
}
|
||||
|
||||
void add_kernel_node(
|
||||
CUfunction func,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
void** params);
|
||||
|
||||
void
|
||||
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params);
|
||||
|
||||
void add_temporary(const array& arr) {
|
||||
temporaries_.push_back(arr.data_shared_ptr());
|
||||
}
|
||||
|
||||
void add_completed_handler(std::function<void()> task);
|
||||
void maybe_commit();
|
||||
void commit();
|
||||
|
||||
CudaStream& stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
// Wait until kernels and completion handlers are finished
|
||||
void synchronize();
|
||||
|
||||
private:
|
||||
Device& device_;
|
||||
struct GraphNode {
|
||||
cudaGraphNode_t node;
|
||||
// K = kernel
|
||||
// E = empty
|
||||
// G = subgraph
|
||||
char node_type;
|
||||
std::string id;
|
||||
};
|
||||
|
||||
void insert_graph_dependencies(GraphNode node);
|
||||
void insert_graph_dependencies(std::vector<GraphNode> nodes);
|
||||
|
||||
CudaStream stream_;
|
||||
std::unique_ptr<CommandEncoder> encoder_;
|
||||
cudaGraph_t graph_;
|
||||
Worker worker_;
|
||||
char node_count_{0};
|
||||
char graph_node_count_{0};
|
||||
char empty_node_count_{0};
|
||||
bool in_concurrent_{false};
|
||||
std::vector<cudaGraphNode_t> from_nodes_;
|
||||
std::vector<cudaGraphNode_t> to_nodes_;
|
||||
std::string graph_key_;
|
||||
std::vector<GraphNode> concurrent_nodes_;
|
||||
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
||||
std::unordered_map<std::string, cudaGraphExec_t> graph_cache_;
|
||||
std::vector<std::uintptr_t> active_deps_;
|
||||
std::vector<std::uintptr_t> active_outputs_;
|
||||
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
|
||||
};
|
||||
|
||||
class Device {
|
||||
@@ -55,7 +122,7 @@ class Device {
|
||||
// Make this device the current cuda device, required by some cuda calls.
|
||||
void make_current();
|
||||
|
||||
DeviceStream& get_stream(Stream s);
|
||||
CommandEncoder& get_command_encoder(Stream s);
|
||||
|
||||
int cuda_device() const {
|
||||
return device_;
|
||||
@@ -75,67 +142,10 @@ class Device {
|
||||
int compute_capability_major_;
|
||||
int compute_capability_minor_;
|
||||
cublasLtHandle_t lt_;
|
||||
std::unordered_map<int, DeviceStream> streams_;
|
||||
};
|
||||
|
||||
class CommandEncoder {
|
||||
public:
|
||||
explicit CommandEncoder(DeviceStream& stream);
|
||||
|
||||
CommandEncoder(const CommandEncoder&) = delete;
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
|
||||
void set_input_array(const array& arr) {}
|
||||
void set_output_array(const array& arr) {}
|
||||
|
||||
void add_temporary(const array& arr) {
|
||||
temporaries_.push_back(arr.data_shared_ptr());
|
||||
}
|
||||
|
||||
void add_completed_handler(std::function<void()> task);
|
||||
void end_encoding();
|
||||
void commit();
|
||||
|
||||
// Schedule a cuda stream for |fun| to launch kernels, and check error
|
||||
// afterwards.
|
||||
template <typename F>
|
||||
void launch_kernel(F&& fun) {
|
||||
launch_kernel(stream_.schedule_cuda_stream(), std::forward<F>(fun));
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void launch_kernel(cudaStream_t stream, F&& fun) {
|
||||
device_.make_current();
|
||||
fun(stream);
|
||||
check_cuda_error("kernel launch", cudaGetLastError());
|
||||
has_gpu_work_ = true;
|
||||
}
|
||||
|
||||
Device& device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
DeviceStream& stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
bool has_gpu_work() const {
|
||||
return has_gpu_work_;
|
||||
}
|
||||
|
||||
// Wait until kernels and completion handlers are finished
|
||||
void synchronize();
|
||||
|
||||
private:
|
||||
Device& device_;
|
||||
DeviceStream& stream_;
|
||||
Worker worker_;
|
||||
bool has_gpu_work_{false};
|
||||
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
||||
std::unordered_map<int, CommandEncoder> encoders_;
|
||||
};
|
||||
|
||||
Device& device(mlx::core::Device device);
|
||||
DeviceStream& get_stream(Stream s);
|
||||
CommandEncoder& get_command_encoder(Stream s);
|
||||
|
||||
// Return an execution policy that does not sync for result.
|
||||
|
||||
@@ -37,22 +37,20 @@ void eval(array& arr) {
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
|
||||
if (encoder.has_gpu_work()) {
|
||||
// Keep used buffers alive until kernel finishes running.
|
||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.insert(in.data_shared_ptr());
|
||||
}
|
||||
for (auto& s : arr.siblings()) {
|
||||
buffers.insert(s.data_shared_ptr());
|
||||
}
|
||||
// Remove the output if it was donated to by an input.
|
||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
||||
buffers.erase(it);
|
||||
}
|
||||
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
|
||||
// Keep used buffers alive until kernel finishes running.
|
||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.insert(in.data_shared_ptr());
|
||||
}
|
||||
encoder.end_encoding();
|
||||
for (auto& s : arr.siblings()) {
|
||||
buffers.insert(s.data_shared_ptr());
|
||||
}
|
||||
// Remove the output if it was donated to by an input.
|
||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
||||
buffers.erase(it);
|
||||
}
|
||||
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
|
||||
encoder.maybe_commit();
|
||||
}
|
||||
|
||||
void finalize(Stream s) {
|
||||
|
||||
@@ -61,7 +61,9 @@ void CudaEvent::wait(Stream s) {
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [*this]() mutable { wait(); });
|
||||
} else {
|
||||
wait(cu::get_stream(s).last_cuda_stream());
|
||||
auto& enc = cu::get_command_encoder(s);
|
||||
enc.commit();
|
||||
wait(enc.stream());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +76,9 @@ void CudaEvent::record(Stream s) {
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
throw std::runtime_error("CudaEvent can not wait on cpu stream.");
|
||||
} else {
|
||||
record(cu::get_stream(s).last_cuda_stream());
|
||||
auto& enc = cu::get_command_encoder(s);
|
||||
enc.commit();
|
||||
record(enc.stream());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,11 +140,9 @@ void SharedEvent::wait(Stream s, uint64_t value) {
|
||||
scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
|
||||
} else {
|
||||
auto& encoder = get_command_encoder(s);
|
||||
encoder.launch_kernel(
|
||||
encoder.stream().last_cuda_stream(),
|
||||
[this, value](cudaStream_t stream) { wait(stream, value); });
|
||||
encoder.commit();
|
||||
wait(encoder.stream(), value);
|
||||
encoder.add_completed_handler([ac = ac_]() {});
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,11 +164,9 @@ void SharedEvent::signal(Stream s, uint64_t value) {
|
||||
scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); });
|
||||
} else {
|
||||
auto& encoder = get_command_encoder(s);
|
||||
encoder.launch_kernel(
|
||||
encoder.stream().last_cuda_stream(),
|
||||
[this, value](cudaStream_t stream) { signal(stream, value); });
|
||||
encoder.commit();
|
||||
signal(encoder.stream(), value);
|
||||
encoder.add_completed_handler([ac = ac_]() {});
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,13 +3,16 @@
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/jit_module.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include "cuda_jit_sources.h"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <fmt/format.h>
|
||||
#include <nvrtc.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
#include <cassert>
|
||||
@@ -22,7 +25,7 @@ namespace {
|
||||
constexpr const char* g_scatter_ops[] = {"Max", "Min", "Sum", "Prod", "Assign"};
|
||||
|
||||
void append_indices_arg(
|
||||
cu::JitModule& mod,
|
||||
cu::KernelArgs& args,
|
||||
const std::vector<array>& inputs,
|
||||
int nidx,
|
||||
int idx_ndim) {
|
||||
@@ -30,7 +33,7 @@ void append_indices_arg(
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
indices[i] = inputs[i + 1].data<void>();
|
||||
}
|
||||
mod.append_arg(std::move(indices));
|
||||
args.append(std::move(indices));
|
||||
std::vector<int32_t> indices_shape(nidx * idx_ndim);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
std::copy_n(
|
||||
@@ -38,7 +41,7 @@ void append_indices_arg(
|
||||
idx_ndim,
|
||||
indices_shape.data() + i * idx_ndim);
|
||||
}
|
||||
mod.append_arg(std::move(indices_shape));
|
||||
args.append(std::move(indices_shape));
|
||||
std::vector<int64_t> indices_strides(nidx * idx_ndim);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
std::copy_n(
|
||||
@@ -46,7 +49,7 @@ void append_indices_arg(
|
||||
idx_ndim,
|
||||
indices_strides.data() + i * idx_ndim);
|
||||
}
|
||||
mod.append_arg(std::move(indices_strides));
|
||||
args.append(std::move(indices_strides));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -94,20 +97,21 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return std::make_pair(jit_source_gather, std::move(kernel_names));
|
||||
});
|
||||
|
||||
mod.append_arg(src);
|
||||
mod.append_arg(out);
|
||||
cu::KernelArgs args;
|
||||
args.append(src);
|
||||
args.append(out);
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(out.size());
|
||||
args.append<int64_t>(out.size());
|
||||
} else {
|
||||
mod.append_arg<int32_t>(out.size());
|
||||
args.append<int32_t>(out.size());
|
||||
}
|
||||
mod.append_ndim_arg(src.shape());
|
||||
mod.append_ndim_arg(src.strides());
|
||||
mod.append_arg<int32_t>(src.ndim());
|
||||
mod.append_ndim_arg(slice_sizes_);
|
||||
mod.append_arg(slice_size);
|
||||
mod.append_arg(axes_);
|
||||
append_indices_arg(mod, inputs, nidx, idx_ndim);
|
||||
args.append_ndim(src.shape());
|
||||
args.append_ndim(src.strides());
|
||||
args.append<int32_t>(src.ndim());
|
||||
args.append_ndim(slice_sizes_);
|
||||
args.append(slice_size);
|
||||
args.append(axes_);
|
||||
append_indices_arg(args, inputs, nidx, idx_ndim);
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
"mlx::core::cu::gather<{}, {}, {}, {}, {}>",
|
||||
@@ -122,9 +126,10 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
mod.launch_kernel(stream, kernel_name, out, large);
|
||||
});
|
||||
|
||||
auto kernel = mod.get_kernel(kernel_name);
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
||||
}
|
||||
|
||||
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -187,26 +192,27 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return std::make_pair(jit_source_scatter, std::move(kernel_names));
|
||||
});
|
||||
|
||||
mod.append_arg(upd);
|
||||
mod.append_arg(out);
|
||||
cu::KernelArgs args;
|
||||
args.append(upd);
|
||||
args.append(out);
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(upd.size());
|
||||
args.append<int64_t>(upd.size());
|
||||
} else {
|
||||
mod.append_arg<int32_t>(upd.size());
|
||||
args.append<int32_t>(upd.size());
|
||||
}
|
||||
mod.append_ndim_arg(upd.shape());
|
||||
mod.append_ndim_arg(upd.strides());
|
||||
mod.append_arg<int32_t>(upd.ndim());
|
||||
args.append_ndim(upd.shape());
|
||||
args.append_ndim(upd.strides());
|
||||
args.append<int32_t>(upd.ndim());
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(upd_post_idx_size);
|
||||
args.append<int64_t>(upd_post_idx_size);
|
||||
} else {
|
||||
mod.append_arg<int32_t>(upd_post_idx_size);
|
||||
args.append<int32_t>(upd_post_idx_size);
|
||||
}
|
||||
mod.append_ndim_arg(out.shape());
|
||||
mod.append_ndim_arg(out.strides());
|
||||
mod.append_arg<int32_t>(out.ndim());
|
||||
mod.append_arg(axes_);
|
||||
append_indices_arg(mod, inputs, nidx, idx_ndim);
|
||||
args.append_ndim(out.shape());
|
||||
args.append_ndim(out.strides());
|
||||
args.append<int32_t>(out.ndim());
|
||||
args.append(axes_);
|
||||
append_indices_arg(args, inputs, nidx, idx_ndim);
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
"mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>",
|
||||
@@ -222,9 +228,9 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
mod.launch_kernel(stream, kernel_name, upd, large);
|
||||
});
|
||||
auto kernel = mod.get_kernel(kernel_name);
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, upd, large);
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
||||
}
|
||||
|
||||
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -275,25 +281,26 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
size_t idx_size_axis = idx.shape(axis_);
|
||||
|
||||
mod.append_arg(src);
|
||||
mod.append_arg(idx);
|
||||
mod.append_arg(out);
|
||||
cu::KernelArgs args;
|
||||
args.append(src);
|
||||
args.append(idx);
|
||||
args.append(out);
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(idx_size_pre);
|
||||
mod.append_arg<int64_t>(idx_size_axis);
|
||||
mod.append_arg<int64_t>(idx_size_post);
|
||||
args.append<int64_t>(idx_size_pre);
|
||||
args.append<int64_t>(idx_size_axis);
|
||||
args.append<int64_t>(idx_size_post);
|
||||
} else {
|
||||
mod.append_arg<int32_t>(idx_size_pre);
|
||||
mod.append_arg<int32_t>(idx_size_axis);
|
||||
mod.append_arg<int32_t>(idx_size_post);
|
||||
args.append<int32_t>(idx_size_pre);
|
||||
args.append<int32_t>(idx_size_axis);
|
||||
args.append<int32_t>(idx_size_post);
|
||||
}
|
||||
mod.append_arg(remove_index(idx.shape(), axis_));
|
||||
mod.append_arg(remove_index(src.strides(), axis_));
|
||||
mod.append_arg(remove_index(idx.strides(), axis_));
|
||||
mod.append_arg<int32_t>(axis_);
|
||||
mod.append_arg(src.shape(axis_));
|
||||
mod.append_arg(src.strides(axis_));
|
||||
mod.append_arg(idx.strides(axis_));
|
||||
args.append(remove_index(idx.shape(), axis_));
|
||||
args.append(remove_index(src.strides(), axis_));
|
||||
args.append(remove_index(idx.strides(), axis_));
|
||||
args.append<int32_t>(axis_);
|
||||
args.append(src.shape(axis_));
|
||||
args.append(src.strides(axis_));
|
||||
args.append(idx.strides(axis_));
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
"mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>",
|
||||
@@ -309,9 +316,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
mod.launch_kernel(stream, kernel_name, idx, large);
|
||||
});
|
||||
auto kernel = mod.get_kernel(kernel_name);
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
||||
}
|
||||
|
||||
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -377,25 +384,26 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
size_t idx_size_axis = idx.shape(axis_);
|
||||
|
||||
mod.append_arg(upd);
|
||||
mod.append_arg(idx);
|
||||
mod.append_arg(out);
|
||||
cu::KernelArgs args;
|
||||
args.append(upd);
|
||||
args.append(idx);
|
||||
args.append(out);
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(idx_size_pre);
|
||||
mod.append_arg<int64_t>(idx_size_axis);
|
||||
mod.append_arg<int64_t>(idx_size_post);
|
||||
args.append<int64_t>(idx_size_pre);
|
||||
args.append<int64_t>(idx_size_axis);
|
||||
args.append<int64_t>(idx_size_post);
|
||||
} else {
|
||||
mod.append_arg<int32_t>(idx_size_pre);
|
||||
mod.append_arg<int32_t>(idx_size_axis);
|
||||
mod.append_arg<int32_t>(idx_size_post);
|
||||
args.append<int32_t>(idx_size_pre);
|
||||
args.append<int32_t>(idx_size_axis);
|
||||
args.append<int32_t>(idx_size_post);
|
||||
}
|
||||
mod.append_arg(remove_index(idx.shape(), axis_));
|
||||
mod.append_arg(remove_index(upd.strides(), axis_));
|
||||
mod.append_arg(remove_index(idx.strides(), axis_));
|
||||
mod.append_arg<int32_t>(axis_);
|
||||
mod.append_arg(out.shape(axis_));
|
||||
mod.append_arg(upd.strides(axis_));
|
||||
mod.append_arg(idx.strides(axis_));
|
||||
args.append(remove_index(idx.shape(), axis_));
|
||||
args.append(remove_index(upd.strides(), axis_));
|
||||
args.append(remove_index(idx.strides(), axis_));
|
||||
args.append<int32_t>(axis_);
|
||||
args.append(out.shape(axis_));
|
||||
args.append(upd.strides(axis_));
|
||||
args.append(idx.strides(axis_));
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
"mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>",
|
||||
@@ -412,9 +420,9 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
mod.launch_kernel(stream, kernel_name, idx, large);
|
||||
});
|
||||
auto kernel = mod.get_kernel(kernel_name);
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -26,16 +26,6 @@ void check_nvrtc_error(const char* name, nvrtcResult err) {
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_CU_ERROR(cmd) check_cu_error(#cmd, (cmd))
|
||||
|
||||
void check_cu_error(const char* name, CUresult err) {
|
||||
if (err != CUDA_SUCCESS) {
|
||||
const char* err_str = "Unknown error";
|
||||
cuGetErrorString(err, &err_str);
|
||||
throw std::runtime_error(fmt::format("{} failed: {}", name, err_str));
|
||||
}
|
||||
}
|
||||
|
||||
// Return the location of the CUDA toolkit.
|
||||
const std::string& cuda_home() {
|
||||
static std::string home = []() -> std::string {
|
||||
@@ -280,60 +270,13 @@ JitModule::JitModule(
|
||||
// Load kernels.
|
||||
for (const auto& [name, mangled] : ptx_kernels) {
|
||||
CUfunction kernel;
|
||||
CHECK_CU_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
|
||||
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
|
||||
kernels_[name] = kernel;
|
||||
}
|
||||
}
|
||||
|
||||
JitModule::~JitModule() {
|
||||
CHECK_CU_ERROR(cuModuleUnload(module_));
|
||||
}
|
||||
|
||||
void JitModule::launch_kernel(
|
||||
CUstream stream,
|
||||
const std::string& kernel_name,
|
||||
const array& arr,
|
||||
bool large,
|
||||
int work_per_thread) {
|
||||
CUfunction kernel = get_kernel(kernel_name);
|
||||
size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread);
|
||||
int _, block_dim;
|
||||
CHECK_CU_ERROR(
|
||||
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
|
||||
if (block_dim > nthreads) {
|
||||
block_dim = nthreads;
|
||||
}
|
||||
Dims num_blocks{1, 1, 1};
|
||||
if (large) {
|
||||
num_blocks =
|
||||
get_2d_grid_dims_common(arr.shape(), arr.strides(), work_per_thread);
|
||||
std::get<0>(num_blocks) =
|
||||
(std::get<0>(num_blocks) + block_dim - 1) / block_dim;
|
||||
} else {
|
||||
std::get<0>(num_blocks) = (nthreads + block_dim - 1) / block_dim;
|
||||
}
|
||||
launch_kernel(stream, kernel, num_blocks, Dims{block_dim, 1, 1});
|
||||
}
|
||||
|
||||
void JitModule::launch_kernel(
|
||||
CUstream stream,
|
||||
CUfunction kernel,
|
||||
Dims num_blocks,
|
||||
Dims block_dims) {
|
||||
CHECK_CU_ERROR(cuLaunchKernel(
|
||||
kernel,
|
||||
std::get<0>(num_blocks),
|
||||
std::get<1>(num_blocks),
|
||||
std::get<2>(num_blocks),
|
||||
std::get<0>(block_dims),
|
||||
std::get<1>(block_dims),
|
||||
std::get<2>(block_dims),
|
||||
0,
|
||||
stream,
|
||||
args_.data(),
|
||||
nullptr));
|
||||
args_.clear();
|
||||
storage_.clear();
|
||||
CHECK_CUDA_ERROR(cuModuleUnload(module_));
|
||||
}
|
||||
|
||||
CUfunction JitModule::get_kernel(const std::string& kernel_name) {
|
||||
@@ -345,10 +288,6 @@ CUfunction JitModule::get_kernel(const std::string& kernel_name) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void JitModule::append_ptr_arg(const void* v) {
|
||||
args_.push_back(const_cast<void*>(v));
|
||||
}
|
||||
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
|
||||
#include <deque>
|
||||
@@ -23,72 +24,48 @@ using KernelBuilderResult = std::pair<
|
||||
/* kernel names */ std::vector<std::string>>;
|
||||
using KernelBuilder = std::function<KernelBuilderResult()>;
|
||||
|
||||
class JitModule {
|
||||
public:
|
||||
JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder);
|
||||
~JitModule();
|
||||
struct KernelArgs {
|
||||
void** args() {
|
||||
return args_.data();
|
||||
}
|
||||
|
||||
JitModule(const JitModule&) = delete;
|
||||
JitModule& operator=(const JitModule&) = delete;
|
||||
|
||||
void append_arg(const array& a) {
|
||||
append_arg(reinterpret_cast<CUdeviceptr>(a.data<void>()));
|
||||
void append(const array& a) {
|
||||
append(reinterpret_cast<CUdeviceptr>(a.data<void>()));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void append_arg(T val) {
|
||||
void append(T val) {
|
||||
storage_.emplace_back(val);
|
||||
append_ptr_arg(&storage_.back());
|
||||
append_ptr(&storage_.back());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void append_arg(std::vector<T> vec) {
|
||||
void append(std::vector<T> vec) {
|
||||
if (vec.empty()) {
|
||||
// The nullptr can not be used as arg, pass something not null.
|
||||
append_arg(std::monostate{});
|
||||
append(std::monostate{});
|
||||
} else {
|
||||
append_ptr_arg(vec.data());
|
||||
append_ptr(vec.data());
|
||||
storage_.emplace_back(std::move(vec));
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure the arg is copied to an array with size of NDIM.
|
||||
template <size_t NDIM = MAX_NDIM, typename T>
|
||||
void append_ndim_arg(const std::vector<T>& vec) {
|
||||
void append_ndim(std::vector<T> vec) {
|
||||
if (vec.size() > NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", NDIM));
|
||||
}
|
||||
std::vector<T> copied(NDIM);
|
||||
std::copy(vec.begin(), vec.end(), copied.data());
|
||||
append_arg(std::move(copied));
|
||||
vec.resize(NDIM);
|
||||
append(std::move(vec));
|
||||
}
|
||||
|
||||
// Launch kernel with |kernel_name| that each thread works on
|
||||
// |work_per_thread| elements of |arr|.
|
||||
void launch_kernel(
|
||||
CUstream stream,
|
||||
const std::string& kernel_name,
|
||||
const array& arr,
|
||||
bool large,
|
||||
int work_per_thread = 1);
|
||||
|
||||
void launch_kernel(
|
||||
CUstream stream,
|
||||
CUfunction kernel,
|
||||
Dims num_blocks,
|
||||
Dims block_dims);
|
||||
|
||||
CUfunction get_kernel(const std::string& kernel_name);
|
||||
void append_ptr(const void* v) {
|
||||
args_.push_back(const_cast<void*>(v));
|
||||
}
|
||||
|
||||
private:
|
||||
void append_ptr_arg(const void* v);
|
||||
|
||||
CUmodule module_{nullptr};
|
||||
std::unordered_map<std::string, CUfunction> kernels_;
|
||||
std::vector<void*> args_;
|
||||
|
||||
// The cuLaunchKernel API requires passing pointers to arguments so store
|
||||
@@ -105,6 +82,23 @@ class JitModule {
|
||||
std::deque<Arg> storage_;
|
||||
};
|
||||
|
||||
class JitModule {
|
||||
public:
|
||||
JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder);
|
||||
~JitModule();
|
||||
|
||||
JitModule(const JitModule&) = delete;
|
||||
JitModule& operator=(const JitModule&) = delete;
|
||||
CUfunction get_kernel(const std::string& kernel_name);
|
||||
|
||||
private:
|
||||
CUmodule module_{nullptr};
|
||||
std::unordered_map<std::string, CUfunction> kernels_;
|
||||
};
|
||||
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <fmt/format.h>
|
||||
@@ -120,7 +121,13 @@ std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
|
||||
template <typename T>
|
||||
inline uint max_occupancy_block_dim(T kernel) {
|
||||
int _, block_dim;
|
||||
CHECK_CUDA_ERROR(cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
|
||||
if constexpr (std::is_same_v<T, CUfunction>) {
|
||||
CHECK_CUDA_ERROR(
|
||||
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
|
||||
} else {
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
|
||||
}
|
||||
return block_dim;
|
||||
}
|
||||
|
||||
|
||||
@@ -258,23 +258,23 @@ void LayerNorm::eval_gpu(
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) {
|
||||
constexpr uint32_t N_READS = 4;
|
||||
dispatch_block_dim(
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::layer_norm<DataType, block_dim(), N_READS>;
|
||||
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
b.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride,
|
||||
b_stride);
|
||||
});
|
||||
dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) {
|
||||
constexpr uint32_t N_READS = 4;
|
||||
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::layer_norm<DataType, block_dim(), N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
n_rows,
|
||||
block_dim(),
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
b.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride,
|
||||
b_stride);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -289,21 +289,25 @@ void LayerNormVJP::eval_gpu(
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
auto check_input = [&s](const array& x) -> std::pair<array, bool> {
|
||||
auto check_input = [&s](const array& x, bool& copied) {
|
||||
if (x.flags().row_contiguous) {
|
||||
return {x, false};
|
||||
copied = false;
|
||||
return x;
|
||||
}
|
||||
copied = true;
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
return {x_copy, true};
|
||||
return x_copy;
|
||||
};
|
||||
bool donate_x = inputs[0].is_donatable();
|
||||
bool donate_g = inputs[3].is_donatable();
|
||||
auto [x, copied] = check_input(inputs[0]);
|
||||
bool copied;
|
||||
auto x = check_input(inputs[0], copied);
|
||||
donate_x |= copied;
|
||||
const array& w = inputs[1];
|
||||
const array& b = inputs[2];
|
||||
auto [g, g_copied] = check_input(inputs[3]);
|
||||
bool g_copied;
|
||||
auto g = check_input(inputs[3], g_copied);
|
||||
donate_g |= g_copied;
|
||||
array& gx = outputs[0];
|
||||
array& gw = outputs[1];
|
||||
@@ -334,8 +338,10 @@ void LayerNormVJP::eval_gpu(
|
||||
// gradient accumulators.
|
||||
array gw_temp =
|
||||
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
|
||||
bool g_in_gw = false;
|
||||
if (has_w) {
|
||||
if (!g_in_gx && donate_g) {
|
||||
g_in_gw = true;
|
||||
gw_temp.copy_shared_buffer(g);
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||
@@ -343,41 +349,47 @@ void LayerNormVJP::eval_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
// Finish with the gradient for b in case we had a b.
|
||||
if (gb.ndim() == 1 && gb.size() == axis_size) {
|
||||
// The gradient for b in case we had a b.
|
||||
bool has_gb = (gb.ndim() == 1 && gb.size() == axis_size);
|
||||
if (has_gb) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan);
|
||||
}
|
||||
|
||||
// Insert dependency if `g` was donated
|
||||
if ((g_in_gx || g_in_gw) && has_gb) {
|
||||
encoder.set_input_array(gb);
|
||||
}
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(g);
|
||||
encoder.set_output_array(gx);
|
||||
encoder.set_output_array(gw_temp);
|
||||
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
|
||||
dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) {
|
||||
dispatch_bool(has_w, [&](auto has_w_constant) {
|
||||
constexpr int N_READS = 4;
|
||||
dispatch_block_dim(
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::layer_norm_vjp<
|
||||
DataType,
|
||||
has_w_constant(),
|
||||
block_dim(),
|
||||
N_READS>;
|
||||
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
g.data<DataType>(),
|
||||
gx.data<DataType>(),
|
||||
gw_temp.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
});
|
||||
dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) {
|
||||
dispatch_bool(has_w, [&](auto has_w_constant) {
|
||||
constexpr int N_READS = 4;
|
||||
dispatch_block_dim(
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::layer_norm_vjp<
|
||||
DataType,
|
||||
has_w_constant.value,
|
||||
block_dim(),
|
||||
N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
n_rows,
|
||||
block_dim(),
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
g.data<DataType>(),
|
||||
gx.data<DataType>(),
|
||||
gw_temp.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -143,16 +143,18 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) {
|
||||
constexpr int N_READS = 4;
|
||||
dispatch_block_dim(
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::logsumexp<DataType, float, block_dim(), N_READS>;
|
||||
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||
in.data<DataType>(), out.data<DataType>(), axis_size);
|
||||
});
|
||||
dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) {
|
||||
constexpr int N_READS = 4;
|
||||
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::logsumexp<DataType, float, block_dim(), N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
n_rows,
|
||||
block_dim(),
|
||||
in.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
axis_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -42,7 +42,8 @@ class MatMul {
|
||||
int64_t ldb,
|
||||
int32_t batch_count,
|
||||
int64_t a_batch_stride,
|
||||
int64_t b_batch_stride) {
|
||||
int64_t b_batch_stride)
|
||||
: handle_(device.lt_handle()) {
|
||||
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
||||
|
||||
auto scale_type = dtype_to_cuda_type(dtype);
|
||||
@@ -147,7 +148,7 @@ class MatMul {
|
||||
if (heuristic_.state != CUBLAS_STATUS_SUCCESS) {
|
||||
int ret = 0;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(
|
||||
encoder.device().lt_handle(),
|
||||
handle_,
|
||||
matmul_desc_,
|
||||
a_desc_,
|
||||
b_desc_,
|
||||
@@ -172,25 +173,24 @@ class MatMul {
|
||||
workspace_ptr = workspace.data<void>();
|
||||
}
|
||||
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
||||
encoder.device().lt_handle(),
|
||||
matmul_desc_,
|
||||
&alpha,
|
||||
a,
|
||||
a_desc_,
|
||||
b,
|
||||
b_desc_,
|
||||
&beta,
|
||||
c ? c : out,
|
||||
c ? c_desc_ : out_desc_,
|
||||
out,
|
||||
out_desc_,
|
||||
&heuristic_.algo,
|
||||
workspace_ptr,
|
||||
heuristic_.workspaceSize,
|
||||
stream));
|
||||
});
|
||||
auto capture = encoder.capture_context();
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
||||
handle_,
|
||||
matmul_desc_,
|
||||
&alpha,
|
||||
a,
|
||||
a_desc_,
|
||||
b,
|
||||
b_desc_,
|
||||
&beta,
|
||||
c ? c : out,
|
||||
c ? c_desc_ : out_desc_,
|
||||
out,
|
||||
out_desc_,
|
||||
&heuristic_.algo,
|
||||
workspace_ptr,
|
||||
heuristic_.workspaceSize,
|
||||
encoder.stream()));
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -259,6 +259,7 @@ class MatMul {
|
||||
return desc;
|
||||
}
|
||||
|
||||
cublasLtHandle_t handle_{nullptr};
|
||||
cublasLtMatmulDesc_t matmul_desc_{nullptr};
|
||||
cublasLtMatmulPreference_t pref_{nullptr};
|
||||
cublasLtMatrixLayout_t a_desc_{nullptr};
|
||||
@@ -273,7 +274,7 @@ class MatMul {
|
||||
namespace {
|
||||
|
||||
std::tuple<bool, int64_t, array>
|
||||
check_transpose(std::vector<array>& copies, const Stream& s, const array& arr) {
|
||||
check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1 && stx == arr.shape(-1)) {
|
||||
@@ -283,7 +284,7 @@ check_transpose(std::vector<array>& copies, const Stream& s, const array& arr) {
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
enc.add_temporary(arr_copy);
|
||||
return std::make_tuple(false, arr.shape(-1), arr_copy);
|
||||
}
|
||||
}
|
||||
@@ -317,13 +318,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre);
|
||||
|
||||
for (auto& temp : copies) {
|
||||
encoder.add_temporary(temp);
|
||||
}
|
||||
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
@@ -348,7 +344,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Invoke cublasLt
|
||||
|
||||
cu::MatMul matmul(
|
||||
encoder.device(),
|
||||
cu::device(s.device),
|
||||
a.dtype(),
|
||||
a_transposed,
|
||||
M,
|
||||
@@ -373,6 +369,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||
auto concurrent = encoder.concurrent_context();
|
||||
for (size_t i = 0; i < nbatch; ++i) {
|
||||
matmul.run(
|
||||
encoder,
|
||||
@@ -405,14 +402,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre);
|
||||
auto [c_transposed, ldc, c] = check_transpose(copies, s, c_pre);
|
||||
|
||||
for (auto& temp : copies) {
|
||||
encoder.add_temporary(temp);
|
||||
}
|
||||
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
||||
auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
@@ -440,7 +432,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Invoke cublasLt
|
||||
|
||||
cu::MatMul matmul(
|
||||
encoder.device(),
|
||||
cu::device(s.device),
|
||||
a.dtype(),
|
||||
a_transposed,
|
||||
M,
|
||||
@@ -478,6 +470,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
||||
auto concurrent = encoder.concurrent_context();
|
||||
for (size_t i = 0; i < nbatch; ++i) {
|
||||
matmul.run(
|
||||
encoder,
|
||||
|
||||
@@ -24,23 +24,21 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
auto& encoder = cu::get_command_encoder(stream());
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&, this](cudaStream_t stream) {
|
||||
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
|
||||
using CTYPE = MLX_GET_TYPE(type_tag);
|
||||
using OutType = cuda_type_t<CTYPE>;
|
||||
CTYPE step =
|
||||
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
|
||||
thrust::transform(
|
||||
cu::thrust_policy(stream),
|
||||
thrust::counting_iterator<uint32_t>(0),
|
||||
thrust::counting_iterator<uint32_t>(out.data_size()),
|
||||
thrust::device_pointer_cast(out.data<OutType>()),
|
||||
cu::Arange<OutType>{
|
||||
static_cast<OutType>(start_), static_cast<OutType>(step)});
|
||||
});
|
||||
auto capture = encoder.capture_context();
|
||||
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
|
||||
using CTYPE = MLX_GET_TYPE(type_tag);
|
||||
using OutType = cuda_type_t<CTYPE>;
|
||||
CTYPE step =
|
||||
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
|
||||
thrust::transform(
|
||||
cu::thrust_policy(encoder.stream()),
|
||||
thrust::counting_iterator<uint32_t>(0),
|
||||
thrust::counting_iterator<uint32_t>(out.data_size()),
|
||||
thrust::device_pointer_cast(out.data<OutType>()),
|
||||
cu::Arange<OutType>{
|
||||
static_cast<OutType>(start_), static_cast<OutType>(step)});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -156,34 +156,39 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(keys);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dim3 grid_dims{num_keys, half_size + odd};
|
||||
int64_t total = grid_dims.x * grid_dims.y;
|
||||
int32_t threads_y = 1;
|
||||
while ((total / threads_y) >= (1U << 31)) {
|
||||
threads_y *= 2;
|
||||
}
|
||||
int32_t threads_x = cuda::ceil_div(total, threads_y);
|
||||
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
|
||||
if (keys.flags().row_contiguous) {
|
||||
cu::rbitsc<<<grid, block, 0, stream>>>(
|
||||
keys.data<uint32_t>(),
|
||||
out.data<uint8_t>(),
|
||||
grid_dims,
|
||||
odd,
|
||||
bytes_per_key);
|
||||
} else {
|
||||
cu::rbits<<<grid, block, 0, stream>>>(
|
||||
keys.data<uint32_t>(),
|
||||
out.data<uint8_t>(),
|
||||
grid_dims,
|
||||
odd,
|
||||
bytes_per_key,
|
||||
keys.ndim(),
|
||||
const_param(keys.shape()),
|
||||
const_param(keys.strides()));
|
||||
}
|
||||
});
|
||||
dim3 grid_dims{num_keys, half_size + odd};
|
||||
int64_t total = grid_dims.x * grid_dims.y;
|
||||
int32_t threads_y = 1;
|
||||
while ((total / threads_y) >= (1U << 31)) {
|
||||
threads_y *= 2;
|
||||
}
|
||||
int32_t threads_x = cuda::ceil_div(total, threads_y);
|
||||
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
|
||||
auto& stream = encoder.stream();
|
||||
if (keys.flags().row_contiguous) {
|
||||
encoder.add_kernel_node(
|
||||
cu::rbitsc,
|
||||
grid,
|
||||
block,
|
||||
keys.data<uint32_t>(),
|
||||
out.data<uint8_t>(),
|
||||
grid_dims,
|
||||
odd,
|
||||
bytes_per_key);
|
||||
} else {
|
||||
encoder.add_kernel_node(
|
||||
cu::rbits,
|
||||
grid,
|
||||
block,
|
||||
keys.data<uint32_t>(),
|
||||
out.data<uint8_t>(),
|
||||
grid_dims,
|
||||
odd,
|
||||
bytes_per_key,
|
||||
keys.ndim(),
|
||||
const_param(keys.shape()),
|
||||
const_param(keys.strides()));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -110,19 +110,20 @@ void all_reduce(
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
encoder.add_temporary(intermediate);
|
||||
encoder.set_output_array(intermediate);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_all_types(dt, [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
||||
kernel<<<blocks, threads, 0, stream>>>(
|
||||
static_cast<T*>(indata),
|
||||
intermediate.data<U>(),
|
||||
block_step,
|
||||
insize);
|
||||
});
|
||||
dispatch_all_types(dt, [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
blocks,
|
||||
threads,
|
||||
static_cast<T*>(indata),
|
||||
intermediate.data<U>(),
|
||||
block_step,
|
||||
insize);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -135,16 +136,20 @@ void all_reduce(
|
||||
}
|
||||
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_all_types(dt, [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
||||
kernel<<<blocks, threads, 0, stream>>>(
|
||||
static_cast<T*>(indata), out.data<U>(), block_step, insize);
|
||||
});
|
||||
dispatch_all_types(dt, [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
blocks,
|
||||
threads,
|
||||
static_cast<T*>(indata),
|
||||
out.data<U>(),
|
||||
block_step,
|
||||
insize);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -214,26 +214,24 @@ void col_reduce_looped(
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||
T* indata = const_cast<T*>(in.data<T>());
|
||||
|
||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||
T* indata = const_cast<T*>(in.data<T>());
|
||||
|
||||
constexpr int N_READS = 4;
|
||||
constexpr int BM = 32;
|
||||
constexpr int BN = 32;
|
||||
dim3 grid = output_grid_for_col_reduce(out, args, BN);
|
||||
int blocks = BM * BN / N_READS;
|
||||
auto kernel =
|
||||
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
||||
kernel<<<grid, blocks, 0, stream>>>(indata, out.data<U>(), args);
|
||||
});
|
||||
constexpr int N_READS = 4;
|
||||
constexpr int BM = 32;
|
||||
constexpr int BN = 32;
|
||||
dim3 grid = output_grid_for_col_reduce(out, args, BN);
|
||||
int blocks = BM * BN / N_READS;
|
||||
auto kernel =
|
||||
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel, grid, blocks, indata, out.data<U>(), args);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -32,18 +32,16 @@ void init_reduce(
|
||||
}
|
||||
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
auto kernel = cu::init_reduce<T, U, OP>;
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
|
||||
grid.x = (grid.x + 1023) / 1024;
|
||||
kernel<<<grid, block, 0, stream>>>(out.data<U>(), out.size());
|
||||
});
|
||||
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
auto kernel = cu::init_reduce<T, U, OP>;
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
|
||||
grid.x = (grid.x + 1023) / 1024;
|
||||
encoder.add_kernel_node(kernel, grid, block, out.data<U>(), out.size());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -245,34 +245,32 @@ void row_reduce_simple(
|
||||
// 2 passes. Something like 32 * out.size() and then do a warp reduce.
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
|
||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||
T* indata = const_cast<T*>(in.data<T>());
|
||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||
T* indata = const_cast<T*>(in.data<T>());
|
||||
|
||||
// Calculate the grid and block dims
|
||||
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
int threads = std::min(1024UL, reductions);
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
dim3 block(threads, 1, 1);
|
||||
// Calculate the grid and block dims
|
||||
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
int threads = std::min(1024UL, reductions);
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
dim3 block(threads, 1, 1);
|
||||
|
||||
// Pick the kernel
|
||||
auto kernel = cu::row_reduce_simple<T, U, OP, N_READS>;
|
||||
if (grid.x >= 1024) {
|
||||
grid.x = (grid.x + 1) / 2;
|
||||
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
|
||||
}
|
||||
// Pick the kernel
|
||||
auto kernel = cu::row_reduce_simple<T, U, OP, N_READS>;
|
||||
if (grid.x >= 1024) {
|
||||
grid.x = (grid.x + 1) / 2;
|
||||
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
|
||||
}
|
||||
|
||||
// Launch
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
indata, out.data<U>(), out.size(), plan.shape.back());
|
||||
});
|
||||
int size = plan.shape.back();
|
||||
encoder.add_kernel_node(
|
||||
kernel, grid, block, indata, out.data<U>(), out.size(), size);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -293,43 +291,39 @@ void row_reduce_looped(
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||
T* indata = const_cast<T*>(in.data<T>());
|
||||
|
||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||
T* indata = const_cast<T*>(in.data<T>());
|
||||
// Calculate the grid and block dims
|
||||
args.sort_access_pattern(in, axes);
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
||||
int threads = std::min(1024UL, reductions);
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
dim3 block(threads, 1, 1);
|
||||
|
||||
// Calculate the grid and block dims
|
||||
args.sort_access_pattern(in, axes);
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
||||
int threads = std::min(1024UL, reductions);
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
dim3 block(threads, 1, 1);
|
||||
|
||||
// Pick the kernel
|
||||
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
|
||||
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
|
||||
dispatch_block_dim(threads, [&](auto threads_constant) {
|
||||
kernel = cu::row_reduce_looped<
|
||||
T,
|
||||
U,
|
||||
OP,
|
||||
reduce_ndim(),
|
||||
threads_constant(),
|
||||
N_READS>;
|
||||
block.x = threads_constant();
|
||||
});
|
||||
// Pick the kernel
|
||||
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
|
||||
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
|
||||
dispatch_block_dim(threads, [&](auto threads_constant) {
|
||||
kernel = cu::row_reduce_looped<
|
||||
T,
|
||||
U,
|
||||
OP,
|
||||
reduce_ndim.value,
|
||||
threads_constant.value,
|
||||
N_READS>;
|
||||
block.x = threads_constant.value;
|
||||
});
|
||||
|
||||
// Launch
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
indata, out.data<U>(), out.size(), args);
|
||||
});
|
||||
|
||||
encoder.add_kernel_node(
|
||||
kernel, grid, block, indata, out.data<U>(), out.size(), args);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -224,21 +224,21 @@ void RMSNorm::eval_gpu(
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
|
||||
constexpr uint32_t N_READS = 4;
|
||||
dispatch_block_dim(
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>;
|
||||
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
|
||||
constexpr uint32_t N_READS = 4;
|
||||
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
n_rows,
|
||||
block_dim(),
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -253,20 +253,24 @@ void RMSNormVJP::eval_gpu(
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
auto check_input = [&s](const array& x) -> std::pair<array, bool> {
|
||||
auto check_input = [&s](const array& x, bool& copied) {
|
||||
if (x.flags().row_contiguous) {
|
||||
return {x, false};
|
||||
copied = false;
|
||||
return x;
|
||||
}
|
||||
copied = true;
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
return {x_copy, true};
|
||||
return x_copy;
|
||||
};
|
||||
bool donate_x = inputs[0].is_donatable();
|
||||
bool donate_g = inputs[2].is_donatable();
|
||||
auto [x, copied] = check_input(inputs[0]);
|
||||
bool copied;
|
||||
auto x = check_input(inputs[0], copied);
|
||||
donate_x |= copied;
|
||||
const array& w = inputs[1];
|
||||
auto [g, g_copied] = check_input(inputs[2]);
|
||||
bool g_copied;
|
||||
auto g = check_input(inputs[2], g_copied);
|
||||
donate_g |= g_copied;
|
||||
array& gx = outputs[0];
|
||||
array& gw = outputs[1];
|
||||
@@ -310,30 +314,31 @@ void RMSNormVJP::eval_gpu(
|
||||
encoder.set_input_array(g);
|
||||
encoder.set_output_array(gx);
|
||||
encoder.set_output_array(gw_temp);
|
||||
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
|
||||
dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) {
|
||||
dispatch_bool(has_w, [&](auto has_w_constant) {
|
||||
constexpr int N_READS = 4;
|
||||
dispatch_block_dim(
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr int N_READS = 4;
|
||||
auto kernel = cu::rms_norm_vjp<
|
||||
DataType,
|
||||
has_w_constant(),
|
||||
block_dim(),
|
||||
N_READS>;
|
||||
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
g.data<DataType>(),
|
||||
gx.data<DataType>(),
|
||||
gw_temp.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
});
|
||||
dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) {
|
||||
dispatch_bool(has_w, [&](auto has_w_constant) {
|
||||
constexpr int N_READS = 4;
|
||||
dispatch_block_dim(
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr int N_READS = 4;
|
||||
auto kernel = cu::rms_norm_vjp<
|
||||
DataType,
|
||||
has_w_constant.value,
|
||||
block_dim(),
|
||||
N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
n_rows,
|
||||
block_dim(),
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
g.data<DataType>(),
|
||||
gx.data<DataType>(),
|
||||
gw_temp.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -308,74 +308,89 @@ void RoPE::eval_gpu(
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(donated ? out : in);
|
||||
encoder.set_input_array(offset);
|
||||
if (with_freqs) {
|
||||
encoder.set_input_array(inputs[2]);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) {
|
||||
dispatch_bool(traditional_, [&](auto traditional) {
|
||||
dispatch_bool(forward_, [&](auto forward) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
if (single && !with_freqs) {
|
||||
auto kernel = cu::rope_single<DataType, traditional(), forward()>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
mat_size,
|
||||
dims);
|
||||
} else if (single) {
|
||||
auto kernel =
|
||||
cu::rope_single_freqs<DataType, traditional(), forward()>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
inputs[2].data<float>(),
|
||||
scale_,
|
||||
mat_size,
|
||||
dims,
|
||||
inputs[2].strides(0));
|
||||
} else if (with_freqs) {
|
||||
auto kernel = cu::rope_freqs<DataType, traditional(), forward()>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
inputs[2].data<float>(),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
strides,
|
||||
out_strides,
|
||||
in.size() / mat_size,
|
||||
dims,
|
||||
inputs[2].strides(0));
|
||||
} else {
|
||||
auto kernel = cu::rope<DataType, traditional(), forward()>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
strides,
|
||||
out_strides,
|
||||
in.size() / mat_size,
|
||||
dims);
|
||||
}
|
||||
});
|
||||
dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) {
|
||||
dispatch_bool(traditional_, [&](auto traditional) {
|
||||
dispatch_bool(forward_, [&](auto forward) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
if (single && !with_freqs) {
|
||||
auto kernel =
|
||||
cu::rope_single<DataType, traditional.value, forward.value>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
block,
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
mat_size,
|
||||
dims);
|
||||
} else if (single) {
|
||||
auto kernel =
|
||||
cu::rope_single_freqs<DataType, traditional.value, forward.value>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
block,
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
inputs[2].data<float>(),
|
||||
scale_,
|
||||
mat_size,
|
||||
dims,
|
||||
inputs[2].strides(0));
|
||||
} else if (with_freqs) {
|
||||
auto kernel =
|
||||
cu::rope_freqs<DataType, traditional.value, forward.value>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
block,
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
inputs[2].data<float>(),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
strides,
|
||||
out_strides,
|
||||
in.size() / mat_size,
|
||||
dims,
|
||||
inputs[2].strides(0));
|
||||
} else {
|
||||
auto kernel = cu::rope<DataType, traditional.value, forward.value>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
block,
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
strides,
|
||||
out_strides,
|
||||
in.size() / mat_size,
|
||||
dims);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -141,19 +141,21 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) {
|
||||
constexpr int N_READS = 4;
|
||||
dispatch_block_dim(
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::softmax<DataType, DataType, block_dim(), N_READS>;
|
||||
if (precise) {
|
||||
kernel = cu::softmax<DataType, float, block_dim(), N_READS>;
|
||||
}
|
||||
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||
in.data<DataType>(), out.data<DataType>(), axis_size);
|
||||
});
|
||||
dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) {
|
||||
constexpr int N_READS = 4;
|
||||
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::softmax<DataType, DataType, block_dim(), N_READS>;
|
||||
if (precise) {
|
||||
kernel = cu::softmax<DataType, float, block_dim(), N_READS>;
|
||||
}
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
n_rows,
|
||||
block_dim(),
|
||||
in.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
axis_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -50,32 +50,6 @@ array swapaxes_in_eval(const array& in, int axis1, int axis2) {
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void segmented_sort_pairs(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
// Allocate temporary storage.
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(
|
||||
cub::DeviceSegmentedSort::StableSortPairs(nullptr, size, args...));
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
// Run op.
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
|
||||
temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
// Allocate temporary storage.
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(
|
||||
cub::DeviceSegmentedSort::StableSortKeys(nullptr, size, args...));
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
// Run op.
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
|
||||
temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
struct OffsetTransform {
|
||||
int nsort;
|
||||
|
||||
@@ -113,57 +87,94 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||
using CTYPE = MLX_GET_TYPE(type_tag);
|
||||
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
||||
using Type = cuda_type_t<CTYPE>;
|
||||
auto offsets = thrust::make_transform_iterator(
|
||||
thrust::make_counting_iterator(0), OffsetTransform{nsort});
|
||||
if (argsort) {
|
||||
// Indices in the sorted dimension.
|
||||
array indices(
|
||||
allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
||||
encoder.add_temporary(indices);
|
||||
thrust::transform(
|
||||
cu::thrust_policy(stream),
|
||||
thrust::counting_iterator<uint32_t>(0),
|
||||
thrust::counting_iterator<uint32_t>(indices.data_size()),
|
||||
thrust::device_pointer_cast(indices.data<uint32_t>()),
|
||||
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
|
||||
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||
using CTYPE = MLX_GET_TYPE(type_tag);
|
||||
auto& stream = encoder.stream();
|
||||
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
||||
using Type = cuda_type_t<CTYPE>;
|
||||
auto offsets = thrust::make_transform_iterator(
|
||||
thrust::make_counting_iterator(0), OffsetTransform{nsort});
|
||||
if (argsort) {
|
||||
// Indices in the sorted dimension.
|
||||
array indices(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
||||
encoder.add_temporary(indices);
|
||||
|
||||
// In argsort though we don't need the result of sorted values, the
|
||||
// API requires us to provide an array to store it.
|
||||
array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype());
|
||||
encoder.add_temporary(discard);
|
||||
// In argsort though we don't need the result of sorted values, the
|
||||
// API requires us to provide an array to store it.
|
||||
array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype());
|
||||
encoder.add_temporary(discard);
|
||||
|
||||
segmented_sort_pairs(
|
||||
encoder,
|
||||
in.data<Type>(),
|
||||
discard.data<Type>(),
|
||||
indices.data<uint32_t>(),
|
||||
out.data<uint32_t>(),
|
||||
in.data_size(),
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
stream);
|
||||
} else {
|
||||
segmented_sort(
|
||||
encoder,
|
||||
in.data<Type>(),
|
||||
out.data<Type>(),
|
||||
in.data_size(),
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
stream);
|
||||
}
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
|
||||
nullptr,
|
||||
size,
|
||||
in.data<Type>(),
|
||||
discard.data<Type>(),
|
||||
indices.data<uint32_t>(),
|
||||
out.data<uint32_t>(),
|
||||
in.data_size(),
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
stream));
|
||||
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
|
||||
// Start capturing after allocations
|
||||
auto capture = encoder.capture_context();
|
||||
thrust::transform(
|
||||
cu::thrust_policy(stream),
|
||||
thrust::counting_iterator<uint32_t>(0),
|
||||
thrust::counting_iterator<uint32_t>(indices.data_size()),
|
||||
thrust::device_pointer_cast(indices.data<uint32_t>()),
|
||||
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
|
||||
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
|
||||
temp.data<void>(),
|
||||
size,
|
||||
in.data<Type>(),
|
||||
discard.data<Type>(),
|
||||
indices.data<uint32_t>(),
|
||||
out.data<uint32_t>(),
|
||||
in.data_size(),
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
stream));
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"CUDA backend does not support sorting complex numbers");
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
|
||||
nullptr,
|
||||
size,
|
||||
in.data<Type>(),
|
||||
out.data<Type>(),
|
||||
in.data_size(),
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
stream));
|
||||
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
|
||||
// Start capturing after allocations
|
||||
auto capture = encoder.capture_context();
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
|
||||
temp.data<void>(),
|
||||
size,
|
||||
in.data<Type>(),
|
||||
out.data<Type>(),
|
||||
in.data_size(),
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
stream));
|
||||
}
|
||||
});
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"CUDA backend does not support sorting complex numbers");
|
||||
}
|
||||
});
|
||||
|
||||
if (!is_segmented_sort) {
|
||||
|
||||
@@ -91,73 +91,80 @@ void ternary_op_gpu_inplace(
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(c);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_all_types(out.dtype(), [&](auto type_tag) {
|
||||
using DType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
dispatch_all_types(out.dtype(), [&](auto type_tag) {
|
||||
using DType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
|
||||
auto topt = get_ternary_op_type(a, b, c);
|
||||
if (topt == TernaryOpType::General) {
|
||||
dispatch_bool(
|
||||
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||
[&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
Shape shape;
|
||||
std::vector<Strides> strides;
|
||||
std::tie(shape, strides) = collapse_contiguous_dims(a, b, c, out);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
auto& c_strides = strides[2];
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel =
|
||||
cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
out.size(),
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(a_strides),
|
||||
const_param<dims_constant()>(b_strides),
|
||||
const_param<dims_constant()>(c_strides));
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::ternary_g<Op, DType, IdxT>;
|
||||
auto topt = get_ternary_op_type(a, b, c);
|
||||
if (topt == TernaryOpType::General) {
|
||||
dispatch_bool(
|
||||
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||
[&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
Shape shape;
|
||||
std::vector<Strides> strides;
|
||||
std::tie(shape, strides) = collapse_contiguous_dims(a, b, c, out);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
auto& c_strides = strides[2];
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel =
|
||||
cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
out.data_size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
const_param(c_strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
auto kernel = cu::ternary_v<Op, DType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), large());
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
out.data_size());
|
||||
});
|
||||
}
|
||||
});
|
||||
out.size(),
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(a_strides),
|
||||
const_param<dims_constant()>(b_strides),
|
||||
const_param<dims_constant()>(c_strides));
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::ternary_g<Op, DType, IdxT>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, large());
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
out.data_size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
const_param(c_strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
auto kernel = cu::ternary_v<Op, DType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), large());
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
out.data_size());
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -9,14 +9,38 @@
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/transform.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void unary_v(const In* in, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
out[index] = Op{}(in[index]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void unary_g(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides strides,
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto idx = elem_to_loc_4d(index, shape.data(), strides.data(), ndim);
|
||||
out[index] = Op{}(in[idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out>
|
||||
constexpr bool supports_unary_op() {
|
||||
if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||
|
||||
@@ -71,38 +95,61 @@ void unary_op_gpu_inplace(
|
||||
if (in.size() == 0) {
|
||||
return;
|
||||
}
|
||||
bool contig = in.flags().contiguous;
|
||||
bool large;
|
||||
if (!contig) {
|
||||
large = in.data_size() > INT32_MAX || out.size() > INT32_MAX;
|
||||
} else {
|
||||
large = in.data_size() > UINT32_MAX;
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
|
||||
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
|
||||
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
|
||||
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
|
||||
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
dispatch_bool(large, [&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
using InType = cuda_type_t<CTYPE_IN>;
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
auto policy = cu::thrust_policy(stream);
|
||||
auto in_ptr = thrust::device_pointer_cast(in.data<InType>());
|
||||
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
|
||||
if (in.flags().contiguous) {
|
||||
thrust::transform(
|
||||
policy, in_ptr, in_ptr + in.data_size(), out_ptr, Op());
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
if (contig) {
|
||||
auto kernel = cu::unary_v<Op, InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), large);
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
in.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.data_size());
|
||||
} else {
|
||||
auto [shape, strides] = collapse_contiguous_dims(in);
|
||||
auto [in_begin, in_end] = cu::make_general_iterators<int64_t>(
|
||||
in_ptr, in.size(), shape, strides);
|
||||
thrust::transform(policy, in_begin, in_end, out_ptr, Op());
|
||||
auto kernel = cu::unary_g<Op, InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
in.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.data_size(),
|
||||
const_param(shape),
|
||||
const_param(strides),
|
||||
shape.size());
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do unary op {} on input of {} with output of {}.",
|
||||
op,
|
||||
dtype_to_string(in.dtype()),
|
||||
dtype_to_string(out.dtype())));
|
||||
}
|
||||
});
|
||||
});
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do unary op {} on input of {} with output of {}.",
|
||||
op,
|
||||
dtype_to_string(in.dtype()),
|
||||
dtype_to_string(out.dtype())));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -24,6 +24,14 @@ void check_cuda_error(const char* name, cudaError_t err) {
|
||||
}
|
||||
}
|
||||
|
||||
void check_cuda_error(const char* name, CUresult err) {
|
||||
if (err != CUDA_SUCCESS) {
|
||||
const char* err_str = "Unknown error";
|
||||
cuGetErrorString(err, &err_str);
|
||||
throw std::runtime_error(fmt::format("{} failed: {}", name, err_str));
|
||||
}
|
||||
}
|
||||
|
||||
const char* dtype_to_cuda_type(const Dtype& dtype) {
|
||||
switch (dtype) {
|
||||
case bool_:
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -33,6 +34,7 @@ class CudaStream {
|
||||
|
||||
// Throw exception if the cuda API does not succeed.
|
||||
void check_cuda_error(const char* name, cudaError_t err);
|
||||
void check_cuda_error(const char* name, CUresult err);
|
||||
|
||||
// The macro version that prints the command that failed.
|
||||
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
||||
|
||||
@@ -187,6 +187,9 @@ struct Max {
|
||||
|
||||
template <typename T>
|
||||
T simd_reduce_impl(T val) {
|
||||
if (simd_any(val != val)) {
|
||||
return static_cast<T>(NAN);
|
||||
}
|
||||
return simd_max(val);
|
||||
}
|
||||
|
||||
@@ -198,7 +201,35 @@ struct Max {
|
||||
}
|
||||
|
||||
// Operator
|
||||
U operator()(U a, U b) {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T a, T b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T a, T b) {
|
||||
if (metal::isnan(a) || metal::isnan(b)) {
|
||||
return static_cast<T>(NAN);
|
||||
} else {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t a, complex64_t b) {
|
||||
bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real);
|
||||
bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag);
|
||||
|
||||
if (!real_is_nan && !imag_is_nan) {
|
||||
return a > b ? a : b;
|
||||
} else if (real_is_nan && !imag_is_nan) {
|
||||
return complex64_t(
|
||||
static_cast<float>(NAN), a.imag > b.imag ? a.imag : b.imag);
|
||||
} else if (!real_is_nan && imag_is_nan) {
|
||||
return complex64_t(
|
||||
a.real > b.real ? a.real : b.real, static_cast<float>(NAN));
|
||||
} else {
|
||||
return complex64_t(static_cast<float>(NAN), static_cast<float>(NAN));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -688,7 +688,7 @@ array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
perm = expand_dims(perm, -1, s);
|
||||
take_axis -= 1;
|
||||
}
|
||||
auto pb = take_along_axis(b, perm, take_axis);
|
||||
auto pb = take_along_axis(b, perm, take_axis, s);
|
||||
auto y = solve_triangular(luf[1], pb, /* upper = */ false, s);
|
||||
return solve_triangular(luf[2], y, /* upper = */ true, s);
|
||||
}
|
||||
|
||||
@@ -114,6 +114,12 @@ class Module(dict):
|
||||
super(Module, self).__setattr__(key, val)
|
||||
self.pop(key, None)
|
||||
|
||||
def __delattr__(self, name):
|
||||
if (val := self.get(name, None)) is not None:
|
||||
del self[name]
|
||||
else:
|
||||
super().__delattr__(name)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
|
||||
|
||||
@@ -391,9 +391,11 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
scale = mx.array(2.0)
|
||||
y = mx.load(save_file)
|
||||
mx.eval(y)
|
||||
mx.synchronize()
|
||||
load_only = mx.get_peak_memory()
|
||||
y = mx.load(save_file) * scale
|
||||
mx.eval(y)
|
||||
mx.synchronize()
|
||||
load_with_binary = mx.get_peak_memory()
|
||||
|
||||
self.assertEqual(load_only, load_with_binary)
|
||||
|
||||
@@ -274,6 +274,11 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
m = MyModel()
|
||||
m.update_modules(m.leaf_modules())
|
||||
|
||||
def test_parameter_deletion(self):
|
||||
m = nn.Linear(32, 32)
|
||||
del m.weight
|
||||
self.assertFalse(hasattr(m, "weight"))
|
||||
|
||||
|
||||
class TestLayers(mlx_tests.MLXTestCase):
|
||||
def test_identity(self):
|
||||
|
||||
@@ -153,6 +153,63 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
x = x.transpose(1, 0, 2, 3, 4, 5, 6, 7, 8, 9)
|
||||
check(x, (1, 3, 5, 7, 9))
|
||||
|
||||
def test_nanpropagation(self):
|
||||
dtypes = [
|
||||
"uint8",
|
||||
"uint16",
|
||||
"uint32",
|
||||
"int8",
|
||||
"int16",
|
||||
"int32",
|
||||
"float16",
|
||||
"float32",
|
||||
]
|
||||
|
||||
for dtype in dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
x = (mx.random.normal((4, 4))).astype(getattr(mx, dtype))
|
||||
indices = mx.random.randint(0, 4, shape=(6,)).reshape(3, 2)
|
||||
for idx in indices:
|
||||
x[*idx] = mx.nan
|
||||
x_np = np.array(x)
|
||||
|
||||
for op in ["max"]:
|
||||
for axis in [0, 1]:
|
||||
out = getattr(mx, op)(x, axis=axis)
|
||||
ref = getattr(np, op)(x_np, axis=axis)
|
||||
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
|
||||
|
||||
def test_nanpropagation_complex64(self):
|
||||
complex_array_1 = mx.array(
|
||||
[1 + 1j, 2 + 2j, 3 + 3j, mx.nan + 4j], dtype=mx.complex64
|
||||
).reshape(2, 2)
|
||||
complex_array_2 = mx.array(
|
||||
[1 + 1j, 2 + 2j, 3 + mx.nan * 1j, 4 + 4j], dtype=mx.complex64
|
||||
).reshape(2, 2)
|
||||
complex_array_3 = mx.array(
|
||||
[1 + 1j, 2 + mx.nan * 1j, 3 + 3j, 4 + 4j], dtype=mx.complex64
|
||||
).reshape(2, 2)
|
||||
complex_array_4 = mx.array(
|
||||
[mx.nan + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=mx.complex64
|
||||
).reshape(2, 2)
|
||||
|
||||
np_arrays = [
|
||||
np.array(complex_array_1),
|
||||
np.array(complex_array_2),
|
||||
np.array(complex_array_3),
|
||||
np.array(complex_array_4),
|
||||
]
|
||||
|
||||
for mx_arr, np_arr in zip(
|
||||
[complex_array_1, complex_array_2, complex_array_3, complex_array_4],
|
||||
np_arrays,
|
||||
):
|
||||
for axis in [0, 1]:
|
||||
for op in ["max"]:
|
||||
out = getattr(mx, op)(mx_arr, axis=axis)
|
||||
ref = getattr(np, op)(np_arr, axis=axis)
|
||||
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner(failfast=True)
|
||||
|
||||
@@ -1024,6 +1024,10 @@ TEST_CASE("test reduction ops") {
|
||||
x = array({true, true, true, false, true, false}, {2, 3});
|
||||
CHECK(array_equal(min(x, 1), array({true, false})).item<bool>());
|
||||
CHECK(array_equal(min(x, 0), array({false, true, false})).item<bool>());
|
||||
|
||||
x = array({1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3});
|
||||
CHECK(array_equal(max(x, 0), array({4.0f, NAN, 6.0f}), true).item<bool>());
|
||||
CHECK(array_equal(max(x, 1), array({NAN, 6.0f}), true).item<bool>());
|
||||
}
|
||||
|
||||
// Test logsumexp
|
||||
|
||||
Reference in New Issue
Block a user