This commit is contained in:
Awni Hannun 2025-06-16 03:44:55 +00:00 committed by GitHub
commit 61e5d59bf9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 60 additions and 36 deletions

View File

@ -165,7 +165,7 @@ void binary_op_gpu_inplace(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size(),
out.size(),
const_param<NDIM>(shape),
const_param<NDIM>(a_strides),
const_param<NDIM>(b_strides));
@ -178,7 +178,7 @@ void binary_op_gpu_inplace(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size(),
out.size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
@ -197,7 +197,7 @@ void binary_op_gpu_inplace(
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, LARGE);
get_launch_args(kernel, out.data_size(), out.shape(), out.strides(), LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
@ -264,7 +264,6 @@ BINARY_GPU(Add)
BINARY_GPU(ArcTan2)
BINARY_GPU(Divide)
BINARY_GPU(Remainder)
BINARY_GPU(Equal)
BINARY_GPU(Greater)
BINARY_GPU(GreaterEqual)
BINARY_GPU(Less)
@ -279,6 +278,17 @@ BINARY_GPU(NotEqual)
BINARY_GPU(Power)
BINARY_GPU(Subtract)
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Equal::eval_gpu");
auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
if (equal_nan_) {
binary_op_gpu<cu::NaNEqual>(inputs, out, op, s);
} else {
binary_op_gpu<cu::Equal>(inputs, out, op, s);
}
}
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();

View File

@ -6,7 +6,7 @@
namespace mlx::core {
void copy_gpu_inplace(
const array& in_,
const array& in,
array& out,
const Shape& shape,
const Strides& strides_in,
@ -20,7 +20,6 @@ void copy_gpu_inplace(
if (out.size() == 0) {
return;
}
const array& in = in_.data_shared_ptr() ? in_ : out;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);

View File

@ -15,14 +15,7 @@ namespace mlx::core {
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
using InType = cuda_type_t<CTYPE_IN>; \
using OutType = cuda_type_t<CTYPE_OUT>; \
if constexpr (cu::CastOp<InType, OutType>::is_castable) { \
__VA_ARGS__; \
} else { \
throw std::runtime_error(fmt::format( \
"Can not copy data from dtype {} to {}.", \
dtype_to_string(out.dtype()), \
dtype_to_string(in.dtype()))); \
} \
__VA_ARGS__; \
}); \
})

View File

@ -43,7 +43,8 @@ void copy_contiguous(
if (ctype == CopyType::Vector) {
kernel = cu::copy_v<InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE);
auto [num_blocks, block_dims] =
get_launch_args(kernel, out.data_size(), out.shape(), out.strides(), LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>() + in_offset,
out.data<OutType>() + out_offset,

View File

@ -59,9 +59,9 @@ void copy_general(
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
@ -70,7 +70,7 @@ void copy_general(
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
out.size(),
const_param<NDIM>(shape),
const_param<NDIM>(strides_in),
const_param<NDIM>(strides_out));
@ -81,7 +81,7 @@ void copy_general(
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
out.size(),
const_param(shape),
const_param(strides_in),
const_param(strides_out),

View File

@ -65,9 +65,9 @@ void copy_general_dynamic(
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
@ -76,7 +76,7 @@ void copy_general_dynamic(
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
out.size(),
const_param<NDIM>(shape),
const_param<NDIM>(strides_in),
const_param<NDIM>(strides_out),
@ -89,7 +89,7 @@ void copy_general_dynamic(
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
out.size(),
const_param(shape),
const_param(strides_in),
const_param(strides_out),

View File

@ -54,9 +54,9 @@ void copy_general_input(
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
@ -65,7 +65,7 @@ void copy_general_input(
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
out.size(),
const_param<NDIM>(shape),
const_param<NDIM>(strides_in));
});
@ -75,7 +75,7 @@ void copy_general_input(
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
out.size(),
const_param(shape),
const_param(strides_in),
ndim);

View File

@ -45,6 +45,19 @@ struct CastOp<
}
};
template <typename SrcT, typename DstT>
struct CastOp<
SrcT,
DstT,
cuda::std::enable_if_t<cuda::std::is_same_v<SrcT, DstT>>> {
static constexpr bool is_castable = true;
__device__ SrcT operator()(SrcT x) {
return x;
}
};
// Return an iterator that cast the value to DstT using CastOp.
template <typename DstT, typename Iterator>
__host__ __device__ auto make_cast_iterator(Iterator it) {

View File

@ -136,17 +136,19 @@ inline uint max_occupancy_block_dim(T kernel) {
template <typename T>
inline std::tuple<dim3, uint> get_launch_args(
T kernel,
const array& arr,
size_t size,
const Shape& shape,
const Strides& strides,
bool large,
int work_per_thread = 1) {
size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread);
size_t nthreads = cuda::ceil_div(size, work_per_thread);
uint block_dim = max_occupancy_block_dim(kernel);
if (block_dim > nthreads) {
block_dim = nthreads;
}
dim3 num_blocks;
if (large) {
num_blocks = get_2d_grid_dims(arr.shape(), arr.strides(), work_per_thread);
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
} else {
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
@ -154,4 +156,13 @@ inline std::tuple<dim3, uint> get_launch_args(
return std::make_tuple(num_blocks, block_dim);
}
template <typename T>
inline std::tuple<dim3, uint> get_launch_args(
T kernel,
const array& arr,
bool large,
int work_per_thread = 1) {
return get_launch_args(kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
}
} // namespace mlx::core

View File

@ -116,7 +116,7 @@ void ternary_op_gpu_inplace(
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
out.data_size(),
out.size(),
const_param<NDIM>(shape),
const_param<NDIM>(a_strides),
const_param<NDIM>(b_strides),
@ -142,7 +142,7 @@ void ternary_op_gpu_inplace(
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
auto kernel = cu::ternary_v<Op, DType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE);
auto [num_blocks, block_dims] = get_launch_args(kernel, out.data_size(), out.shape(), out.strides(), LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<bool>(),
b.data<DType>(),

View File

@ -91,7 +91,7 @@ void unary_op_gpu_inplace(
} else {
auto [shape, strides] = collapse_contiguous_dims(in);
auto [in_begin, in_end] = cu::make_general_iterators<int64_t>(
in_ptr, in.data_size(), shape, strides);
in_ptr, in.size(), shape, strides);
thrust::transform(policy, in_begin, in_end, out_ptr, Op());
}
} else {

View File

@ -100,7 +100,6 @@ cuda_skip = {
"TestOps.test_hadamard",
"TestOps.test_hadamard_grad_vmap",
"TestOps.test_irregular_binary_ops",
"TestOps.test_isfinite",
"TestOps.test_kron",
"TestOps.test_log",
"TestOps.test_log10",
@ -115,7 +114,6 @@ cuda_skip = {
"TestOps.test_sort",
"TestOps.test_tensordot",
"TestOps.test_tile",
"TestOps.test_view",
"TestQuantized.test_gather_matmul_grad",
"TestQuantized.test_gather_qmm",
"TestQuantized.test_gather_qmm_sorted",
@ -136,7 +134,6 @@ cuda_skip = {
"TestReduce.test_expand_sums",
"TestReduce.test_many_reduction_axes",
"TestUpsample.test_torch_upsample",
"TestVmap.test_unary",
"TestVmap.test_vmap_conv",
"TestVmap.test_vmap_inverse",
"TestVmap.test_vmap_svd",