Compare commits

...

12 Commits

Author SHA1 Message Date
acsweet
e2a2bae148
Merge 992eac905a into b8022c578a 2025-06-17 20:34:16 +02:00
Awni Hannun
b8022c578a
divmod, partition, sort fixes (#2302) 2025-06-16 18:49:32 -07:00
Awni Hannun
bc53f8293f
Cuda bug fixes 2 (#2298)
* more bug fixes

* more bug fixes

* format
2025-06-16 13:14:46 -07:00
Awni Hannun
c552ff2451
[CUDA] Fix back-end bugs and enable corresponding tests (#2296)
* Fix some cuda back-end bugs and enable corresponding tests

* more fixes

* enable more tests

* format
2025-06-16 08:45:40 -07:00
Awni Hannun
4fda5fbdf9
add python testing for cuda with ability to skip list of tests (#2295) 2025-06-15 10:56:48 -07:00
Angelos Katharopoulos
580776559b
RoPE for CUDA (#2293)
* First working CUDA rope

* Fix random
2025-06-15 06:08:07 -07:00
Awni Hannun
a14aaa7c9d
Fix cuda arg reduce (#2291) 2025-06-14 17:54:00 -07:00
Awni Hannun
a6d780154f
fix cuda gemm for bf16 (#2288) 2025-06-13 22:10:46 -07:00
Awni Hannun
6871e2eeb7
fix cuda jit (#2287) 2025-06-13 19:21:46 -07:00
acsweet
992eac905a
Merge branch 'main' into metal-thread-safe 2025-05-27 09:40:36 -07:00
Andrew Sweet
c8d4d97447 tests added 2025-05-07 11:59:54 -07:00
Andrew Sweet
28902ece4e updated, simplified mutex for thread safety 2025-04-30 16:17:12 -07:00
76 changed files with 1396 additions and 255 deletions

View File

@ -234,6 +234,7 @@ jobs:
command: |
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
build_release:
parameters:

View File

@ -107,6 +107,16 @@ same array:
>>> a
array([1, 2, 0], dtype=int32)
Note, unlike NumPy, updates to the same location are nondeterministic:
.. code-block:: shell
>>> a = mx.array([1, 2, 3])
>>> a[[0, 0]] = mx.array([4, 5])
The first element of ``a`` could be ``4`` or ``5``.
Transformations of functions which use in-place updates are allowed and work as
expected. For example:

View File

@ -209,4 +209,14 @@ Dims get_2d_grid_dims_common(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
}
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2);
auto gx = (dim0 + bx - 1) / bx;
auto gy = (dim1 + by - 1) / by;
auto gz = (dim2 + bz - 1) / bz;
return std::make_pair(
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
}
} // namespace mlx::core

View File

@ -95,6 +95,9 @@ Dims get_2d_grid_dims_common(
const Strides& strides,
size_t divisor);
// Get both the block and a grid of blocks that covers dim0, dim1 and dim2.
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2);
struct ContiguousIterator {
inline void step() {
int dims = shape_.size();

View File

@ -8,6 +8,7 @@ target_sources(
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
@ -32,6 +33,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu

View File

@ -1,5 +1,4 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
@ -113,7 +112,7 @@ __global__ void arg_reduce_general(
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T vals[N_READS];
auto tid = r * BLOCK_DIM + block.thread_index().z;
auto tid = r * BLOCK_DIM + block.thread_index().x;
cub::LoadDirectBlocked(
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
best = op.reduce_many(best, vals, tid * N_READS);
@ -158,7 +157,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
constexpr uint32_t N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
dim3 block_dims{1, 1, BLOCK_DIM};
dim3 block_dims{BLOCK_DIM, 1, 1};
auto kernel = &cu::arg_reduce_general<
InType,
cu::ArgMax<InType>,

View File

@ -101,10 +101,12 @@ constexpr bool supports_binary_op() {
return std::is_same_v<Out, bool> && std::is_same_v<In, bool>;
}
if (std::is_same_v<Op, NaNEqual>) {
return std::is_same_v<Out, bool> &&
(is_floating_v<In> || std::is_same_v<In, complex64_t>);
return std::is_same_v<Out, bool> && is_inexact_v<In>;
}
if (std::is_same_v<Op, LogAddExp> || std::is_same_v<Op, ArcTan2>) {
if (std::is_same_v<Op, LogAddExp>) {
return std::is_same_v<In, Out> && is_inexact_v<In>;
}
if (std::is_same_v<Op, ArcTan2>) {
return std::is_same_v<In, Out> && is_floating_v<In>;
}
if (std::is_same_v<Op, BitwiseAnd> || std::is_same_v<Op, BitwiseOr> ||
@ -123,13 +125,12 @@ constexpr bool supports_binary_op() {
template <typename Op>
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
array& out,
std::string_view op,
const Stream& s) {
assert(inputs.size() > 1);
const auto& a = inputs[0];
const auto& b = inputs[1];
auto& out = outputs[0];
if (out.size() == 0) {
return;
}
@ -144,16 +145,15 @@ void binary_op_gpu_inplace(
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
auto bopt = get_binary_op_type(a, b);
if (bopt == BinaryOpType::General) {
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& a_strides = strides[0];
auto& b_strides = strides[1];
bool large = a.data_size() > UINT32_MAX ||
b.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
bool large = a.data_size() > INT32_MAX ||
b.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
@ -165,7 +165,7 @@ void binary_op_gpu_inplace(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size(),
out.size(),
const_param<NDIM>(shape),
const_param<NDIM>(a_strides),
const_param<NDIM>(b_strides));
@ -178,7 +178,7 @@ void binary_op_gpu_inplace(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size(),
out.size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
@ -196,8 +196,8 @@ void binary_op_gpu_inplace(
} else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, LARGE);
auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
@ -217,20 +217,6 @@ void binary_op_gpu_inplace(
});
}
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt);
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
@ -241,8 +227,7 @@ void binary_op_gpu(
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
std::vector<array> outputs{out};
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
binary_op_gpu_inplace<Op>(inputs, out, op, s);
}
#define BINARY_GPU(func) \
@ -252,19 +237,10 @@ void binary_op_gpu(
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
}
#define BINARY_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = outputs[0].primitive().stream(); \
binary_op_gpu<cu::func>(inputs, outputs, get_primitive_string(this), s); \
}
BINARY_GPU(Add)
BINARY_GPU(ArcTan2)
BINARY_GPU(Divide)
BINARY_GPU(Remainder)
BINARY_GPU(Equal)
BINARY_GPU(Greater)
BINARY_GPU(GreaterEqual)
BINARY_GPU(Less)
@ -279,6 +255,17 @@ BINARY_GPU(NotEqual)
BINARY_GPU(Power)
BINARY_GPU(Subtract)
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Equal::eval_gpu");
auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
if (equal_nan_) {
binary_op_gpu<cu::NaNEqual>(inputs, out, op, s);
} else {
binary_op_gpu<cu::Equal>(inputs, out, op, s);
}
}
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();

View File

@ -0,0 +1,248 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/binary.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/binary_ops.cuh"
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void
binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[0], b[0]);
out_a[0] = out[0];
out_b[0] = out[1];
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void
binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[0], b[index]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void
binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[index], b[0]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void
binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[index], b[index]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
__global__ void binary_g_nd(
const In* a,
const In* b,
Out* out_a,
Out* out_b,
IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index, shape.data(), a_strides.data(), b_strides.data());
auto out = Op{}(a[a_idx], b[b_idx]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_g(
const In* a,
const In* b,
Out* out_a,
Out* out_b,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_4d(
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
auto out = Op{}(a[a_idx], b[b_idx]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out>
constexpr bool supports_binary_op() {
if (std::is_same_v<Op, DivMod>) {
return std::is_same_v<In, Out> &&
(std::is_integral_v<Out> || is_floating_v<Out>);
}
return false;
}
} // namespace cu
template <typename Op>
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
assert(inputs.size() > 1);
const auto& a = inputs[0];
const auto& b = inputs[1];
auto& out_a = outputs[0];
auto& out_b = outputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
if (out_a.size() == 0) {
return;
}
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
encoder.set_output_array(out_b);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
MLX_SWITCH_ALL_TYPES(out_a.dtype(), CTYPE_OUT, {
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
auto bopt = get_binary_op_type(a, b);
if (bopt == BinaryOpType::General) {
auto [shape, strides] = collapse_contiguous_dims(a, b, out_a);
auto& a_strides = strides[0];
auto& b_strides = strides[1];
bool large = a.data_size() > INT32_MAX ||
b.data_size() > INT32_MAX || out_a.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel =
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.size(),
const_param<NDIM>(shape),
const_param<NDIM>(a_strides),
const_param<NDIM>(b_strides));
});
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
ndim);
}
});
} else {
MLX_SWITCH_BOOL(out_a.data_size() > UINT32_MAX, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorScalar) {
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] = get_launch_args(
kernel,
out_a.data_size(),
out_a.shape(),
out_a.strides(),
LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.data_size());
});
}
} else {
throw std::runtime_error(fmt::format(
"Can not do binary op {} on inputs of {} with result of {}.",
op,
dtype_to_string(a.dtype()),
dtype_to_string(out_a.dtype())));
}
});
});
});
}
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt);
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
void DivMod::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("DivMod::eval_gpu");
auto& s = outputs[0].primitive().stream();
binary_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s);
}
} // namespace mlx::core

