mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
3 Commits
ef631d63af
...
a0ae49d397
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0ae49d397 | ||
|
|
254476718b | ||
|
|
3adba92ebe |
@@ -6,6 +6,7 @@
|
|||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/arange.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
|
||||||
@@ -29,7 +30,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
|
||||||
|
|||||||
55
mlx/backend/cuda/arange.cu
Normal file
55
mlx/backend/cuda/arange.cu
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <thrust/device_ptr.h>
|
||||||
|
#include <thrust/transform.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct Arange {
|
||||||
|
const T start;
|
||||||
|
const T step;
|
||||||
|
|
||||||
|
__device__ T operator()(uint32_t i) const {
|
||||||
|
return start + i * step;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Arange::eval_gpu");
|
||||||
|
if (out.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(stream());
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
auto capture = encoder.capture_context();
|
||||||
|
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
|
||||||
|
using CTYPE = MLX_GET_TYPE(type_tag);
|
||||||
|
using OutType = cuda_type_t<CTYPE>;
|
||||||
|
CTYPE step =
|
||||||
|
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
|
||||||
|
thrust::transform(
|
||||||
|
cu::thrust_policy(encoder.stream()),
|
||||||
|
thrust::counting_iterator<uint32_t>(0),
|
||||||
|
thrust::counting_iterator<uint32_t>(out.data_size()),
|
||||||
|
thrust::device_pointer_cast(out.data<OutType>()),
|
||||||
|
cu::Arange<OutType>{
|
||||||
|
static_cast<OutType>(start_), static_cast<OutType>(step)});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -211,12 +211,15 @@ void binary_op_gpu_inplace(
|
|||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||||
auto kernel = cu::
|
|
||||||
binary_g_nd<Op, InType, OutType, IdxT, dims_constant()>;
|
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] =
|
||||||
get_launch_args(kernel, out, large());
|
get_launch_args(out, large());
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
cu::binary_g_nd<
|
||||||
|
Op,
|
||||||
|
InType,
|
||||||
|
OutType,
|
||||||
|
IdxT,
|
||||||
|
dims_constant()>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
@@ -228,11 +231,9 @@ void binary_op_gpu_inplace(
|
|||||||
const_param<dims_constant()>(b_strides));
|
const_param<dims_constant()>(b_strides));
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||||
auto [num_blocks, block_dims] =
|
|
||||||
get_launch_args(kernel, out, large());
|
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
cu::binary_g<Op, InType, OutType, IdxT>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
@@ -258,12 +259,7 @@ void binary_op_gpu_inplace(
|
|||||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT, N_READS>;
|
kernel = cu::binary_vv<Op, InType, OutType, IdxT, N_READS>;
|
||||||
}
|
}
|
||||||
auto [num_blocks, block_dims] = get_launch_args(
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
kernel,
|
out.data_size(), out.shape(), out.strides(), large(), N_READS);
|
||||||
out.data_size(),
|
|
||||||
out.shape(),
|
|
||||||
out.strides(),
|
|
||||||
large(),
|
|
||||||
N_READS);
|
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
|
|||||||
@@ -227,16 +227,15 @@ void binary_two_op_gpu_inplace(
|
|||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||||
auto kernel = cu::binary_two_g_nd<
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(out_a, large());
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::binary_two_g_nd<
|
||||||
Op,
|
Op,
|
||||||
InType,
|
InType,
|
||||||
OutType,
|
OutType,
|
||||||
IdxT,
|
IdxT,
|
||||||
dims_constant()>;
|
dims_constant()>,
|
||||||
auto [num_blocks, block_dims] =
|
|
||||||
get_launch_args(kernel, out_a, large());
|
|
||||||
encoder.add_kernel_node(
|
|
||||||
kernel,
|
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
@@ -249,11 +248,10 @@ void binary_two_op_gpu_inplace(
|
|||||||
const_param<dims_constant()>(b_strides));
|
const_param<dims_constant()>(b_strides));
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
auto kernel = cu::binary_two_g<Op, InType, OutType, IdxT>;
|
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] =
|
||||||
get_launch_args(kernel, out_a, large());
|
get_launch_args(out_a, large());
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
cu::binary_two_g<Op, InType, OutType, IdxT>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
@@ -280,7 +278,6 @@ void binary_two_op_gpu_inplace(
|
|||||||
kernel = cu::binary_two_vv<Op, InType, OutType, IdxT, N_READS>;
|
kernel = cu::binary_two_vv<Op, InType, OutType, IdxT, N_READS>;
|
||||||
}
|
}
|
||||||
auto [num_blocks, block_dims] = get_launch_args(
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
kernel,
|
|
||||||
out_a.data_size(),
|
out_a.data_size(),
|
||||||
out_a.shape(),
|
out_a.shape(),
|
||||||
out_a.strides(),
|
out_a.strides(),
|
||||||
|
|||||||
@@ -294,7 +294,7 @@ void Compiled::eval_gpu(
|
|||||||
|
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] =
|
||||||
get_launch_args(kernel, outputs[0], large, work_per_thread);
|
get_launch_args(outputs[0], large, work_per_thread);
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -71,12 +71,7 @@ void copy_contiguous(
|
|||||||
kernel = cu::copy_v<InType, OutType, IdxT, N_READS>;
|
kernel = cu::copy_v<InType, OutType, IdxT, N_READS>;
|
||||||
}
|
}
|
||||||
auto [num_blocks, block_dims] = get_launch_args(
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
kernel,
|
out.data_size(), out.shape(), out.strides(), large(), N_READS);
|
||||||
out.data_size(),
|
|
||||||
out.shape(),
|
|
||||||
out.strides(),
|
|
||||||
large(),
|
|
||||||
N_READS);
|
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
|
|||||||
@@ -71,12 +71,10 @@ void copy_general(
|
|||||||
data_size *= s;
|
data_size *= s;
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
||||||
auto kernel =
|
auto [num_blocks, block_dims] =
|
||||||
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>;
|
get_launch_args(data_size, shape, out.strides(), large());
|
||||||
auto [num_blocks, block_dims] = get_launch_args(
|
|
||||||
kernel, data_size, shape, out.strides(), large());
|
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
@@ -87,11 +85,10 @@ void copy_general(
|
|||||||
const_param<ndim_constant()>(strides_out));
|
const_param<ndim_constant()>(strides_out));
|
||||||
});
|
});
|
||||||
} else { // ndim >= 4
|
} else { // ndim >= 4
|
||||||
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
auto [num_blocks, block_dims] =
|
||||||
auto [num_blocks, block_dims] = get_launch_args(
|
get_launch_args(data_size, shape, out.strides(), large());
|
||||||
kernel, data_size, shape, out.strides(), large());
|
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
cu::copy_gg<InType, OutType, IdxT>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
|
|||||||
@@ -74,12 +74,13 @@ void copy_general_dynamic(
|
|||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||||
auto kernel = cu::
|
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||||
copy_gg_dynamic_nd<InType, OutType, IdxT, dims_constant()>;
|
|
||||||
auto [num_blocks, block_dims] =
|
|
||||||
get_launch_args(kernel, out, large());
|
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
cu::copy_gg_dynamic_nd<
|
||||||
|
InType,
|
||||||
|
OutType,
|
||||||
|
IdxT,
|
||||||
|
dims_constant()>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
@@ -92,11 +93,9 @@ void copy_general_dynamic(
|
|||||||
dynamic_offset_out.data<int64_t>());
|
dynamic_offset_out.data<int64_t>());
|
||||||
});
|
});
|
||||||
} else { // ndim >= 4
|
} else { // ndim >= 4
|
||||||
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
|
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||||
auto [num_blocks, block_dims] =
|
|
||||||
get_launch_args(kernel, out, large());
|
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
cu::copy_gg_dynamic<InType, OutType, IdxT>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
|
|||||||
@@ -63,12 +63,9 @@ void copy_general_input(
|
|||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||||
auto kernel =
|
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>;
|
|
||||||
auto [num_blocks, block_dims] =
|
|
||||||
get_launch_args(kernel, out, large());
|
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
@@ -78,11 +75,9 @@ void copy_general_input(
|
|||||||
const_param<dims_constant()>(strides_in));
|
const_param<dims_constant()>(strides_in));
|
||||||
});
|
});
|
||||||
} else { // ndim >= 4
|
} else { // ndim >= 4
|
||||||
auto kernel = cu::copy_g<InType, OutType, IdxT>;
|
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||||
auto [num_blocks, block_dims] =
|
|
||||||
get_launch_args(kernel, out, large());
|
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
cu::copy_g<InType, OutType, IdxT>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct Arange {
|
|
||||||
const T start;
|
|
||||||
const T step;
|
|
||||||
|
|
||||||
__device__ T operator()(uint32_t i) const {
|
|
||||||
return start + i * step;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
|
||||||
@@ -128,7 +128,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
auto [num_blocks, block_dims] = get_launch_args(out, large);
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -229,7 +229,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, upd, large);
|
auto [num_blocks, block_dims] = get_launch_args(upd, large);
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -317,7 +317,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
|
auto [num_blocks, block_dims] = get_launch_args(idx, large);
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -421,7 +421,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
|
auto [num_blocks, block_dims] = get_launch_args(idx, large);
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -30,4 +30,25 @@ std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2) {
|
|||||||
return std::make_pair(dim3(gx, gy, gz), dim3(bx, by, bz));
|
return std::make_pair(dim3(gx, gy, gz), dim3(bx, by, bz));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::tuple<dim3, uint> get_launch_args(
|
||||||
|
size_t size,
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides,
|
||||||
|
bool large,
|
||||||
|
int work_per_thread) {
|
||||||
|
size_t nthreads = cuda::ceil_div(size, work_per_thread);
|
||||||
|
uint block_dim = 1024;
|
||||||
|
if (block_dim > nthreads) {
|
||||||
|
block_dim = nthreads;
|
||||||
|
}
|
||||||
|
dim3 num_blocks;
|
||||||
|
if (large) {
|
||||||
|
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
|
||||||
|
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
|
||||||
|
} else {
|
||||||
|
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
|
||||||
|
}
|
||||||
|
return std::make_tuple(num_blocks, block_dim);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -122,37 +122,17 @@ std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
|
|||||||
|
|
||||||
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
|
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
|
||||||
// assuming each thread handles |work_per_thread| elements of |arr|.
|
// assuming each thread handles |work_per_thread| elements of |arr|.
|
||||||
template <typename T>
|
std::tuple<dim3, uint> get_launch_args(
|
||||||
inline std::tuple<dim3, uint> get_launch_args(
|
|
||||||
T kernel,
|
|
||||||
size_t size,
|
size_t size,
|
||||||
const Shape& shape,
|
const Shape& shape,
|
||||||
const Strides& strides,
|
const Strides& strides,
|
||||||
bool large,
|
bool large,
|
||||||
int work_per_thread = 1) {
|
int work_per_thread = 1);
|
||||||
size_t nthreads = cuda::ceil_div(size, work_per_thread);
|
|
||||||
uint block_dim = 1024;
|
|
||||||
if (block_dim > nthreads) {
|
|
||||||
block_dim = nthreads;
|
|
||||||
}
|
|
||||||
dim3 num_blocks;
|
|
||||||
if (large) {
|
|
||||||
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
|
|
||||||
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
|
|
||||||
} else {
|
|
||||||
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
|
|
||||||
}
|
|
||||||
return std::make_tuple(num_blocks, block_dim);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
inline std::tuple<dim3, uint>
|
||||||
inline std::tuple<dim3, uint> get_launch_args(
|
get_launch_args(const array& arr, bool large, int work_per_thread = 1) {
|
||||||
T kernel,
|
|
||||||
const array& arr,
|
|
||||||
bool large,
|
|
||||||
int work_per_thread = 1) {
|
|
||||||
return get_launch_args(
|
return get_launch_args(
|
||||||
kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
|
arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -43,20 +43,19 @@ __global__ void logsumexp(const T* in, T* out, int axis_size) {
|
|||||||
AccT maxval = Limits<AccT>::finite_min();
|
AccT maxval = Limits<AccT>::finite_min();
|
||||||
AccT normalizer = 0;
|
AccT normalizer = 0;
|
||||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||||
AccT vals[N_READS];
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
cub::LoadDirectBlocked(
|
auto vals = load_vector<N_READS>(in, index, axis_size, Limits<T>::min());
|
||||||
r * BLOCK_DIM + block.thread_rank(),
|
|
||||||
make_cast_iterator<AccT>(in),
|
|
||||||
vals,
|
|
||||||
axis_size,
|
|
||||||
Limits<AccT>::min());
|
|
||||||
prevmax = maxval;
|
prevmax = maxval;
|
||||||
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
maxval = max_op(maxval, static_cast<AccT>(vals[i]));
|
||||||
|
}
|
||||||
// Online normalizer calculation for softmax:
|
// Online normalizer calculation for softmax:
|
||||||
// https://github.com/NVIDIA/online-softmax
|
// https://github.com/NVIDIA/online-softmax
|
||||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
normalizer = normalizer + softmax_exp(vals[i] - maxval);
|
normalizer =
|
||||||
|
normalizer + softmax_exp(static_cast<AccT>(vals[i]) - maxval);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -143,9 +142,9 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) {
|
dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) {
|
||||||
constexpr int N_READS = 4;
|
|
||||||
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
|
||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
constexpr int N_READS = 16 / sizeof(DataType);
|
||||||
|
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||||
auto kernel = cu::logsumexp<DataType, float, block_dim(), N_READS>;
|
auto kernel = cu::logsumexp<DataType, float, block_dim(), N_READS>;
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
|
|||||||
@@ -1,47 +1,11 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
#include "mlx/backend/cuda/device/arange.cuh"
|
|
||||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
|
||||||
#include "mlx/distributed/primitives.h"
|
#include "mlx/distributed/primitives.h"
|
||||||
#include "mlx/dtype_utils.h"
|
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
#include <nvtx3/nvtx3.hpp>
|
|
||||||
#include <thrust/device_ptr.h>
|
|
||||||
#include <thrust/transform.h>
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
nvtx3::scoped_range r("Arange::eval_gpu");
|
|
||||||
assert(inputs.size() == 0);
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
if (out.size() == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto& encoder = cu::get_command_encoder(stream());
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
auto capture = encoder.capture_context();
|
|
||||||
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
|
|
||||||
using CTYPE = MLX_GET_TYPE(type_tag);
|
|
||||||
using OutType = cuda_type_t<CTYPE>;
|
|
||||||
CTYPE step =
|
|
||||||
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
|
|
||||||
thrust::transform(
|
|
||||||
cu::thrust_policy(encoder.stream()),
|
|
||||||
thrust::counting_iterator<uint32_t>(0),
|
|
||||||
thrust::counting_iterator<uint32_t>(out.data_size()),
|
|
||||||
thrust::device_pointer_cast(out.data<OutType>()),
|
|
||||||
cu::Arange<OutType>{
|
|
||||||
static_cast<OutType>(start_), static_cast<OutType>(step)});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
bool fast::ScaledDotProductAttention::use_fallback(
|
bool fast::ScaledDotProductAttention::use_fallback(
|
||||||
const array& q,
|
const array& q,
|
||||||
const array& k,
|
const array& k,
|
||||||
@@ -350,12 +350,10 @@ void fast::AffineQuantize::eval_gpu(
|
|||||||
dispatch_bits(bits_, [&](auto bits) {
|
dispatch_bits(bits_, [&](auto bits) {
|
||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
if (dequantize_) {
|
if (dequantize_) {
|
||||||
auto kernel =
|
|
||||||
cu::affine_dequantize<DataType, group_size.value, bits.value>;
|
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] =
|
||||||
get_launch_args(kernel, size, grid_shape, w.strides(), large);
|
get_launch_args(size, grid_shape, w.strides(), large);
|
||||||
enc.add_kernel_node(
|
enc.add_kernel_node(
|
||||||
kernel,
|
cu::affine_dequantize<DataType, group_size.value, bits.value>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
w.data<uint8_t>(),
|
w.data<uint8_t>(),
|
||||||
@@ -364,12 +362,10 @@ void fast::AffineQuantize::eval_gpu(
|
|||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
out.size());
|
out.size());
|
||||||
} else {
|
} else {
|
||||||
auto kernel =
|
|
||||||
cu::affine_quantize<DataType, group_size.value, bits.value>;
|
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] =
|
||||||
get_launch_args(kernel, size, grid_shape, w.strides(), large);
|
get_launch_args(size, grid_shape, w.strides(), large);
|
||||||
enc.add_kernel_node(
|
enc.add_kernel_node(
|
||||||
kernel,
|
cu::affine_quantize<DataType, group_size.value, bits.value>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
w.data<DataType>(),
|
w.data<DataType>(),
|
||||||
|
|||||||
@@ -11,7 +11,6 @@
|
|||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
#include <cooperative_groups/reduce.h>
|
#include <cooperative_groups/reduce.h>
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
#include <cub/block/block_load.cuh>
|
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
@@ -45,20 +44,21 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
|
|||||||
AccT maxval = Limits<AccT>::finite_min();
|
AccT maxval = Limits<AccT>::finite_min();
|
||||||
AccT normalizer = cast_to<AccT>(0);
|
AccT normalizer = cast_to<AccT>(0);
|
||||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||||
AccT vals[N_READS];
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
cub::LoadDirectBlocked(
|
auto vals = load_vector<N_READS>(in, index, axis_size, Limits<T>::min());
|
||||||
r * BLOCK_DIM + block.thread_rank(),
|
|
||||||
make_cast_iterator<AccT>(in),
|
|
||||||
vals,
|
|
||||||
axis_size,
|
|
||||||
Limits<AccT>::min());
|
|
||||||
prevmax = maxval;
|
prevmax = maxval;
|
||||||
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
maxval = max_op(maxval, static_cast<AccT>(vals[i]));
|
||||||
|
}
|
||||||
|
|
||||||
// Online normalizer calculation for softmax:
|
// Online normalizer calculation for softmax:
|
||||||
// https://github.com/NVIDIA/online-softmax
|
// https://github.com/NVIDIA/online-softmax
|
||||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||||
|
#pragma unroll
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
normalizer = normalizer + softmax_exp(vals[i] - maxval);
|
normalizer =
|
||||||
|
normalizer + softmax_exp(static_cast<AccT>(vals[i]) - maxval);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,12 +95,11 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
|
|||||||
// Write output.
|
// Write output.
|
||||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
T vals[N_READS];
|
auto vals = load_vector<N_READS>(in, index, axis_size, T(0));
|
||||||
cub::LoadDirectBlocked(index, in, vals, axis_size);
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
vals[i] = softmax_exp(static_cast<AccT>(vals[i]) - maxval) * normalizer;
|
vals[i] = softmax_exp(static_cast<AccT>(vals[i]) - maxval) * normalizer;
|
||||||
}
|
}
|
||||||
cub::StoreDirectBlocked(index, out, vals, axis_size);
|
store_vector<N_READS>(out, index, vals, axis_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -141,9 +140,9 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) {
|
dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) {
|
||||||
constexpr int N_READS = 4;
|
|
||||||
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
|
||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
constexpr int N_READS = 16 / sizeof(DataType);
|
||||||
|
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||||
auto kernel = cu::softmax<DataType, DataType, block_dim(), N_READS>;
|
auto kernel = cu::softmax<DataType, DataType, block_dim(), N_READS>;
|
||||||
if (precise) {
|
if (precise) {
|
||||||
kernel = cu::softmax<DataType, float, block_dim(), N_READS>;
|
kernel = cu::softmax<DataType, float, block_dim(), N_READS>;
|
||||||
|
|||||||
@@ -125,12 +125,9 @@ void ternary_op_gpu_inplace(
|
|||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||||
auto kernel =
|
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||||
cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>;
|
|
||||||
auto [num_blocks, block_dims] =
|
|
||||||
get_launch_args(kernel, out, large());
|
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
a.data<bool>(),
|
a.data<bool>(),
|
||||||
@@ -144,11 +141,9 @@ void ternary_op_gpu_inplace(
|
|||||||
const_param<dims_constant()>(c_strides));
|
const_param<dims_constant()>(c_strides));
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
auto kernel = cu::ternary_g<Op, DType, IdxT>;
|
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||||
auto [num_blocks, block_dims] =
|
|
||||||
get_launch_args(kernel, out, large());
|
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
cu::ternary_g<Op, DType, IdxT>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
a.data<bool>(),
|
a.data<bool>(),
|
||||||
@@ -167,16 +162,10 @@ void ternary_op_gpu_inplace(
|
|||||||
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
|
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
|
||||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||||
constexpr int N_READS = 16 / sizeof(DType);
|
constexpr int N_READS = 16 / sizeof(DType);
|
||||||
auto kernel = cu::ternary_v<Op, DType, IdxT, N_READS>;
|
|
||||||
auto [num_blocks, block_dims] = get_launch_args(
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
kernel,
|
out.data_size(), out.shape(), out.strides(), large(), N_READS);
|
||||||
out.data_size(),
|
|
||||||
out.shape(),
|
|
||||||
out.strides(),
|
|
||||||
large(),
|
|
||||||
N_READS);
|
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
cu::ternary_v<Op, DType, IdxT, N_READS>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
a.data<bool>(),
|
a.data<bool>(),
|
||||||
|
|||||||
@@ -129,16 +129,10 @@ void unary_op_gpu_inplace(
|
|||||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||||
// TODO: Choose optimized value based on type size.
|
// TODO: Choose optimized value based on type size.
|
||||||
constexpr int N_READS = 4;
|
constexpr int N_READS = 4;
|
||||||
auto kernel = cu::unary_v<Op, InType, OutType, IdxT, N_READS>;
|
|
||||||
auto [num_blocks, block_dims] = get_launch_args(
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
kernel,
|
out.data_size(), out.shape(), out.strides(), large, N_READS);
|
||||||
out.data_size(),
|
|
||||||
out.shape(),
|
|
||||||
out.strides(),
|
|
||||||
large,
|
|
||||||
N_READS);
|
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
cu::unary_v<Op, InType, OutType, IdxT, N_READS>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
in.data<InType>(),
|
in.data<InType>(),
|
||||||
@@ -147,10 +141,9 @@ void unary_op_gpu_inplace(
|
|||||||
} else {
|
} else {
|
||||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||||
auto [shape, strides] = collapse_contiguous_dims(in);
|
auto [shape, strides] = collapse_contiguous_dims(in);
|
||||||
auto kernel = cu::unary_g<Op, InType, OutType, IdxT>;
|
auto [num_blocks, block_dims] = get_launch_args(out, large);
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
cu::unary_g<Op, InType, OutType, IdxT>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
in.data<InType>(),
|
in.data<InType>(),
|
||||||
|
|||||||
Reference in New Issue
Block a user