mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
[CUDA] ternary with select op (#2283)
* cuda ternary with select op * comment + fix * fix
This commit is contained in:
parent
aa07429bad
commit
2188199ff8
@ -34,6 +34,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||||
|
12
mlx/backend/cuda/device/ternary_ops.cuh
Normal file
12
mlx/backend/cuda/device/ternary_ops.cuh
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
struct Select {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(bool condition, T x, T y) {
|
||||||
|
return condition ? x : y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
@ -162,6 +162,27 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
|
|||||||
return cuda::std::make_tuple(a_loc, b_loc);
|
return cuda::std::make_tuple(a_loc, b_loc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int NDIM, typename IdxT = int64_t>
|
||||||
|
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
|
||||||
|
IdxT elem,
|
||||||
|
const int* shape,
|
||||||
|
const int64_t* a_strides,
|
||||||
|
const int64_t* b_strides,
|
||||||
|
const int64_t* c_strides) {
|
||||||
|
IdxT a_loc = 0;
|
||||||
|
IdxT b_loc = 0;
|
||||||
|
IdxT c_loc = 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = NDIM - 1; i >= 0; --i) {
|
||||||
|
int dim_idx = elem % shape[i];
|
||||||
|
a_loc += dim_idx * a_strides[i];
|
||||||
|
b_loc += dim_idx * b_strides[i];
|
||||||
|
c_loc += dim_idx * c_strides[i];
|
||||||
|
elem /= shape[i];
|
||||||
|
}
|
||||||
|
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||||
|
}
|
||||||
|
|
||||||
// Optimized version when ndim is larger than 4.
|
// Optimized version when ndim is larger than 4.
|
||||||
template <typename IdxT = int64_t>
|
template <typename IdxT = int64_t>
|
||||||
inline __host__ __device__ IdxT
|
inline __host__ __device__ IdxT
|
||||||
@ -191,6 +212,26 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
|||||||
return cuda::std::make_tuple(a_loc, b_loc);
|
return cuda::std::make_tuple(a_loc, b_loc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename IdxT = int64_t>
|
||||||
|
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
|
||||||
|
IdxT elem,
|
||||||
|
const int* shape,
|
||||||
|
const int64_t* a_strides,
|
||||||
|
const int64_t* b_strides,
|
||||||
|
const int64_t* c_strides,
|
||||||
|
int ndim) {
|
||||||
|
auto [a_loc, b_loc, c_loc] =
|
||||||
|
elem_to_loc_nd<3>(elem, shape, a_strides, b_strides, c_strides);
|
||||||
|
for (int i = ndim - 1; i >= 3; --i) {
|
||||||
|
int dim_idx = elem % shape[i];
|
||||||
|
a_loc += dim_idx * a_strides[i];
|
||||||
|
b_loc += dim_idx * b_strides[i];
|
||||||
|
c_loc += dim_idx * c_strides[i];
|
||||||
|
elem /= shape[i];
|
||||||
|
}
|
||||||
|
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Elem to loc in a loop utils
|
// Elem to loc in a loop utils
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -91,7 +91,6 @@ NO_GPU(QuantizedMatmul)
|
|||||||
NO_GPU(Scan)
|
NO_GPU(Scan)
|
||||||
NO_GPU(Scatter)
|
NO_GPU(Scatter)
|
||||||
NO_GPU(ScatterAxis)
|
NO_GPU(ScatterAxis)
|
||||||
NO_GPU(Select)
|
|
||||||
NO_GPU_MULTI(SVD)
|
NO_GPU_MULTI(SVD)
|
||||||
NO_GPU(Inverse)
|
NO_GPU(Inverse)
|
||||||
NO_GPU(Cholesky)
|
NO_GPU(Cholesky)
|
||||||
|
177
mlx/backend/cuda/ternary.cu
Normal file
177
mlx/backend/cuda/ternary.cu
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
#include "mlx/backend/common/ternary.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/ternary_ops.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename Op, typename T, typename IdxT>
|
||||||
|
__global__ void
|
||||||
|
ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
out[index] = Op{}(a[index], b[index], c[index]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename T, typename IdxT, int NDIM>
|
||||||
|
__global__ void ternary_g_nd(
|
||||||
|
const bool* a,
|
||||||
|
const T* b,
|
||||||
|
const T* c,
|
||||||
|
T* out,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> c_strides) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto [a_idx, b_idx, c_idx] = elem_to_loc_nd<NDIM>(
|
||||||
|
index,
|
||||||
|
shape.data(),
|
||||||
|
a_strides.data(),
|
||||||
|
b_strides.data(),
|
||||||
|
c_strides.data());
|
||||||
|
out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename T, typename IdxT>
|
||||||
|
__global__ void ternary_g(
|
||||||
|
const bool* a,
|
||||||
|
const T* b,
|
||||||
|
const T* c,
|
||||||
|
T* out,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ Shape shape,
|
||||||
|
const __grid_constant__ Strides a_strides,
|
||||||
|
const __grid_constant__ Strides b_strides,
|
||||||
|
const __grid_constant__ Strides c_strides,
|
||||||
|
int ndim) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto [a_idx, b_idx, c_idx] = elem_to_loc_4d(
|
||||||
|
index,
|
||||||
|
shape.data(),
|
||||||
|
a_strides.data(),
|
||||||
|
b_strides.data(),
|
||||||
|
c_strides.data(),
|
||||||
|
ndim);
|
||||||
|
out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void ternary_op_gpu_inplace(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
array& out,
|
||||||
|
const Stream& s) {
|
||||||
|
const auto& a = inputs[0];
|
||||||
|
const auto& b = inputs[1];
|
||||||
|
const auto& c = inputs[2];
|
||||||
|
if (out.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_input_array(c);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE, {
|
||||||
|
using DType = cuda_type_t<CTYPE>;
|
||||||
|
|
||||||
|
auto topt = get_ternary_op_type(a, b, c);
|
||||||
|
if (topt == TernaryOpType::General) {
|
||||||
|
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
|
||||||
|
auto& a_strides = strides[0];
|
||||||
|
auto& b_strides = strides[1];
|
||||||
|
auto& c_strides = strides[2];
|
||||||
|
bool large = a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX ||
|
||||||
|
c.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
||||||
|
MLX_SWITCH_BOOL(large, LARGE, {
|
||||||
|
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||||
|
int ndim = shape.size();
|
||||||
|
if (ndim <= 3) {
|
||||||
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
|
auto kernel = cu::ternary_g_nd<Op, DType, IdxT, NDIM>;
|
||||||
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(kernel, out, large);
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
a.data<bool>(),
|
||||||
|
b.data<DType>(),
|
||||||
|
c.data<DType>(),
|
||||||
|
out.data<DType>(),
|
||||||
|
out.data_size(),
|
||||||
|
const_param<NDIM>(shape),
|
||||||
|
const_param<NDIM>(a_strides),
|
||||||
|
const_param<NDIM>(b_strides),
|
||||||
|
const_param<NDIM>(c_strides));
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
auto kernel = cu::ternary_g<Op, DType, IdxT>;
|
||||||
|
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
a.data<bool>(),
|
||||||
|
b.data<DType>(),
|
||||||
|
c.data<DType>(),
|
||||||
|
out.data<DType>(),
|
||||||
|
out.data_size(),
|
||||||
|
const_param(shape),
|
||||||
|
const_param(a_strides),
|
||||||
|
const_param(b_strides),
|
||||||
|
const_param(c_strides),
|
||||||
|
ndim);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
||||||
|
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||||
|
auto kernel = cu::ternary_v<Op, DType, IdxT>;
|
||||||
|
auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE);
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
a.data<bool>(),
|
||||||
|
b.data<DType>(),
|
||||||
|
c.data<DType>(),
|
||||||
|
out.data<DType>(),
|
||||||
|
out.data_size());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void ternary_op_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
array& out,
|
||||||
|
const Stream& s) {
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
auto& c = inputs[2];
|
||||||
|
auto topt = get_ternary_op_type(a, b, c);
|
||||||
|
set_ternary_op_output_data(a, b, c, out, topt);
|
||||||
|
ternary_op_gpu_inplace<Op>(inputs, out, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("select::eval_gpu");
|
||||||
|
auto& s = out.primitive().stream();
|
||||||
|
ternary_op_gpu<cu::Select>(inputs, out, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
Loading…
Reference in New Issue
Block a user