View File

@ -130,11 +130,13 @@ struct FusedKernelBuilder {
constexpr const char* g_jit_includes = R"(
#include "mlx/backend/cuda/device/binary_ops.cuh"
#include "mlx/backend/cuda/device/ternary_ops.cuh"
#include "mlx/backend/cuda/device/unary_ops.cuh"
#include "mlx/backend/cuda/device/utils.cuh"
#include <cooperative_groups.h>
#define inf cuda::std::numeric_limits<float>::infinity()
)";
void Compiled::eval_gpu(

View File

@ -6,7 +6,7 @@
namespace mlx::core {
void copy_gpu_inplace(
const array& in_,
const array& in,
array& out,
const Shape& shape,
const Strides& strides_in,
@ -20,7 +20,6 @@ void copy_gpu_inplace(
if (out.size() == 0) {
return;
}
const array& in = in_.data_shared_ptr() ? in_ : out;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);

View File

@ -10,20 +10,13 @@
namespace mlx::core {
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
using InType = cuda_type_t<CTYPE_IN>; \
using OutType = cuda_type_t<CTYPE_OUT>; \
if constexpr (cu::CastOp<InType, OutType>::is_castable) { \
__VA_ARGS__; \
} else { \
throw std::runtime_error(fmt::format( \
"Can not copy data from dtype {} to {}.", \
dtype_to_string(out.dtype()), \
dtype_to_string(in.dtype()))); \
} \
}); \
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
using InType = cuda_type_t<CTYPE_IN>; \
using OutType = cuda_type_t<CTYPE_OUT>; \
__VA_ARGS__; \
}); \
})
void copy_contiguous(

View File

@ -43,7 +43,8 @@ void copy_contiguous(
if (ctype == CopyType::Vector) {
kernel = cu::copy_v<InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE);
auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>() + in_offset,
out.data<OutType>() + out_offset,

View File

@ -59,9 +59,9 @@ void copy_general(
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
@ -70,7 +70,7 @@ void copy_general(
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
out.size(),
const_param<NDIM>(shape),
const_param<NDIM>(strides_in),
const_param<NDIM>(strides_out));
@ -81,7 +81,7 @@ void copy_general(
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
out.size(),
const_param(shape),
const_param(strides_in),
const_param(strides_out),

View File

@ -65,9 +65,9 @@ void copy_general_dynamic(
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
@ -76,7 +76,7 @@ void copy_general_dynamic(
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
out.size(),
const_param<NDIM>(shape),
const_param<NDIM>(strides_in),
const_param<NDIM>(strides_out),
@ -89,7 +89,7 @@ void copy_general_dynamic(
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
out.size(),
const_param(shape),
const_param(strides_in),
const_param(strides_out),

View File

@ -54,9 +54,9 @@ void copy_general_input(
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
@ -65,7 +65,7 @@ void copy_general_input(
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
out.size(),
const_param<NDIM>(shape),
const_param<NDIM>(strides_in));
});
@ -75,7 +75,7 @@ void copy_general_input(
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
out.size(),
const_param(shape),
const_param(strides_in),
ndim);

View File

@ -1,6 +1,8 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh"
#include <cuComplex.h>
#include <cuda/std/array>
@ -20,7 +22,7 @@ struct FloorDivide {
if constexpr (cuda::std::is_integral_v<T>) {
return x / y;
} else {
return trunc(x / y);
return truncf(x / y);
}
}
};
@ -122,6 +124,26 @@ struct LogAddExp {
? maxval
: T(float(maxval) + log1p(expf(minval - maxval)));
};
__device__ cuComplex operator()(cuComplex x, cuComplex y) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
isnan(cuCimagf(y))) {
return {
cuda::std::numeric_limits<float>::quiet_NaN(),
cuda::std::numeric_limits<float>::quiet_NaN()};
}
float inf = cuda::std::numeric_limits<float>::infinity();
auto maxval = x > y ? x : y;
auto minval = x < y ? x : y;
if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf)
return maxval;
float m = exp(cuCrealf(minval) - cuCrealf(maxval));
cuComplex dexp{
m * cos(cuCimagf(minval) - cuCimagf(maxval)),
m * sin(cuCimagf(minval) - cuCimagf(maxval)),
};
return maxval + log1p(dexp);
}
};
struct Maximum {

View File

@ -45,6 +45,18 @@ struct CastOp<
}
};
template <typename SrcT, typename DstT>
struct CastOp<
SrcT,
DstT,
cuda::std::enable_if_t<cuda::std::is_same_v<SrcT, DstT>>> {
static constexpr bool is_castable = true;
__device__ SrcT operator()(SrcT x) {
return x;
}
};
// Return an iterator that cast the value to DstT using CastOp.
template <typename DstT, typename Iterator>
__host__ __device__ auto make_cast_iterator(Iterator it) {

View File

@ -5,7 +5,7 @@
#pragma once
// The maximum dimensions of shape/strides passed as kernel parameters.
#define MAX_NDIM 8
#define MAX_NDIM 10
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
// warpSize variable exists, using it would prevent compile-time optimizations.

View File

@ -1,4 +1,5 @@
// Copyright © 2025 Apple Inc.
#pragma once
namespace mlx::core::cu {

View File

@ -5,6 +5,8 @@
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh"
#include <math_constants.h>
namespace mlx::core::cu {
struct Abs {
@ -183,21 +185,38 @@ struct Imag {
struct Log {
template <typename T>
__device__ T operator()(T x) {
return log(x);
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto r = log(cuCrealf(Abs{}(x)));
auto i = atan2f(cuCimagf(x), cuCrealf(x));
return {r, i};
} else {
return log(x);
}
}
};
struct Log2 {
template <typename T>
__device__ T operator()(T x) {
return log2(x);
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto y = Log{}(x);
return {cuCrealf(y) / CUDART_LN2_F, cuCimagf(y) / CUDART_LN2_F};
} else {
return log2(x);
}
}
};
struct Log10 {
template <typename T>
__device__ T operator()(T x) {
return log10(x);
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto y = Log{}(x);
return {cuCrealf(y) / CUDART_LNT_F, cuCimagf(y) / CUDART_LNT_F};
return y;
} else {
return log10(x);
}
}
};

View File

@ -187,8 +187,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
template <typename IdxT = int64_t>
inline __host__ __device__ IdxT
elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
IdxT loc = elem_to_loc_nd<3>(elem, shape, strides);
for (int i = ndim - 1; i >= 3; --i) {
IdxT loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
loc += (elem % shape[i]) * IdxT(strides[i]);
elem /= shape[i];
}
@ -202,8 +202,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
const int64_t* a_strides,
const int64_t* b_strides,
int ndim) {
auto [a_loc, b_loc] = elem_to_loc_nd<3>(elem, shape, a_strides, b_strides);
for (int i = ndim - 1; i >= 3; --i) {
IdxT a_loc = 0;
IdxT b_loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
@ -220,9 +221,10 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
const int64_t* b_strides,
const int64_t* c_strides,
int ndim) {
auto [a_loc, b_loc, c_loc] =
elem_to_loc_nd<3>(elem, shape, a_strides, b_strides, c_strides);
for (int i = ndim - 1; i >= 3; --i) {
IdxT a_loc = 0;
IdxT b_loc = 0;
IdxT c_loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
@ -336,4 +338,21 @@ struct LoopedElemToLoc<1, false, OffsetT> {
}
};
inline __device__ cuComplex log1p(cuComplex in) {
float x = cuCrealf(in);
float y = cuCimagf(in);
float zabs = sqrt(x * x + y * y);
float theta = atan2f(y, x + 1);
if (zabs < 0.5f) {
float r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return {x, theta};
}
return {0.5f * log1pf(r), theta};
} else {
auto z0 = sqrt((x + 1) * (x + 1) + y * y);
return {log(z0), theta};
}
}
} // namespace mlx::core::cu

View File

@ -65,8 +65,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) ||
(src.size() > UINT32_MAX) || (out.size() > UINT32_MAX);
bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) ||
(src.size() > INT32_MAX) || (out.size() > INT32_MAX);
uint32_t slice_size = std::accumulate(
slice_sizes_.begin(), slice_sizes_.end(), 1, std::multiplies<uint32_t>());
@ -88,7 +88,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
dtype_to_cuda_type(idx_dtype),
nidx,
ndim,
large ? "int64_t" : "uint32_t"));
large ? "int64_t" : "int32_t"));
}
}
return std::make_pair(jit_source_gather, std::move(kernel_names));
@ -99,7 +99,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
if (large) {
mod.append_arg<int64_t>(out.size());
} else {
mod.append_arg<uint32_t>(out.size());
mod.append_arg<int32_t>(out.size());
}
mod.append_ndim_arg(src.shape());
mod.append_ndim_arg(src.strides());
@ -115,7 +115,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
dtype_to_cuda_type(idx_dtype),
nidx,
idx_ndim,
large ? "int64_t" : "uint32_t");
large ? "int64_t" : "int32_t");
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
@ -152,14 +152,14 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) ||
(upd.size() > UINT32_MAX) || (out.size() > UINT32_MAX);
bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) ||
(upd.size() > INT32_MAX) || (out.size() > INT32_MAX);
uint32_t upd_post_idx_size = std::accumulate(
int32_t upd_post_idx_size = std::accumulate(
upd.shape().begin() + idx_ndim,
upd.shape().end(),
1,
std::multiplies<uint32_t>());
std::multiplies<int32_t>());
const char* op = g_scatter_ops[reduce_type_];
std::string module_name = fmt::format(
@ -181,7 +181,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
op,
nidx,
ndim,
large ? "int64_t" : "uint32_t"));
large ? "int64_t" : "int32_t"));
}
}
return std::make_pair(jit_source_scatter, std::move(kernel_names));
@ -192,7 +192,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
if (large) {
mod.append_arg<int64_t>(upd.size());
} else {
mod.append_arg<uint32_t>(upd.size());
mod.append_arg<int32_t>(upd.size());
}
mod.append_ndim_arg(upd.shape());
mod.append_ndim_arg(upd.strides());
@ -200,7 +200,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
if (large) {
mod.append_arg<int64_t>(upd_post_idx_size);
} else {
mod.append_arg<uint32_t>(upd_post_idx_size);
mod.append_arg<int32_t>(upd_post_idx_size);
}
mod.append_ndim_arg(out.shape());
mod.append_ndim_arg(out.strides());
@ -215,7 +215,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
op,
nidx,
idx_ndim,
large ? "int64_t" : "uint32_t");
large ? "int64_t" : "int32_t");
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
@ -238,7 +238,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX;
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
std::string module_name = fmt::format(
"gather_axis_{}_{}",
@ -258,7 +258,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
ndim,
contiguous & 1 ? true : false,
contiguous & 2 ? true : false,
large ? "int64_t" : "uint32_t"));
large ? "int64_t" : "int32_t"));
}
}
}
@ -283,9 +283,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
mod.append_arg<int64_t>(idx_size_axis);
mod.append_arg<int64_t>(idx_size_post);
} else {
mod.append_arg<uint32_t>(idx_size_pre);
mod.append_arg<uint32_t>(idx_size_axis);
mod.append_arg<uint32_t>(idx_size_post);
mod.append_arg<int32_t>(idx_size_pre);
mod.append_arg<int32_t>(idx_size_axis);
mod.append_arg<int32_t>(idx_size_post);
}
mod.append_arg(remove_index(idx.shape(), axis_));
mod.append_arg(remove_index(src.strides(), axis_));
@ -302,7 +302,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
src.ndim() - 1,
src.flags().row_contiguous,
idx.flags().row_contiguous,
large ? "int64_t" : "uint32_t");
large ? "int64_t" : "int32_t");
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
@ -337,7 +337,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX;
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
const char* op = reduce_type_ == ScatterAxis::Sum ? "Sum" : "Assign";
std::string module_name = fmt::format(
@ -360,7 +360,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
ndim,
contiguous & 1 ? true : false,
contiguous & 2 ? true : false,
large ? "int64_t" : "uint32_t"));
large ? "int64_t" : "int32_t"));
}
}
}
@ -385,9 +385,9 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
mod.append_arg<int64_t>(idx_size_axis);
mod.append_arg<int64_t>(idx_size_post);
} else {
mod.append_arg<uint32_t>(idx_size_pre);
mod.append_arg<uint32_t>(idx_size_axis);
mod.append_arg<uint32_t>(idx_size_post);
mod.append_arg<int32_t>(idx_size_pre);
mod.append_arg<int32_t>(idx_size_axis);
mod.append_arg<int32_t>(idx_size_post);
}
mod.append_arg(remove_index(idx.shape(), axis_));
mod.append_arg(remove_index(upd.strides(), axis_));
@ -405,7 +405,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
idx.ndim() - 1,
upd.flags().row_contiguous,
idx.flags().row_contiguous,
large ? "int64_t" : "uint32_t");
large ? "int64_t" : "int32_t");
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {

View File

@ -145,7 +145,7 @@ bool compiler_supports_device_sass(Device& device) {
}
}
#define INCLUDE_PREFIX "mlx/backend/cuda/kernels/"
#define INCLUDE_PREFIX "mlx/backend/cuda/device/"
constexpr const char* g_include_names[] = {
INCLUDE_PREFIX "atomic_ops.cuh",

View File

@ -23,4 +23,11 @@ dim3 get_2d_grid_dims(
return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
}
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2) {
auto [grid, block] = get_grid_and_block_common(dim0, dim1, dim2);
auto [gx, gy, gz] = grid;
auto [bx, by, bz] = block;
return std::make_pair(dim3(gx, gy, gz), dim3(bx, by, bz));
}
} // namespace mlx::core

View File

@ -102,6 +102,11 @@ inline constexpr bool is_floating_v =
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
// Type traits for detecting complex or real floating point numbers.
template <typename T>
inline constexpr bool is_inexact_v =
is_floating_v<T> || cuda::std::is_same_v<T, complex64_t>;
// Utility to copy data from vector to array in host.
template <int NDIM = MAX_NDIM, typename T = int32_t>
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
@ -121,6 +126,7 @@ dim3 get_2d_grid_dims(
const Shape& shape,
const Strides& strides,
size_t divisor);
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
// Return a block size that achieves maximum potential occupancy for kernel.
template <typename T>
@ -135,17 +141,19 @@ inline uint max_occupancy_block_dim(T kernel) {
template <typename T>
inline std::tuple<dim3, uint> get_launch_args(
T kernel,
const array& arr,
size_t size,
const Shape& shape,
const Strides& strides,
bool large,
int work_per_thread = 1) {
size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread);
size_t nthreads = cuda::ceil_div(size, work_per_thread);
uint block_dim = max_occupancy_block_dim(kernel);
if (block_dim > nthreads) {
block_dim = nthreads;
}
dim3 num_blocks;
if (large) {
num_blocks = get_2d_grid_dims(arr.shape(), arr.strides(), work_per_thread);
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
} else {
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
@ -153,4 +161,14 @@ inline std::tuple<dim3, uint> get_launch_args(
return std::make_tuple(num_blocks, block_dim);
}
template <typename T>
inline std::tuple<dim3, uint> get_launch_args(
T kernel,
const array& arr,
bool large,
int work_per_thread = 1) {
return get_launch_args(
kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
}
} // namespace mlx::core

View File

@ -5,6 +5,7 @@
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
#include <cublasLt.h>
#include <fmt/format.h>
@ -44,9 +45,12 @@ class MatMul {
int64_t b_batch_stride) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto type = dtype_to_cuda_type(dtype);
auto scale_type = dtype_to_cuda_type(dtype);
if (dtype == bfloat16 || dtype == float16) {
scale_type = CUDA_R_32F;
}
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
&matmul_desc_, dtype_to_compute_type(dtype), type));
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
@ -65,6 +69,7 @@ class MatMul {
&op,
sizeof(cublasOperation_t)));
auto type = dtype_to_cuda_type(dtype);
a_desc_ = create_matrix_layout(
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
b_desc_ = create_matrix_layout(
@ -187,17 +192,13 @@ class MatMul {
private:
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
switch (dtype) {
case uint8:
case uint16:
case int8:
case int16:
case int32:
return CUBLAS_COMPUTE_32I;
case float16:
case bfloat16:
return CUBLAS_COMPUTE_16F;
case float32:
return CUBLAS_COMPUTE_32F;
case bfloat16:
return CUBLAS_COMPUTE_32F;
case float32:
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F;
case float64:
case complex64:
return CUBLAS_COMPUTE_64F;
@ -209,16 +210,6 @@ class MatMul {
cudaDataType_t dtype_to_cuda_type(Dtype dtype) {
switch (dtype) {
case uint8:
return CUDA_R_8U;
case uint16:
return CUDA_R_16U;
case int8:
return CUDA_R_8I;
case int16:
return CUDA_R_16I;
case int32:
return CUDA_R_32I;
case float16:
return CUDA_R_16F;
case bfloat16:

View File

@ -71,10 +71,8 @@ bool fast::ScaledDotProductAttention::use_fallback(
throw std::runtime_error(#func " has no CUDA implementation."); \
}
NO_GPU(ArgPartition)
NO_GPU(BlockMaskedMM)
NO_GPU(Convolution)
NO_GPU_MULTI(DivMod)
NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate)
NO_GPU(FFT)
@ -83,7 +81,6 @@ NO_GPU(GatherQMM)
NO_GPU(Hadamard)
NO_GPU(Load)
NO_GPU_MULTI(LUF)
NO_GPU(Partition)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(Scan)
@ -94,7 +91,6 @@ NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU_USE_FALLBACK(RoPE)
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel)

View File

@ -4,6 +4,7 @@
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
#include <cassert>
@ -12,6 +13,8 @@ namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
__constant__ constexpr uint32_t rotations[2][4] = {
{13, 15, 26, 6},
{17, 29, 16, 24}};
@ -47,27 +50,28 @@ __global__ void rbitsc(
dim3 grid_dims,
bool odd,
uint32_t bytes_per_key) {
uint2 index{
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y};
if (index.x >= grid_dims.x || index.y >= grid_dims.y) {
auto grid = cg::this_grid();
uint thread_index = grid.thread_rank();
uint index_x = thread_index % grid_dims.x;
uint index_y = thread_index / grid_dims.x;
if (index_x >= grid_dims.x || index_y >= grid_dims.y) {
return;
}
auto kidx = 2 * index.x;
auto kidx = 2 * index_x;
auto key = uint2{keys[kidx], keys[kidx + 1]};
auto half_size = grid_dims.y - odd;
out += index.x * bytes_per_key;
bool drop_last = odd && (index.y == half_size);
out += index_x * bytes_per_key;
bool drop_last = odd && (index_y == half_size);
auto bits = threefry2x32_hash(
key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y});
size_t idx = size_t(index.y) << 2;
key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y});
size_t idx = size_t(index_y) << 2;
for (int i = 0; i < 4; ++i) {
out[idx + i] = bits.bytes[0][i];
}
if (!drop_last) {
idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2;
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2;
if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) {
out[idx + i] = bits.bytes[1][i];
@ -89,30 +93,31 @@ __global__ void rbits(
int32_t ndim,
const __grid_constant__ Shape key_shape,
const __grid_constant__ Strides key_strides) {
uint2 index{
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y};
if (index.x >= grid_dims.x || index.y >= grid_dims.y) {
auto grid = cg::this_grid();
uint thread_index = grid.thread_rank();
uint index_x = thread_index % grid_dims.x;
uint index_y = thread_index / grid_dims.x;
if (index_x >= grid_dims.x || index_y >= grid_dims.y) {
return;
}
auto kidx = 2 * index.x;
auto kidx = 2 * index_x;
auto k1_elem = elem_to_loc(kidx, key_shape.data(), key_strides.data(), ndim);
auto k2_elem =
elem_to_loc(kidx + 1, key_shape.data(), key_strides.data(), ndim);
auto key = uint2{keys[k1_elem], keys[k2_elem]};
auto half_size = grid_dims.y - odd;
out += size_t(index.x) * bytes_per_key;
bool drop_last = odd && (index.y == half_size);
out += size_t(index_x) * bytes_per_key;
bool drop_last = odd && (index_y == half_size);
auto bits = threefry2x32_hash(
key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y});
size_t idx = size_t(index.y) << 2;
key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y});
size_t idx = size_t(index_y) << 2;
for (int i = 0; i < 4; ++i) {
out[idx + i] = bits.bytes[0][i];
}
if (!drop_last) {
idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2;
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2;
if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) {
out[idx + i] = bits.bytes[1][i];
@ -153,19 +158,22 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
dim3 grid_dims{num_keys, half_size + odd};
dim3 block_dims = get_block_dims(grid_dims.x, grid_dims.y, 1);
dim3 num_blocks{
cuda::ceil_div(grid_dims.x, block_dims.x),
cuda::ceil_div(grid_dims.y, block_dims.y)};
int64_t total = grid_dims.x * grid_dims.y;
int32_t threads_y = 1;
while ((total / threads_y) >= (1U << 31)) {
threads_y *= 2;
}
int32_t threads_x = cuda::ceil_div(total, threads_y);
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
if (keys.flags().row_contiguous) {
cu::rbitsc<<<num_blocks, block_dims, 0, stream>>>(
cu::rbitsc<<<grid, block, 0, stream>>>(
keys.data<uint32_t>(),
out.data<uint8_t>(),
grid_dims,
odd,
bytes_per_key);
} else {
cu::rbits<<<num_blocks, block_dims, 0, stream>>>(
cu::rbits<<<grid, block, 0, stream>>>(
keys.data<uint32_t>(),
out.data<uint8_t>(),
grid_dims,

385
mlx/backend/cuda/rope.cu Normal file
View File

@ -0,0 +1,385 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
template <typename T, bool traditional, bool forward>
__device__ void rope_single_impl(
const T* in,
T* out,
int32_t offset,
float inv_freq,
float scale,
int64_t stride,
uint2 pos,
uint2 dims) {
float L = scale * static_cast<float>(offset);
// Compute costheta, sintheta
float theta = L * inv_freq;
float costheta = cos(theta);
float sintheta = sin(theta);
// Compute the input and output indices
uint index_1, index_2;
if (traditional) {
index_1 = 2 * pos.x + pos.y * stride;
index_2 = index_1 + 1;
} else {
index_1 = pos.x + pos.y * stride;
index_2 = index_1 + dims.x;
}
// Read and write the output
float x1 = static_cast<float>(in[index_1]);
float x2 = static_cast<float>(in[index_2]);
float rx1;
float rx2;
if (forward) {
rx1 = x1 * costheta - x2 * sintheta;
rx2 = x1 * sintheta + x2 * costheta;
} else {
rx1 = x2 * sintheta + x1 * costheta;
rx2 = x2 * costheta - x1 * sintheta;
}
out[index_1] = static_cast<T>(rx1);
out[index_2] = static_cast<T>(rx2);
}
template <typename T, bool traditional, bool forward>
__global__ void rope_single(
const T* in,
T* out,
const int32_t* offset,
float scale,
float base,
int64_t stride,
uint2 dims) {
uint2 pos = make_uint2(
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y);
if (pos.x >= dims.x || pos.y >= dims.y) {
return;
}
float d = static_cast<float>(pos.x) / static_cast<float>(dims.x);
float inv_freq = exp2(-d * base);
rope_single_impl<T, traditional, forward>(
in, out, *offset, inv_freq, scale, stride, pos, dims);
}
template <typename T, bool traditional, bool forward>
__global__ void rope_single_freqs(
const T* in,
T* out,
const int32_t* offset,
const float* freqs,
float scale,
int64_t stride,
uint2 dims,
int64_t freq_stride) {
uint2 pos = make_uint2(
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y);
if (pos.x >= dims.x || pos.y >= dims.y) {
return;
}
float inv_freq = 1.0 / freqs[freq_stride * pos.x];
rope_single_impl<T, traditional, forward>(
in, out, *offset, inv_freq, scale, stride, pos, dims);
}
template <typename T, bool traditional, bool forward, int N = 4>
__device__ void rope_impl(
const T* in,
T* out,
int offset,
float inv_freq,
float scale,
const cuda::std::array<int64_t, 3> strides,
const cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch,
uint3 pos,
uint3 dims) {
float L = scale * static_cast<float>(pos.y + offset);
// Compute costheta, sintheta
float theta = L * inv_freq;
float costheta = cos(theta);
float sintheta = sin(theta);
// Compute the input and output indices
size_t in_index_1, in_index_2;
size_t out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
out_index_2 = out_index_1 + 1;
in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
out_index_2 = out_index_1 + dims.x * out_strides[2];
in_index_1 =
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + dims.x * strides[2];
}
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
// Read and write the output
float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]);
float rx1;
float rx2;
if (forward) {
rx1 = x1 * costheta - x2 * sintheta;
rx2 = x1 * sintheta + x2 * costheta;
} else {
rx1 = x2 * sintheta + x1 * costheta;
rx2 = x2 * costheta - x1 * sintheta;
}
out[out_index_1] = static_cast<T>(rx1);
out[out_index_2] = static_cast<T>(rx2);
in_index_1 += strides[0];
in_index_2 += strides[0];
out_index_1 += out_strides[0];
out_index_2 += out_strides[0];
}
}
template <typename T, bool traditional, bool forward>
__global__ void rope(
const T* in,
T* out,
const int32_t* offset,
float scale,
float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch,
uint3 dims) {
uint3 pos = make_uint3(
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y,
blockIdx.z * blockDim.z + threadIdx.z);
if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) {
return;
}
float d = static_cast<float>(pos.x) / static_cast<float>(dims.x);
float inv_freq = exp2(-d * base);
rope_impl<T, traditional, forward>(
in,
out,
*offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
pos,
dims);
}
template <typename T, bool traditional, bool forward>
__global__ void rope_freqs(
const T* in,
T* out,
const int32_t* offset,
const float* freqs,
float scale,
float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch,
uint3 dims,
int64_t freq_stride) {
uint3 pos = make_uint3(
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y,
blockIdx.z * blockDim.z + threadIdx.z);
if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) {
return;
}
float inv_freq = 1.0 / freqs[freq_stride * pos.x];
rope_impl<T, traditional, forward>(
in,
out,
*offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
pos,
dims);
}
} // namespace cu
namespace fast {
bool RoPE::use_fallback(Stream s) {
return s.device == Device::cpu;
}
void RoPE::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("RoPE::eval_gpu");
auto& s = stream();
auto& in = inputs[0];
auto& offset = inputs[1];
auto& out = outputs[0];
if (in.ndim() < 3) {
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
}
cuda::std::array<int64_t, 3> strides;
cuda::std::array<int64_t, 3> out_strides;
bool donated = false;
int ndim = in.ndim();
int dispatch_ndim = in.ndim();
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
dispatch_ndim--;
}
size_t mat_size = in.shape(-2) * in.shape(-1);
// We apply rope to less that the whole vector so copy to output and then
// apply in-place.
if (dims_ < in.shape(-1)) {
donated = true;
auto ctype =
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
copy_gpu(in, out, ctype, s);
strides[0] = mat_size;
strides[1] = out.strides()[ndim - 2];
strides[2] = out.strides()[ndim - 1];
}
// Either copy or apply in-place
else if (in.flags().row_contiguous) {
if (in.is_donatable()) {
donated = true;
out.copy_shared_buffer(in);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
strides[0] = mat_size;
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
} else if (dispatch_ndim == 3) {
// Handle non-contiguous 3D inputs
out.set_data(allocator::malloc(out.nbytes()));
strides[0] = in.strides()[ndim - 3];
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
} else {
// Copy non-contiguous > 3D inputs into the output and treat
// input as donated
donated = true;
copy_gpu(in, out, CopyType::General, s);
strides[0] = mat_size;
strides[1] = out.strides()[ndim - 2];
strides[2] = out.strides()[ndim - 1];
}
out_strides[0] = mat_size;
out_strides[1] = out.strides()[ndim - 2];
out_strides[2] = out.strides()[ndim - 1];
// Some flags to help us dispatch below
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
bool with_freqs = inputs.size() == 3;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(donated ? out : in);
encoder.set_input_array(offset);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(in.dtype(), "rope", CTYPE, {
using DataType = cuda_type_t<CTYPE>;
MLX_SWITCH_BOOL(traditional_, TRADITIONAL, {
MLX_SWITCH_BOOL(forward_, FORWARD, {
if (single && !with_freqs) {
auto kernel = cu::rope_single<DataType, TRADITIONAL, FORWARD>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
kernel<<<grid, block, 0, stream>>>(
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
scale_,
std::log2(base_),
mat_size,
dims);
} else if (single) {
auto kernel = cu::rope_single_freqs<DataType, TRADITIONAL, FORWARD>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
kernel<<<grid, block, 0, stream>>>(
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
inputs[2].data<float>(),
scale_,
mat_size,
dims,
inputs[2].strides(0));
} else if (with_freqs) {
auto kernel = cu::rope_freqs<DataType, TRADITIONAL, FORWARD>;
uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4;
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
kernel<<<grid, block, 0, stream>>>(
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
inputs[2].data<float>(),
scale_,
std::log2(base_),
strides,
out_strides,
in.size() / mat_size,
dims,
inputs[2].strides(0));
} else {
auto kernel = cu::rope<DataType, TRADITIONAL, FORWARD>;
uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4;
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
kernel<<<grid, block, 0, stream>>>(
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
scale_,
std::log2(base_),
strides,
out_strides,
in.size() / mat_size,
dims);
}
});
});
});
});
}
} // namespace fast
} // namespace mlx::core

View File

@ -86,7 +86,6 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
axis += in.ndim();
}
int nsort = in.shape(axis);
int nsegments = in.data_size() / nsort;
int last_dim = in.ndim() - 1;
// If we are not sorting the innermost dimension of a contiguous array,
@ -100,7 +99,11 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
encoder.add_temporary(out);
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
encoder.launch_kernel([&](cudaStream_t stream) {
@ -134,7 +137,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
indices.data<uint32_t>(),
out.data<uint32_t>(),
in.data_size(),
nsegments,
in.data_size() / nsort,
offsets,
offsets + 1,
stream);
@ -144,7 +147,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.data<Type>(),
out.data<Type>(),
in.data_size(),
nsegments,
in.data_size() / nsort,
offsets,
offsets + 1,
stream);
@ -177,4 +180,14 @@ void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
gpu_sort(stream(), inputs[0], out, axis_, false);
}
void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("ArgPartition::eval_gpu");
gpu_sort(stream(), inputs[0], out, axis_, true);
}
void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Partition::eval_gpu");
gpu_sort(stream(), inputs[0], out, axis_, false);
}
} // namespace mlx::core

View File

@ -101,10 +101,10 @@ void ternary_op_gpu_inplace(
auto& a_strides = strides[0];
auto& b_strides = strides[1];
auto& c_strides = strides[2];
bool large = a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX ||
c.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
bool large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
@ -116,7 +116,7 @@ void ternary_op_gpu_inplace(
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
out.data_size(),
out.size(),
const_param<NDIM>(shape),
const_param<NDIM>(a_strides),
const_param<NDIM>(b_strides),
@ -142,7 +142,8 @@ void ternary_op_gpu_inplace(
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
auto kernel = cu::ternary_v<Op, DType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE);
auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<bool>(),
b.data<DType>(),

View File

@ -27,12 +27,14 @@ constexpr bool supports_unary_op() {
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Log1p> ||
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Sigmoid> ||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Sigmoid> ||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
return std::is_same_v<In, Out> && is_floating_v<In>;
}
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p>) {
return std::is_same_v<In, Out> && is_inexact_v<In>;
}
if (std::is_same_v<Op, BitwiseInvert>) {
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
!std::is_same_v<In, bool>;
@ -91,7 +93,7 @@ void unary_op_gpu_inplace(
} else {
auto [shape, strides] = collapse_contiguous_dims(in);
auto [in_begin, in_end] = cu::make_general_iterators<int64_t>(
in_ptr, in.data_size(), shape, strides);
in_ptr, in.size(), shape, strides);
thrust::transform(policy, in_begin, in_end, out_ptr, Op());
}
} else {

View File

@ -31,6 +31,9 @@ const char* dtype_to_cuda_type(const Dtype& dtype) {
if (dtype == bfloat16) {
return "__nv_bfloat16";
}
if (dtype == complex64) {
return "cuComplex";
}
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
if (dtype == DTYPE) { \
return #CPP_TYPE; \

View File

@ -4,12 +4,15 @@
#include "mlx/backend/gpu/available.h"
#include "mlx/backend/gpu/eval.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/thread_safey.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
namespace mlx::core::gpu {
std::mutex metal_operation_mutex;
bool is_available() {
return true;
}
@ -30,6 +33,7 @@ inline void check_error(MTL::CommandBuffer* cbuf) {
}
void eval(array& arr) {
std::lock_guard<std::mutex> lock(metal_operation_mutex);
auto pool = metal::new_scoped_memory_pool();
auto s = arr.primitive().stream();
auto& d = metal::device(s.device);
@ -78,6 +82,7 @@ void eval(array& arr) {
}
void finalize(Stream s) {
std::lock_guard<std::mutex> lock(metal_operation_mutex);
auto pool = metal::new_scoped_memory_pool();
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
@ -88,6 +93,7 @@ void finalize(Stream s) {
}
void synchronize(Stream s) {
std::lock_guard<std::mutex> lock(metal_operation_mutex);
auto pool = metal::new_scoped_memory_pool();
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);

View File

@ -2,6 +2,7 @@
#include "mlx/event.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/thread_safey.h"
#include "mlx/scheduler.h"
namespace mlx::core {
@ -27,6 +28,7 @@ void Event::wait(Stream stream) {
if (stream.device == Device::cpu) {
scheduler::enqueue(stream, [*this]() mutable { wait(); });
} else {
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
auto& d = metal::device(stream.device);
d.end_encoding(stream.index);
auto command_buffer = d.get_command_buffer(stream.index);
@ -41,6 +43,7 @@ void Event::signal(Stream stream) {
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
});
} else {
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
auto& d = metal::device(stream.device);
d.end_encoding(stream.index);
auto command_buffer = d.get_command_buffer(stream.index);

View File

@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc.
#include "mlx/fence.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/thread_safey.h"
#include "mlx/scheduler.h"
#include "mlx/utils.h"
@ -68,6 +69,7 @@ void Fence::wait(Stream stream, const array& x) {
return;
}
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
auto& d = metal::device(stream.device);
auto idx = stream.index;
@ -116,6 +118,7 @@ void Fence::update(Stream stream, const array& x) {
return;
}
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
auto& d = metal::device(stream.device);
auto idx = stream.index;
if (!f.use_fast) {

View File

@ -0,0 +1,7 @@
#pragma once
#include <mutex>
namespace mlx::core::gpu {
extern std::mutex metal_operation_mutex;
}

View File

@ -149,6 +149,11 @@ inline bool metal_fast_synch() {
return metal_fast_synch;
}
inline bool enable_tf32() {
static bool enable_tf32_ = get_var("MLX_ENABLE_TF32", 1);
return enable_tf32_;
}
} // namespace env
} // namespace mlx::core

5
python/tests/__main__.py Normal file
View File

@ -0,0 +1,5 @@
from . import mlx_tests
__unittest = True
mlx_tests.MLXTestRunner(module=None)

107
python/tests/cuda_skip.py Normal file
View File

@ -0,0 +1,107 @@
cuda_skip = {
"TestArray.test_api",
"TestBF16.test_arg_reduction_ops",
"TestBF16.test_reduction_ops",
"TestBlas.test_complex_gemm",
"TestEinsum.test_ellipses",
"TestEinsum.test_opt_einsum_test_cases",
"TestLoad.test_load_f8_e4m3",
"TestMemory.test_memory_info",
"TestLayers.test_group_norm",
"TestLayers.test_pooling",
"TestLayers.test_quantized_embedding",
"TestLayers.test_sin_pe",
"TestLayers.test_upsample",
"TestOps.test_complex_ops",
"TestOps.test_dynamic_slicing",
"TestOps.test_softmax",
"TestReduce.test_axis_permutation_sums",
"TestReduce.test_dtypes",
"TestReduce.test_expand_sums",
"TestReduce.test_many_reduction_axes",
"TestUpsample.test_torch_upsample",
# Block masked matmul NYI
"TestBlas.test_block_masked_matmul",
# Gather matmul NYI
"TestBlas.test_gather_matmul",
"TestBlas.test_gather_matmul_grad",
# Scan NYI
"TestAutograd.test_cumprod_grad",
"TestOps.test_scans",
"TestOps.test_logcumsumexp",
# Hadamard NYI
"TestOps.test_hadamard",
"TestOps.test_hadamard_grad_vmap",
# Convolutions NYI
"TestConv.test_1d_conv_with_2d",
"TestConv.test_asymmetric_padding",
"TestConv.test_basic_grad_shapes",
"TestConv.test_conv2d_unaligned_channels",
"TestConv.test_conv_1d_groups_flipped",
"TestConv.test_conv_general_flip_grad",
"TestConv.test_conv_groups_grad",
"TestConv.test_numpy_conv",
"TestConv.test_repeated_conv",
"TestConv.test_torch_conv_1D",
"TestConv.test_torch_conv_1D_grad",
"TestConv.test_torch_conv_2D",
"TestConv.test_torch_conv_2D_grad",
"TestConv.test_torch_conv_3D",
"TestConv.test_torch_conv_3D_grad",
"TestConv.test_torch_conv_depthwise",
"TestConv.test_torch_conv_general",
"TestConvTranspose.test_torch_conv_tranpose_1d_output_padding",
"TestConvTranspose.test_torch_conv_transpose_1D",
"TestConvTranspose.test_torch_conv_transpose_1D_grad",
"TestConvTranspose.test_torch_conv_transpose_2D",
"TestConvTranspose.test_torch_conv_transpose_2D_grad",
"TestConvTranspose.test_torch_conv_transpose_2d_output_padding",
"TestConvTranspose.test_torch_conv_transpose_3D",
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
"TestExportImport.test_export_conv",
"TestLayers.test_conv1d",
"TestLayers.test_conv2d",
"TestVmap.test_vmap_conv",
# FFTs NYI
"TestFFT.test_fft",
"TestFFT.test_fft_big_powers_of_two",
"TestFFT.test_fft_contiguity",
"TestFFT.test_fft_exhaustive",
"TestFFT.test_fft_grads",
"TestFFT.test_fft_into_ifft",
"TestFFT.test_fft_large_numbers",
"TestFFT.test_fft_shared_mem",
"TestFFT.test_fftn",
# Lapack ops NYI
"TestLinalg.test_cholesky",
"TestLinalg.test_cholesky_inv",
"TestLinalg.test_eig",
"TestLinalg.test_eigh",
"TestLinalg.test_inverse",
"TestVmap.test_vmap_inverse",
"TestLinalg.test_lu",
"TestLinalg.test_lu_factor",
"TestLinalg.test_pseudo_inverse",
"TestLinalg.test_qr_factorization",
"TestInit.test_orthogonal",
"TestLinalg.test_svd_decomposition",
"TestVmap.test_vmap_svd",
"TestLinalg.test_tri_inverse",
# Quantization NYI
"TestQuantized.test_gather_matmul_grad",
"TestQuantized.test_gather_qmm",
"TestQuantized.test_gather_qmm_sorted",
"TestQuantized.test_non_multiples",
"TestQuantized.test_qmm",
"TestQuantized.test_qmm_jvp",
"TestQuantized.test_qmm_shapes",
"TestQuantized.test_qmm_vjp",
"TestQuantized.test_qmv",
"TestQuantized.test_quantize_dequantize",
"TestQuantized.test_qvm",
"TestQuantized.test_qvm_splitk",
"TestQuantized.test_small_matrix",
"TestQuantized.test_throw",
"TestQuantized.test_vjp_scales_biases",
}

View File

@ -9,6 +9,42 @@ import mlx.core as mx
import numpy as np
class MLXTestRunner(unittest.TestProgram):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def createTests(self, *args, **kwargs):
super().createTests(*args, **kwargs)
# Asume CUDA backend in this case
device = os.getenv("DEVICE", None)
if device is not None:
device = getattr(mx, device)
else:
device = mx.default_device()
if not (device == mx.gpu and not mx.metal.is_available()):
return
from cuda_skip import cuda_skip
filtered_suite = unittest.TestSuite()
def filter_and_add(t):
if isinstance(t, unittest.TestSuite):
for sub_t in t:
filter_and_add(sub_t)
else:
t_id = ".".join(t.id().split(".")[-2:])
if t_id in cuda_skip:
print(f"Skipping {t_id}")
else:
filtered_suite.addTest(t)
filter_and_add(self.test)
self.test = filtered_suite
class MLXTestCase(unittest.TestCase):
@property
def is_apple_silicon(self):

View File

@ -130,4 +130,4 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -1187,7 +1187,7 @@ class TestArray(mlx_tests.MLXTestCase):
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
check_slices(
np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 0, 1])
np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 2, 1])
)
# Multiple slices
@ -2033,4 +2033,4 @@ class TestArray(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -799,4 +799,4 @@ class TestAutograd(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -193,4 +193,4 @@ class TestBF16(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -1236,4 +1236,4 @@ class TestBlas(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -981,4 +981,4 @@ class TestCompile(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -38,4 +38,4 @@ class TestConstants(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -1188,4 +1188,4 @@ class TestConv(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -807,4 +807,4 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -38,7 +38,7 @@ class TestDevice(mlx_tests.MLXTestCase):
# Restore device
mx.set_default_device(device)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
@unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
def test_device_context(self):
default = mx.default_device()
diff = mx.cpu if default == mx.gpu else mx.gpu
@ -114,4 +114,4 @@ class TestStream(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -294,4 +294,4 @@ class TestDouble(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -360,4 +360,4 @@ class TestEinsum(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -172,7 +172,7 @@ class TestEval(mlx_tests.MLXTestCase):
post = mx.get_peak_memory()
self.assertEqual(pre, post)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
@unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
def test_multistream_deadlock(self):
s1 = mx.default_stream(mx.gpu)
s2 = mx.new_stream(mx.gpu)
@ -197,4 +197,4 @@ class TestEval(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -348,4 +348,4 @@ class TestExportImport(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -772,4 +772,4 @@ class TestFast(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -607,7 +607,7 @@ class TestSDPA(mlx_tests.MLXTestCase):
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_sdpa_prommote_mask(self):
def test_sdpa_promote_mask(self):
mask = mx.array(2.0, mx.bfloat16)
D = 64
Nq = 4
@ -653,4 +653,4 @@ class TestSDPA(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main(failfast=True)
mlx_tests.MLXTestRunner(failfast=True)

View File

@ -320,4 +320,4 @@ class TestFFT(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -34,4 +34,4 @@ class TestGraph(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -136,4 +136,4 @@ class TestInit(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -545,4 +545,4 @@ class TestLinalg(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -400,4 +400,4 @@ class TestLoad(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -83,14 +83,14 @@ class TestLosses(mlx_tests.MLXTestCase):
logits, targets, reduction="mean"
)
expected_mean = mx.mean(expected_none)
self.assertEqual(losses_mean, expected_mean)
self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum'
losses_sum = nn.losses.binary_cross_entropy(
logits, targets, reduction="sum"
)
expected_sum = mx.sum(expected_none)
self.assertEqual(losses_sum, expected_sum)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
# With weights, no label smoothing
weights = mx.array([1.0, 2.0, 1.0, 2.0])
@ -414,4 +414,4 @@ class TestLosses(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -60,4 +60,4 @@ class TestMemory(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -1907,4 +1907,4 @@ class TestLayers(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -2586,17 +2586,6 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqualArray(result, mx.array(expected))
def test_atleast_1d(self):
def compare_nested_lists(x, y):
if isinstance(x, list) and isinstance(y, list):
if len(x) != len(y):
return False
for i in range(len(x)):
if not compare_nested_lists(x[i], y[i]):
return False
return True
else:
return x == y
# Test 1D input
arrays = [
[1],
@ -2614,23 +2603,11 @@ class TestOps(mlx_tests.MLXTestCase):
for i, array in enumerate(arrays):
mx_res = mx.atleast_1d(mx.array(array))
np_res = np.atleast_1d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
def test_atleast_2d(self):
def compare_nested_lists(x, y):
if isinstance(x, list) and isinstance(y, list):
if len(x) != len(y):
return False
for i in range(len(x)):
if not compare_nested_lists(x[i], y[i]):
return False
return True
else:
return x == y
# Test 1D input
arrays = [
[1],
@ -2648,23 +2625,11 @@ class TestOps(mlx_tests.MLXTestCase):
for i, array in enumerate(arrays):
mx_res = mx.atleast_2d(mx.array(array))
np_res = np.atleast_2d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
def test_atleast_3d(self):
def compare_nested_lists(x, y):
if isinstance(x, list) and isinstance(y, list):
if len(x) != len(y):
return False
for i in range(len(x)):
if not compare_nested_lists(x[i], y[i]):
return False
return True
else:
return x == y
# Test 1D input
arrays = [
[1],
@ -2682,10 +2647,9 @@ class TestOps(mlx_tests.MLXTestCase):
for i, array in enumerate(arrays):
mx_res = mx.atleast_3d(mx.array(array))
np_res = np.atleast_3d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
def test_issubdtype(self):
self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact))
@ -3127,4 +3091,4 @@ class TestBroadcast(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -527,4 +527,4 @@ class TestSchedulers(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -576,4 +576,4 @@ class TestQuantized(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -389,4 +389,4 @@ class TestRandom(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -155,4 +155,4 @@ class TestReduce(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main(failfast=True)
mlx_tests.MLXTestRunner(failfast=True)

View File

@ -48,4 +48,4 @@ class TestTreeUtils(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -97,4 +97,4 @@ class TestUpsample(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -725,4 +725,4 @@ class TestVmap(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -9,7 +9,9 @@ FetchContent_MakeAvailable(doctest)
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
if(MLX_BUILD_METAL)
set(METAL_TEST_SOURCES gpu_tests.cpp metal_thread_safety_tests.cpp)
elseif(MLX_BUILD_CUDA)
set(METAL_TEST_SOURCES gpu_tests.cpp)
endif()

View File

@ -589,6 +589,7 @@ TEST_CASE("test array shared buffer") {
array b = array(buf_b, shape, float32, deleter);
eval(a + b);
synchronize(); // ensure all operations complete before test ends
}
TEST_CASE("test make empty array") {

View File

@ -0,0 +1,250 @@
#include "doctest/doctest.h"
#include "mlx/mlx.h"
#include "mlx/backend/metal/device.h"
#include <thread>
#include <vector>
#include <atomic>
#include <chrono>
#include <mutex>
#include <iostream>
using namespace mlx::core;
// Helper function to run operations across multiple threads with pre-created streams
void run_in_threads(int num_threads, const std::function<void(int, Stream)>& func,
const std::vector<Stream>& streams) {
std::vector<std::thread> threads;
threads.reserve(num_threads);
for (int i = 0; i < num_threads; ++i) {
threads.emplace_back(func, i, streams[i % streams.size()]);
}
for (auto& t : threads) {
if (t.joinable()) {
t.join();
}
}
}
// Helper function for tasks not requiring streams (e.g., using default stream)
void run_in_threads_default(int num_threads, const std::function<void(int)>& func) {
std::vector<std::thread> threads;
threads.reserve(num_threads);
for (int i = 0; i < num_threads; ++i) {
threads.emplace_back(func, i);
}
for (auto& t : threads) {
if (t.joinable()) {
t.join();
}
}
}
// Thread-safe result collection
struct TestResults {
std::mutex mutex;
std::vector<bool> shape_checks;
std::vector<bool> availability_checks;
std::vector<bool> value_checks;
std::vector<float> expected_values;
std::vector<float> actual_values;
void record_result(bool shape_ok, bool available_ok, bool value_ok,
float expected, float actual) {
std::lock_guard<std::mutex> lock(mutex);
shape_checks.push_back(shape_ok);
availability_checks.push_back(available_ok);
value_checks.push_back(value_ok);
expected_values.push_back(expected);
actual_values.push_back(actual);
}
};
TEST_CASE("test metal concurrent eval operations") {
Device D_GPU = Device::gpu;
const int num_threads = std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() : 8;
const int ops_per_thread = 10;
const int array_size = 32;
std::atomic<int> completed_ops{0};
TestResults results;
// Pre-create streams to avoid concurrent stream creation
std::vector<Stream> streams;
for (int i = 0; i < num_threads; ++i) {
streams.push_back(new_stream(D_GPU));
}
synchronize(); // Ensure stream creation is complete
auto task = [&](int thread_id, Stream s) {
try {
for (int i = 0; i < ops_per_thread; ++i) {
float val1 = static_cast<float>(thread_id * ops_per_thread + i + 1);
float val2 = val1 * 2.0f;
auto x = full({array_size, array_size}, val1, s);
auto y = full({array_size, array_size}, val2, s);
auto z = add(x, y);
eval(z);
bool shape_ok = (z.shape() == Shape{array_size, array_size});
bool available_ok = z.is_available();
// Get a value from the array
int mid = array_size/2;
auto sample = slice(z, {mid, mid}, {mid+1, mid+1});
float actual = sample.item<float>();
float expected = val1 + val2;
bool values_match = (std::abs(actual - expected) < 1e-5);
results.record_result(shape_ok, available_ok, values_match, expected, actual);
if (shape_ok && available_ok && values_match) {
completed_ops++;
}
}
} catch (const std::exception& e) {
std::cerr << "Thread " << thread_id << " exception: " << e.what() << std::endl;
}
};
// Run the threads with pre-created streams
CHECK_NOTHROW(run_in_threads(num_threads, task, streams));
// Check all results outside of threads
for (size_t i = 0; i < results.shape_checks.size(); ++i) {
CAPTURE(i); // Help identify which operation failed
CHECK(results.shape_checks[i]);
CHECK(results.availability_checks[i]);
CHECK(results.value_checks[i]);
if (!results.value_checks[i]) {
CAPTURE(results.expected_values[i]);
CAPTURE(results.actual_values[i]);
}
}
// Verify all operations completed successfully
CHECK_EQ(completed_ops.load(), num_threads * ops_per_thread);
}
TEST_CASE("test metal high contention on default stream eval") {
Device D_GPU = Device::gpu;
const int num_threads = std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() : 8;
const int ops_per_thread = 5;
const int array_size = 16;
Stream default_gpu_stream = default_stream(D_GPU);
std::atomic<int> successful_ops{0};
std::vector<std::string> thread_errors;
std::mutex errors_mutex;
TestResults results;
auto task = [&](int thread_id) {
try {
for (int i = 0; i < ops_per_thread; ++i) {
float val = static_cast<float>(thread_id * 100 + i + 1);
auto x = full({array_size, array_size}, val, default_gpu_stream);
auto y = full({array_size, array_size}, val * 0.5f, default_gpu_stream);
auto z = multiply(x, y);
eval(z);
// Sample a value
auto sample = slice(z, {0, 0}, {1, 1});
float actual = sample.item<float>();
float expected = val * val * 0.5f;
bool shape_ok = (z.shape() == Shape{array_size, array_size});
bool available_ok = z.is_available();
bool values_match = (std::abs(actual - expected) < 1e-5);
results.record_result(shape_ok, available_ok, values_match, expected, actual);
if (shape_ok && available_ok && values_match) {
successful_ops++;
}
}
} catch (const std::exception& e) {
std::lock_guard<std::mutex> lock(errors_mutex);
thread_errors.push_back(std::string("Thread ") +
std::to_string(thread_id) +
" exception: " + e.what());
}
};
// Use the default helper for this test since it uses the default stream
CHECK_NOTHROW(run_in_threads_default(num_threads, task));
// Check for thread errors
CHECK(thread_errors.empty());
if (!thread_errors.empty()) {
for (const auto& err : thread_errors) {
CAPTURE(err);
}
}
// Check all results
for (size_t i = 0; i < results.shape_checks.size(); ++i) {
CAPTURE(i);
CHECK(results.shape_checks[i]);
CHECK(results.availability_checks[i]);
CHECK(results.value_checks[i]);
if (!results.value_checks[i]) {
CAPTURE(results.expected_values[i]);
CAPTURE(results.actual_values[i]);
}
}
// Verify operation count
CHECK_EQ(successful_ops.load(), num_threads * ops_per_thread);
}
TEST_CASE("test metal concurrent graph eval from different threads") {
Device D_GPU = Device::gpu;
const int num_threads = std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() : 4; // Keep modest for clarity
const int array_size = 64;
TestResults all_results;
// Pre-create streams
std::vector<Stream> streams;
for (int i = 0; i < num_threads; ++i) {
streams.push_back(new_stream(D_GPU));
}
synchronize();
auto task = [&](int thread_id, Stream s) {
try {
float val1_base = static_cast<float>(thread_id + 1) * 10.0f;
auto x = full({array_size, array_size}, val1_base, s);
auto y = full({array_size, array_size}, val1_base + 1.0f, s);
auto z = add(x, y);
auto w = multiply(z, x);
eval(w);
float expected_val = (val1_base + (val1_base + 1.0f)) * val1_base;
auto sample = slice(w, {0,0}, {1,1});
float actual_val = sample.item<float>();
bool shape_ok = (w.shape() == Shape{array_size, array_size});
bool available_ok = w.is_available();
bool value_ok = (std::abs(actual_val - expected_val) < 1e-4);
all_results.record_result(shape_ok, available_ok, value_ok, expected_val, actual_val);
} catch (const std::exception& e) {
std::cerr << "Thread " << thread_id << " exception in concurrent graph eval: " << e.what() << std::endl;
}
};
CHECK_NOTHROW(run_in_threads(num_threads, task, streams));
CHECK_EQ(all_results.shape_checks.size(), num_threads); // One result per thread
for (size_t i = 0; i < num_threads; ++i) {
CAPTURE(i);
CHECK(all_results.shape_checks[i]);
CHECK(all_results.availability_checks[i]);
CHECK(all_results.value_checks[i]);
if (!all_results.value_checks[i]) {
CAPTURE(all_results.expected_values[i]);
CAPTURE(all_results.actual_values[i]);
}
}
}