mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Compare commits
12 Commits
93839bf44d
...
e2a2bae148
Author | SHA1 | Date | |
---|---|---|---|
![]() |
e2a2bae148 | ||
![]() |
b8022c578a | ||
![]() |
bc53f8293f | ||
![]() |
c552ff2451 | ||
![]() |
4fda5fbdf9 | ||
![]() |
580776559b | ||
![]() |
a14aaa7c9d | ||
![]() |
a6d780154f | ||
![]() |
6871e2eeb7 | ||
![]() |
992eac905a | ||
![]() |
c8d4d97447 | ||
![]() |
28902ece4e |
@ -234,6 +234,7 @@ jobs:
|
||||
command: |
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
|
@ -107,6 +107,16 @@ same array:
|
||||
>>> a
|
||||
array([1, 2, 0], dtype=int32)
|
||||
|
||||
|
||||
Note, unlike NumPy, updates to the same location are nondeterministic:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.array([1, 2, 3])
|
||||
>>> a[[0, 0]] = mx.array([4, 5])
|
||||
|
||||
The first element of ``a`` could be ``4`` or ``5``.
|
||||
|
||||
Transformations of functions which use in-place updates are allowed and work as
|
||||
expected. For example:
|
||||
|
||||
|
@ -209,4 +209,14 @@ Dims get_2d_grid_dims_common(
|
||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||
}
|
||||
|
||||
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
|
||||
auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2);
|
||||
auto gx = (dim0 + bx - 1) / bx;
|
||||
auto gy = (dim1 + by - 1) / by;
|
||||
auto gz = (dim2 + bz - 1) / bz;
|
||||
|
||||
return std::make_pair(
|
||||
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -95,6 +95,9 @@ Dims get_2d_grid_dims_common(
|
||||
const Strides& strides,
|
||||
size_t divisor);
|
||||
|
||||
// Get both the block and a grid of blocks that covers dim0, dim1 and dim2.
|
||||
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2);
|
||||
|
||||
struct ContiguousIterator {
|
||||
inline void step() {
|
||||
int dims = shape_.size();
|
||||
|
@ -8,6 +8,7 @@ target_sources(
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
||||
@ -32,6 +33,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
||||
|
@ -1,5 +1,4 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
||||
@ -113,7 +112,7 @@ __global__ void arg_reduce_general(
|
||||
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
T vals[N_READS];
|
||||
auto tid = r * BLOCK_DIM + block.thread_index().z;
|
||||
auto tid = r * BLOCK_DIM + block.thread_index().x;
|
||||
cub::LoadDirectBlocked(
|
||||
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
|
||||
best = op.reduce_many(best, vals, tid * N_READS);
|
||||
@ -158,7 +157,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
constexpr uint32_t N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
||||
dim3 block_dims{1, 1, BLOCK_DIM};
|
||||
dim3 block_dims{BLOCK_DIM, 1, 1};
|
||||
auto kernel = &cu::arg_reduce_general<
|
||||
InType,
|
||||
cu::ArgMax<InType>,
|
||||
|
@ -101,10 +101,12 @@ constexpr bool supports_binary_op() {
|
||||
return std::is_same_v<Out, bool> && std::is_same_v<In, bool>;
|
||||
}
|
||||
if (std::is_same_v<Op, NaNEqual>) {
|
||||
return std::is_same_v<Out, bool> &&
|
||||
(is_floating_v<In> || std::is_same_v<In, complex64_t>);
|
||||
return std::is_same_v<Out, bool> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, LogAddExp> || std::is_same_v<Op, ArcTan2>) {
|
||||
if (std::is_same_v<Op, LogAddExp>) {
|
||||
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, ArcTan2>) {
|
||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, BitwiseAnd> || std::is_same_v<Op, BitwiseOr> ||
|
||||
@ -123,13 +125,12 @@ constexpr bool supports_binary_op() {
|
||||
template <typename Op>
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
array& out,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() > 1);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@ -144,16 +145,15 @@ void binary_op_gpu_inplace(
|
||||
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = cuda_type_t<CTYPE_IN>;
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
bool large = a.data_size() > UINT32_MAX ||
|
||||
b.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
||||
bool large = a.data_size() > INT32_MAX ||
|
||||
b.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
@ -165,7 +165,7 @@ void binary_op_gpu_inplace(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.data_size(),
|
||||
out.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(a_strides),
|
||||
const_param<NDIM>(b_strides));
|
||||
@ -178,7 +178,7 @@ void binary_op_gpu_inplace(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.data_size(),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
@ -196,8 +196,8 @@ void binary_op_gpu_inplace(
|
||||
} else if (bopt == BinaryOpType::VectorVector) {
|
||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out, LARGE);
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
@ -217,20 +217,6 @@ void binary_op_gpu_inplace(
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
@ -241,8 +227,7 @@ void binary_op_gpu(
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
std::vector<array> outputs{out};
|
||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||
binary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||
}
|
||||
|
||||
#define BINARY_GPU(func) \
|
||||
@ -252,19 +237,10 @@ void binary_op_gpu(
|
||||
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
|
||||
}
|
||||
|
||||
#define BINARY_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||
auto& s = outputs[0].primitive().stream(); \
|
||||
binary_op_gpu<cu::func>(inputs, outputs, get_primitive_string(this), s); \
|
||||
}
|
||||
|
||||
BINARY_GPU(Add)
|
||||
BINARY_GPU(ArcTan2)
|
||||
BINARY_GPU(Divide)
|
||||
BINARY_GPU(Remainder)
|
||||
BINARY_GPU(Equal)
|
||||
BINARY_GPU(Greater)
|
||||
BINARY_GPU(GreaterEqual)
|
||||
BINARY_GPU(Less)
|
||||
@ -279,6 +255,17 @@ BINARY_GPU(NotEqual)
|
||||
BINARY_GPU(Power)
|
||||
BINARY_GPU(Subtract)
|
||||
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Equal::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
auto op = get_primitive_string(this);
|
||||
if (equal_nan_) {
|
||||
binary_op_gpu<cu::NaNEqual>(inputs, out, op, s);
|
||||
} else {
|
||||
binary_op_gpu<cu::Equal>(inputs, out, op, s);
|
||||
}
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
|
248
mlx/backend/cuda/binary_two.cu
Normal file
248
mlx/backend/cuda/binary_two.cu
Normal file
@ -0,0 +1,248 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void
|
||||
binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto out = Op{}(a[0], b[0]);
|
||||
out_a[0] = out[0];
|
||||
out_b[0] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void
|
||||
binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto out = Op{}(a[0], b[index]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void
|
||||
binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto out = Op{}(a[index], b[0]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void
|
||||
binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto out = Op{}(a[index], b[index]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void binary_g_nd(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out_a,
|
||||
Out* out_b,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
||||
index, shape.data(), a_strides.data(), b_strides.data());
|
||||
auto out = Op{}(a[a_idx], b[b_idx]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_g(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out_a,
|
||||
Out* out_b,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides a_strides,
|
||||
const __grid_constant__ Strides b_strides,
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_4d(
|
||||
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
||||
auto out = Op{}(a[a_idx], b[b_idx]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out>
|
||||
constexpr bool supports_binary_op() {
|
||||
if (std::is_same_v<Op, DivMod>) {
|
||||
return std::is_same_v<In, Out> &&
|
||||
(std::is_integral_v<Out> || is_floating_v<Out>);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() > 1);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
auto& out_a = outputs[0];
|
||||
auto& out_b = outputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out_a, bopt);
|
||||
set_binary_op_output_data(a, b, out_b, bopt);
|
||||
|
||||
if (out_a.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out_a);
|
||||
encoder.set_output_array(out_b);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
|
||||
MLX_SWITCH_ALL_TYPES(out_a.dtype(), CTYPE_OUT, {
|
||||
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = cuda_type_t<CTYPE_IN>;
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out_a);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
bool large = a.data_size() > INT32_MAX ||
|
||||
b.data_size() > INT32_MAX || out_a.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel =
|
||||
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out_a, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
out_a.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(a_strides),
|
||||
const_param<NDIM>(b_strides));
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out_a, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
out_a.size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
MLX_SWITCH_BOOL(out_a.data_size() > UINT32_MAX, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorVector) {
|
||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel,
|
||||
out_a.data_size(),
|
||||
out_a.shape(),
|
||||
out_a.strides(),
|
||||
LARGE);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
out_a.data_size());
|
||||
});
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do binary op {} on inputs of {} with result of {}.",
|
||||
op,
|
||||
dtype_to_string(a.dtype()),
|
||||
dtype_to_string(out_a.dtype())));
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
void DivMod::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("DivMod::eval_gpu");
|
||||
auto& s = outputs[0].primitive().stream();
|
||||
binary_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -130,11 +130,13 @@ struct FusedKernelBuilder {
|
||||
|
||||
constexpr const char* g_jit_includes = R"(
|
||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/ternary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
#define inf cuda::std::numeric_limits<float>::infinity()
|
||||
)";
|
||||
|
||||
void Compiled::eval_gpu(
|
||||
|
@ -6,7 +6,7 @@
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in_,
|
||||
const array& in,
|
||||
array& out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in,
|
||||
@ -20,7 +20,6 @@ void copy_gpu_inplace(
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
const array& in = in_.data_shared_ptr() ? in_ : out;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
|
@ -10,20 +10,13 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
|
||||
using InType = cuda_type_t<CTYPE_IN>; \
|
||||
using OutType = cuda_type_t<CTYPE_OUT>; \
|
||||
if constexpr (cu::CastOp<InType, OutType>::is_castable) { \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
throw std::runtime_error(fmt::format( \
|
||||
"Can not copy data from dtype {} to {}.", \
|
||||
dtype_to_string(out.dtype()), \
|
||||
dtype_to_string(in.dtype()))); \
|
||||
} \
|
||||
}); \
|
||||
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
|
||||
using InType = cuda_type_t<CTYPE_IN>; \
|
||||
using OutType = cuda_type_t<CTYPE_OUT>; \
|
||||
__VA_ARGS__; \
|
||||
}); \
|
||||
})
|
||||
|
||||
void copy_contiguous(
|
||||
|
@ -43,7 +43,8 @@ void copy_contiguous(
|
||||
if (ctype == CopyType::Vector) {
|
||||
kernel = cu::copy_v<InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE);
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>() + in_offset,
|
||||
out.data<OutType>() + out_offset,
|
||||
|
@ -59,9 +59,9 @@ void copy_general(
|
||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
||||
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
@ -70,7 +70,7 @@ void copy_general(
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.data_size(),
|
||||
out.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(strides_in),
|
||||
const_param<NDIM>(strides_out));
|
||||
@ -81,7 +81,7 @@ void copy_general(
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.data_size(),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
const_param(strides_out),
|
||||
|
@ -65,9 +65,9 @@ void copy_general_dynamic(
|
||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
||||
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
@ -76,7 +76,7 @@ void copy_general_dynamic(
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.data_size(),
|
||||
out.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(strides_in),
|
||||
const_param<NDIM>(strides_out),
|
||||
@ -89,7 +89,7 @@ void copy_general_dynamic(
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.data_size(),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
const_param(strides_out),
|
||||
|
@ -54,9 +54,9 @@ void copy_general_input(
|
||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
||||
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
@ -65,7 +65,7 @@ void copy_general_input(
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.data_size(),
|
||||
out.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(strides_in));
|
||||
});
|
||||
@ -75,7 +75,7 @@ void copy_general_input(
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.data_size(),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
ndim);
|
||||
|
@ -1,6 +1,8 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda/std/array>
|
||||
@ -20,7 +22,7 @@ struct FloorDivide {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return x / y;
|
||||
} else {
|
||||
return trunc(x / y);
|
||||
return truncf(x / y);
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -122,6 +124,26 @@ struct LogAddExp {
|
||||
? maxval
|
||||
: T(float(maxval) + log1p(expf(minval - maxval)));
|
||||
};
|
||||
|
||||
__device__ cuComplex operator()(cuComplex x, cuComplex y) {
|
||||
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
|
||||
isnan(cuCimagf(y))) {
|
||||
return {
|
||||
cuda::std::numeric_limits<float>::quiet_NaN(),
|
||||
cuda::std::numeric_limits<float>::quiet_NaN()};
|
||||
}
|
||||
float inf = cuda::std::numeric_limits<float>::infinity();
|
||||
auto maxval = x > y ? x : y;
|
||||
auto minval = x < y ? x : y;
|
||||
if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf)
|
||||
return maxval;
|
||||
float m = exp(cuCrealf(minval) - cuCrealf(maxval));
|
||||
cuComplex dexp{
|
||||
m * cos(cuCimagf(minval) - cuCimagf(maxval)),
|
||||
m * sin(cuCimagf(minval) - cuCimagf(maxval)),
|
||||
};
|
||||
return maxval + log1p(dexp);
|
||||
}
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
|
@ -45,6 +45,18 @@ struct CastOp<
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
struct CastOp<
|
||||
SrcT,
|
||||
DstT,
|
||||
cuda::std::enable_if_t<cuda::std::is_same_v<SrcT, DstT>>> {
|
||||
static constexpr bool is_castable = true;
|
||||
|
||||
__device__ SrcT operator()(SrcT x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
// Return an iterator that cast the value to DstT using CastOp.
|
||||
template <typename DstT, typename Iterator>
|
||||
__host__ __device__ auto make_cast_iterator(Iterator it) {
|
||||
|
@ -5,7 +5,7 @@
|
||||
#pragma once
|
||||
|
||||
// The maximum dimensions of shape/strides passed as kernel parameters.
|
||||
#define MAX_NDIM 8
|
||||
#define MAX_NDIM 10
|
||||
|
||||
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
|
||||
// warpSize variable exists, using it would prevent compile-time optimizations.
|
||||
|
@ -1,4 +1,5 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
|
@ -5,6 +5,8 @@
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <math_constants.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
struct Abs {
|
||||
@ -183,21 +185,38 @@ struct Imag {
|
||||
struct Log {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log(x);
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
auto r = log(cuCrealf(Abs{}(x)));
|
||||
auto i = atan2f(cuCimagf(x), cuCrealf(x));
|
||||
return {r, i};
|
||||
} else {
|
||||
return log(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log2(x);
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
auto y = Log{}(x);
|
||||
return {cuCrealf(y) / CUDART_LN2_F, cuCimagf(y) / CUDART_LN2_F};
|
||||
} else {
|
||||
return log2(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log10(x);
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
auto y = Log{}(x);
|
||||
return {cuCrealf(y) / CUDART_LNT_F, cuCimagf(y) / CUDART_LNT_F};
|
||||
return y;
|
||||
} else {
|
||||
return log10(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -187,8 +187,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ IdxT
|
||||
elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
|
||||
IdxT loc = elem_to_loc_nd<3>(elem, shape, strides);
|
||||
for (int i = ndim - 1; i >= 3; --i) {
|
||||
IdxT loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
@ -202,8 +202,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides,
|
||||
int ndim) {
|
||||
auto [a_loc, b_loc] = elem_to_loc_nd<3>(elem, shape, a_strides, b_strides);
|
||||
for (int i = ndim - 1; i >= 3; --i) {
|
||||
IdxT a_loc = 0;
|
||||
IdxT b_loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
@ -220,9 +221,10 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
|
||||
const int64_t* b_strides,
|
||||
const int64_t* c_strides,
|
||||
int ndim) {
|
||||
auto [a_loc, b_loc, c_loc] =
|
||||
elem_to_loc_nd<3>(elem, shape, a_strides, b_strides, c_strides);
|
||||
for (int i = ndim - 1; i >= 3; --i) {
|
||||
IdxT a_loc = 0;
|
||||
IdxT b_loc = 0;
|
||||
IdxT c_loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
@ -336,4 +338,21 @@ struct LoopedElemToLoc<1, false, OffsetT> {
|
||||
}
|
||||
};
|
||||
|
||||
inline __device__ cuComplex log1p(cuComplex in) {
|
||||
float x = cuCrealf(in);
|
||||
float y = cuCimagf(in);
|
||||
float zabs = sqrt(x * x + y * y);
|
||||
float theta = atan2f(y, x + 1);
|
||||
if (zabs < 0.5f) {
|
||||
float r = x * (2 + x) + y * y;
|
||||
if (r == 0) { // handle underflow
|
||||
return {x, theta};
|
||||
}
|
||||
return {0.5f * log1pf(r), theta};
|
||||
} else {
|
||||
auto z0 = sqrt((x + 1) * (x + 1) + y * y);
|
||||
return {log(z0), theta};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
@ -65,8 +65,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
|
||||
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
|
||||
|
||||
bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) ||
|
||||
(src.size() > UINT32_MAX) || (out.size() > UINT32_MAX);
|
||||
bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) ||
|
||||
(src.size() > INT32_MAX) || (out.size() > INT32_MAX);
|
||||
|
||||
uint32_t slice_size = std::accumulate(
|
||||
slice_sizes_.begin(), slice_sizes_.end(), 1, std::multiplies<uint32_t>());
|
||||
@ -88,7 +88,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
dtype_to_cuda_type(idx_dtype),
|
||||
nidx,
|
||||
ndim,
|
||||
large ? "int64_t" : "uint32_t"));
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_gather, std::move(kernel_names));
|
||||
@ -99,7 +99,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(out.size());
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(out.size());
|
||||
mod.append_arg<int32_t>(out.size());
|
||||
}
|
||||
mod.append_ndim_arg(src.shape());
|
||||
mod.append_ndim_arg(src.strides());
|
||||
@ -115,7 +115,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
dtype_to_cuda_type(idx_dtype),
|
||||
nidx,
|
||||
idx_ndim,
|
||||
large ? "int64_t" : "uint32_t");
|
||||
large ? "int64_t" : "int32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
@ -152,14 +152,14 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
|
||||
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
|
||||
|
||||
bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) ||
|
||||
(upd.size() > UINT32_MAX) || (out.size() > UINT32_MAX);
|
||||
bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) ||
|
||||
(upd.size() > INT32_MAX) || (out.size() > INT32_MAX);
|
||||
|
||||
uint32_t upd_post_idx_size = std::accumulate(
|
||||
int32_t upd_post_idx_size = std::accumulate(
|
||||
upd.shape().begin() + idx_ndim,
|
||||
upd.shape().end(),
|
||||
1,
|
||||
std::multiplies<uint32_t>());
|
||||
std::multiplies<int32_t>());
|
||||
|
||||
const char* op = g_scatter_ops[reduce_type_];
|
||||
std::string module_name = fmt::format(
|
||||
@ -181,7 +181,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
op,
|
||||
nidx,
|
||||
ndim,
|
||||
large ? "int64_t" : "uint32_t"));
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_scatter, std::move(kernel_names));
|
||||
@ -192,7 +192,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(upd.size());
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(upd.size());
|
||||
mod.append_arg<int32_t>(upd.size());
|
||||
}
|
||||
mod.append_ndim_arg(upd.shape());
|
||||
mod.append_ndim_arg(upd.strides());
|
||||
@ -200,7 +200,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(upd_post_idx_size);
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(upd_post_idx_size);
|
||||
mod.append_arg<int32_t>(upd_post_idx_size);
|
||||
}
|
||||
mod.append_ndim_arg(out.shape());
|
||||
mod.append_ndim_arg(out.strides());
|
||||
@ -215,7 +215,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
op,
|
||||
nidx,
|
||||
idx_ndim,
|
||||
large ? "int64_t" : "uint32_t");
|
||||
large ? "int64_t" : "int32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
@ -238,7 +238,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX;
|
||||
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
||||
|
||||
std::string module_name = fmt::format(
|
||||
"gather_axis_{}_{}",
|
||||
@ -258,7 +258,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
ndim,
|
||||
contiguous & 1 ? true : false,
|
||||
contiguous & 2 ? true : false,
|
||||
large ? "int64_t" : "uint32_t"));
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -283,9 +283,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
mod.append_arg<int64_t>(idx_size_axis);
|
||||
mod.append_arg<int64_t>(idx_size_post);
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(idx_size_pre);
|
||||
mod.append_arg<uint32_t>(idx_size_axis);
|
||||
mod.append_arg<uint32_t>(idx_size_post);
|
||||
mod.append_arg<int32_t>(idx_size_pre);
|
||||
mod.append_arg<int32_t>(idx_size_axis);
|
||||
mod.append_arg<int32_t>(idx_size_post);
|
||||
}
|
||||
mod.append_arg(remove_index(idx.shape(), axis_));
|
||||
mod.append_arg(remove_index(src.strides(), axis_));
|
||||
@ -302,7 +302,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
src.ndim() - 1,
|
||||
src.flags().row_contiguous,
|
||||
idx.flags().row_contiguous,
|
||||
large ? "int64_t" : "uint32_t");
|
||||
large ? "int64_t" : "int32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
@ -337,7 +337,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX;
|
||||
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
||||
|
||||
const char* op = reduce_type_ == ScatterAxis::Sum ? "Sum" : "Assign";
|
||||
std::string module_name = fmt::format(
|
||||
@ -360,7 +360,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
ndim,
|
||||
contiguous & 1 ? true : false,
|
||||
contiguous & 2 ? true : false,
|
||||
large ? "int64_t" : "uint32_t"));
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -385,9 +385,9 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
mod.append_arg<int64_t>(idx_size_axis);
|
||||
mod.append_arg<int64_t>(idx_size_post);
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(idx_size_pre);
|
||||
mod.append_arg<uint32_t>(idx_size_axis);
|
||||
mod.append_arg<uint32_t>(idx_size_post);
|
||||
mod.append_arg<int32_t>(idx_size_pre);
|
||||
mod.append_arg<int32_t>(idx_size_axis);
|
||||
mod.append_arg<int32_t>(idx_size_post);
|
||||
}
|
||||
mod.append_arg(remove_index(idx.shape(), axis_));
|
||||
mod.append_arg(remove_index(upd.strides(), axis_));
|
||||
@ -405,7 +405,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
idx.ndim() - 1,
|
||||
upd.flags().row_contiguous,
|
||||
idx.flags().row_contiguous,
|
||||
large ? "int64_t" : "uint32_t");
|
||||
large ? "int64_t" : "int32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
|
@ -145,7 +145,7 @@ bool compiler_supports_device_sass(Device& device) {
|
||||
}
|
||||
}
|
||||
|
||||
#define INCLUDE_PREFIX "mlx/backend/cuda/kernels/"
|
||||
#define INCLUDE_PREFIX "mlx/backend/cuda/device/"
|
||||
|
||||
constexpr const char* g_include_names[] = {
|
||||
INCLUDE_PREFIX "atomic_ops.cuh",
|
||||
|
@ -23,4 +23,11 @@ dim3 get_2d_grid_dims(
|
||||
return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
|
||||
}
|
||||
|
||||
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2) {
|
||||
auto [grid, block] = get_grid_and_block_common(dim0, dim1, dim2);
|
||||
auto [gx, gy, gz] = grid;
|
||||
auto [bx, by, bz] = block;
|
||||
return std::make_pair(dim3(gx, gy, gz), dim3(bx, by, bz));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -102,6 +102,11 @@ inline constexpr bool is_floating_v =
|
||||
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
|
||||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
|
||||
|
||||
// Type traits for detecting complex or real floating point numbers.
|
||||
template <typename T>
|
||||
inline constexpr bool is_inexact_v =
|
||||
is_floating_v<T> || cuda::std::is_same_v<T, complex64_t>;
|
||||
|
||||
// Utility to copy data from vector to array in host.
|
||||
template <int NDIM = MAX_NDIM, typename T = int32_t>
|
||||
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
|
||||
@ -121,6 +126,7 @@ dim3 get_2d_grid_dims(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
size_t divisor);
|
||||
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
|
||||
|
||||
// Return a block size that achieves maximum potential occupancy for kernel.
|
||||
template <typename T>
|
||||
@ -135,17 +141,19 @@ inline uint max_occupancy_block_dim(T kernel) {
|
||||
template <typename T>
|
||||
inline std::tuple<dim3, uint> get_launch_args(
|
||||
T kernel,
|
||||
const array& arr,
|
||||
size_t size,
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
bool large,
|
||||
int work_per_thread = 1) {
|
||||
size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread);
|
||||
size_t nthreads = cuda::ceil_div(size, work_per_thread);
|
||||
uint block_dim = max_occupancy_block_dim(kernel);
|
||||
if (block_dim > nthreads) {
|
||||
block_dim = nthreads;
|
||||
}
|
||||
dim3 num_blocks;
|
||||
if (large) {
|
||||
num_blocks = get_2d_grid_dims(arr.shape(), arr.strides(), work_per_thread);
|
||||
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
|
||||
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
|
||||
} else {
|
||||
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
|
||||
@ -153,4 +161,14 @@ inline std::tuple<dim3, uint> get_launch_args(
|
||||
return std::make_tuple(num_blocks, block_dim);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::tuple<dim3, uint> get_launch_args(
|
||||
T kernel,
|
||||
const array& arr,
|
||||
bool large,
|
||||
int work_per_thread = 1) {
|
||||
return get_launch_args(
|
||||
kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <fmt/format.h>
|
||||
@ -44,9 +45,12 @@ class MatMul {
|
||||
int64_t b_batch_stride) {
|
||||
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
||||
|
||||
auto type = dtype_to_cuda_type(dtype);
|
||||
auto scale_type = dtype_to_cuda_type(dtype);
|
||||
if (dtype == bfloat16 || dtype == float16) {
|
||||
scale_type = CUDA_R_32F;
|
||||
}
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
|
||||
&matmul_desc_, dtype_to_compute_type(dtype), type));
|
||||
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
|
||||
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
@ -65,6 +69,7 @@ class MatMul {
|
||||
&op,
|
||||
sizeof(cublasOperation_t)));
|
||||
|
||||
auto type = dtype_to_cuda_type(dtype);
|
||||
a_desc_ = create_matrix_layout(
|
||||
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
|
||||
b_desc_ = create_matrix_layout(
|
||||
@ -187,17 +192,13 @@ class MatMul {
|
||||
private:
|
||||
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case uint8:
|
||||
case uint16:
|
||||
case int8:
|
||||
case int16:
|
||||
case int32:
|
||||
return CUBLAS_COMPUTE_32I;
|
||||
case float16:
|
||||
case bfloat16:
|
||||
return CUBLAS_COMPUTE_16F;
|
||||
case float32:
|
||||
return CUBLAS_COMPUTE_32F;
|
||||
case bfloat16:
|
||||
return CUBLAS_COMPUTE_32F;
|
||||
case float32:
|
||||
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
|
||||
: CUBLAS_COMPUTE_32F;
|
||||
case float64:
|
||||
case complex64:
|
||||
return CUBLAS_COMPUTE_64F;
|
||||
@ -209,16 +210,6 @@ class MatMul {
|
||||
|
||||
cudaDataType_t dtype_to_cuda_type(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case uint8:
|
||||
return CUDA_R_8U;
|
||||
case uint16:
|
||||
return CUDA_R_16U;
|
||||
case int8:
|
||||
return CUDA_R_8I;
|
||||
case int16:
|
||||
return CUDA_R_16I;
|
||||
case int32:
|
||||
return CUDA_R_32I;
|
||||
case float16:
|
||||
return CUDA_R_16F;
|
||||
case bfloat16:
|
||||
|
@ -71,10 +71,8 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||
}
|
||||
|
||||
NO_GPU(ArgPartition)
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU(Convolution)
|
||||
NO_GPU_MULTI(DivMod)
|
||||
NO_GPU(DynamicSlice)
|
||||
NO_GPU(DynamicSliceUpdate)
|
||||
NO_GPU(FFT)
|
||||
@ -83,7 +81,6 @@ NO_GPU(GatherQMM)
|
||||
NO_GPU(Hadamard)
|
||||
NO_GPU(Load)
|
||||
NO_GPU_MULTI(LUF)
|
||||
NO_GPU(Partition)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
NO_GPU(Scan)
|
||||
@ -94,7 +91,6 @@ NO_GPU_MULTI(Eig)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_USE_FALLBACK(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(AffineQuantize)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
#include <cassert>
|
||||
@ -12,6 +13,8 @@ namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
__constant__ constexpr uint32_t rotations[2][4] = {
|
||||
{13, 15, 26, 6},
|
||||
{17, 29, 16, 24}};
|
||||
@ -47,27 +50,28 @@ __global__ void rbitsc(
|
||||
dim3 grid_dims,
|
||||
bool odd,
|
||||
uint32_t bytes_per_key) {
|
||||
uint2 index{
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y};
|
||||
if (index.x >= grid_dims.x || index.y >= grid_dims.y) {
|
||||
auto grid = cg::this_grid();
|
||||
uint thread_index = grid.thread_rank();
|
||||
uint index_x = thread_index % grid_dims.x;
|
||||
uint index_y = thread_index / grid_dims.x;
|
||||
if (index_x >= grid_dims.x || index_y >= grid_dims.y) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto kidx = 2 * index.x;
|
||||
auto kidx = 2 * index_x;
|
||||
auto key = uint2{keys[kidx], keys[kidx + 1]};
|
||||
auto half_size = grid_dims.y - odd;
|
||||
out += index.x * bytes_per_key;
|
||||
bool drop_last = odd && (index.y == half_size);
|
||||
out += index_x * bytes_per_key;
|
||||
bool drop_last = odd && (index_y == half_size);
|
||||
auto bits = threefry2x32_hash(
|
||||
key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y});
|
||||
size_t idx = size_t(index.y) << 2;
|
||||
key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y});
|
||||
size_t idx = size_t(index_y) << 2;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[idx + i] = bits.bytes[0][i];
|
||||
}
|
||||
if (!drop_last) {
|
||||
idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2;
|
||||
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||
idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2;
|
||||
if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||
int edge_bytes = (bytes_per_key % 4);
|
||||
for (int i = 0; i < edge_bytes; ++i) {
|
||||
out[idx + i] = bits.bytes[1][i];
|
||||
@ -89,30 +93,31 @@ __global__ void rbits(
|
||||
int32_t ndim,
|
||||
const __grid_constant__ Shape key_shape,
|
||||
const __grid_constant__ Strides key_strides) {
|
||||
uint2 index{
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y};
|
||||
if (index.x >= grid_dims.x || index.y >= grid_dims.y) {
|
||||
auto grid = cg::this_grid();
|
||||
uint thread_index = grid.thread_rank();
|
||||
uint index_x = thread_index % grid_dims.x;
|
||||
uint index_y = thread_index / grid_dims.x;
|
||||
if (index_x >= grid_dims.x || index_y >= grid_dims.y) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto kidx = 2 * index.x;
|
||||
auto kidx = 2 * index_x;
|
||||
auto k1_elem = elem_to_loc(kidx, key_shape.data(), key_strides.data(), ndim);
|
||||
auto k2_elem =
|
||||
elem_to_loc(kidx + 1, key_shape.data(), key_strides.data(), ndim);
|
||||
auto key = uint2{keys[k1_elem], keys[k2_elem]};
|
||||
auto half_size = grid_dims.y - odd;
|
||||
out += size_t(index.x) * bytes_per_key;
|
||||
bool drop_last = odd && (index.y == half_size);
|
||||
out += size_t(index_x) * bytes_per_key;
|
||||
bool drop_last = odd && (index_y == half_size);
|
||||
auto bits = threefry2x32_hash(
|
||||
key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y});
|
||||
size_t idx = size_t(index.y) << 2;
|
||||
key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y});
|
||||
size_t idx = size_t(index_y) << 2;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[idx + i] = bits.bytes[0][i];
|
||||
}
|
||||
if (!drop_last) {
|
||||
idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2;
|
||||
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||
idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2;
|
||||
if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||
int edge_bytes = (bytes_per_key % 4);
|
||||
for (int i = 0; i < edge_bytes; ++i) {
|
||||
out[idx + i] = bits.bytes[1][i];
|
||||
@ -153,19 +158,22 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
dim3 grid_dims{num_keys, half_size + odd};
|
||||
dim3 block_dims = get_block_dims(grid_dims.x, grid_dims.y, 1);
|
||||
dim3 num_blocks{
|
||||
cuda::ceil_div(grid_dims.x, block_dims.x),
|
||||
cuda::ceil_div(grid_dims.y, block_dims.y)};
|
||||
int64_t total = grid_dims.x * grid_dims.y;
|
||||
int32_t threads_y = 1;
|
||||
while ((total / threads_y) >= (1U << 31)) {
|
||||
threads_y *= 2;
|
||||
}
|
||||
int32_t threads_x = cuda::ceil_div(total, threads_y);
|
||||
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
|
||||
if (keys.flags().row_contiguous) {
|
||||
cu::rbitsc<<<num_blocks, block_dims, 0, stream>>>(
|
||||
cu::rbitsc<<<grid, block, 0, stream>>>(
|
||||
keys.data<uint32_t>(),
|
||||
out.data<uint8_t>(),
|
||||
grid_dims,
|
||||
odd,
|
||||
bytes_per_key);
|
||||
} else {
|
||||
cu::rbits<<<num_blocks, block_dims, 0, stream>>>(
|
||||
cu::rbits<<<grid, block, 0, stream>>>(
|
||||
keys.data<uint32_t>(),
|
||||
out.data<uint8_t>(),
|
||||
grid_dims,
|
||||
|
385
mlx/backend/cuda/rope.cu
Normal file
385
mlx/backend/cuda/rope.cu
Normal file
@ -0,0 +1,385 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__device__ void rope_single_impl(
|
||||
const T* in,
|
||||
T* out,
|
||||
int32_t offset,
|
||||
float inv_freq,
|
||||
float scale,
|
||||
int64_t stride,
|
||||
uint2 pos,
|
||||
uint2 dims) {
|
||||
float L = scale * static_cast<float>(offset);
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * inv_freq;
|
||||
float costheta = cos(theta);
|
||||
float sintheta = sin(theta);
|
||||
|
||||
// Compute the input and output indices
|
||||
uint index_1, index_2;
|
||||
if (traditional) {
|
||||
index_1 = 2 * pos.x + pos.y * stride;
|
||||
index_2 = index_1 + 1;
|
||||
} else {
|
||||
index_1 = pos.x + pos.y * stride;
|
||||
index_2 = index_1 + dims.x;
|
||||
}
|
||||
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[index_1]);
|
||||
float x2 = static_cast<float>(in[index_2]);
|
||||
float rx1;
|
||||
float rx2;
|
||||
if (forward) {
|
||||
rx1 = x1 * costheta - x2 * sintheta;
|
||||
rx2 = x1 * sintheta + x2 * costheta;
|
||||
} else {
|
||||
rx1 = x2 * sintheta + x1 * costheta;
|
||||
rx2 = x2 * costheta - x1 * sintheta;
|
||||
}
|
||||
out[index_1] = static_cast<T>(rx1);
|
||||
out[index_2] = static_cast<T>(rx2);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__global__ void rope_single(
|
||||
const T* in,
|
||||
T* out,
|
||||
const int32_t* offset,
|
||||
float scale,
|
||||
float base,
|
||||
int64_t stride,
|
||||
uint2 dims) {
|
||||
uint2 pos = make_uint2(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y);
|
||||
if (pos.x >= dims.x || pos.y >= dims.y) {
|
||||
return;
|
||||
}
|
||||
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(dims.x);
|
||||
float inv_freq = exp2(-d * base);
|
||||
rope_single_impl<T, traditional, forward>(
|
||||
in, out, *offset, inv_freq, scale, stride, pos, dims);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__global__ void rope_single_freqs(
|
||||
const T* in,
|
||||
T* out,
|
||||
const int32_t* offset,
|
||||
const float* freqs,
|
||||
float scale,
|
||||
int64_t stride,
|
||||
uint2 dims,
|
||||
int64_t freq_stride) {
|
||||
uint2 pos = make_uint2(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y);
|
||||
if (pos.x >= dims.x || pos.y >= dims.y) {
|
||||
return;
|
||||
}
|
||||
|
||||
float inv_freq = 1.0 / freqs[freq_stride * pos.x];
|
||||
rope_single_impl<T, traditional, forward>(
|
||||
in, out, *offset, inv_freq, scale, stride, pos, dims);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward, int N = 4>
|
||||
__device__ void rope_impl(
|
||||
const T* in,
|
||||
T* out,
|
||||
int offset,
|
||||
float inv_freq,
|
||||
float scale,
|
||||
const cuda::std::array<int64_t, 3> strides,
|
||||
const cuda::std::array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
uint3 pos,
|
||||
uint3 dims) {
|
||||
float L = scale * static_cast<float>(pos.y + offset);
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * inv_freq;
|
||||
float costheta = cos(theta);
|
||||
float sintheta = sin(theta);
|
||||
|
||||
// Compute the input and output indices
|
||||
size_t in_index_1, in_index_2;
|
||||
size_t out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 =
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + strides[2];
|
||||
} else {
|
||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
out_index_2 = out_index_1 + dims.x * out_strides[2];
|
||||
in_index_1 =
|
||||
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + dims.x * strides[2];
|
||||
}
|
||||
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
float rx1;
|
||||
float rx2;
|
||||
if (forward) {
|
||||
rx1 = x1 * costheta - x2 * sintheta;
|
||||
rx2 = x1 * sintheta + x2 * costheta;
|
||||
} else {
|
||||
rx1 = x2 * sintheta + x1 * costheta;
|
||||
rx2 = x2 * costheta - x1 * sintheta;
|
||||
}
|
||||
out[out_index_1] = static_cast<T>(rx1);
|
||||
out[out_index_2] = static_cast<T>(rx2);
|
||||
in_index_1 += strides[0];
|
||||
in_index_2 += strides[0];
|
||||
out_index_1 += out_strides[0];
|
||||
out_index_2 += out_strides[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__global__ void rope(
|
||||
const T* in,
|
||||
T* out,
|
||||
const int32_t* offset,
|
||||
float scale,
|
||||
float base,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
uint3 dims) {
|
||||
uint3 pos = make_uint3(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y,
|
||||
blockIdx.z * blockDim.z + threadIdx.z);
|
||||
if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) {
|
||||
return;
|
||||
}
|
||||
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(dims.x);
|
||||
float inv_freq = exp2(-d * base);
|
||||
rope_impl<T, traditional, forward>(
|
||||
in,
|
||||
out,
|
||||
*offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
pos,
|
||||
dims);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
__global__ void rope_freqs(
|
||||
const T* in,
|
||||
T* out,
|
||||
const int32_t* offset,
|
||||
const float* freqs,
|
||||
float scale,
|
||||
float base,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
uint3 dims,
|
||||
int64_t freq_stride) {
|
||||
uint3 pos = make_uint3(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
blockIdx.y * blockDim.y + threadIdx.y,
|
||||
blockIdx.z * blockDim.z + threadIdx.z);
|
||||
if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) {
|
||||
return;
|
||||
}
|
||||
|
||||
float inv_freq = 1.0 / freqs[freq_stride * pos.x];
|
||||
rope_impl<T, traditional, forward>(
|
||||
in,
|
||||
out,
|
||||
*offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
pos,
|
||||
dims);
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace fast {
|
||||
|
||||
bool RoPE::use_fallback(Stream s) {
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
void RoPE::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("RoPE::eval_gpu");
|
||||
|
||||
auto& s = stream();
|
||||
auto& in = inputs[0];
|
||||
auto& offset = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
if (in.ndim() < 3) {
|
||||
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
|
||||
}
|
||||
|
||||
cuda::std::array<int64_t, 3> strides;
|
||||
cuda::std::array<int64_t, 3> out_strides;
|
||||
bool donated = false;
|
||||
int ndim = in.ndim();
|
||||
int dispatch_ndim = in.ndim();
|
||||
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
|
||||
dispatch_ndim--;
|
||||
}
|
||||
size_t mat_size = in.shape(-2) * in.shape(-1);
|
||||
|
||||
// We apply rope to less that the whole vector so copy to output and then
|
||||
// apply in-place.
|
||||
if (dims_ < in.shape(-1)) {
|
||||
donated = true;
|
||||
auto ctype =
|
||||
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
|
||||
copy_gpu(in, out, ctype, s);
|
||||
strides[0] = mat_size;
|
||||
strides[1] = out.strides()[ndim - 2];
|
||||
strides[2] = out.strides()[ndim - 1];
|
||||
}
|
||||
|
||||
// Either copy or apply in-place
|
||||
else if (in.flags().row_contiguous) {
|
||||
if (in.is_donatable()) {
|
||||
donated = true;
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
strides[0] = mat_size;
|
||||
strides[1] = in.strides()[ndim - 2];
|
||||
strides[2] = in.strides()[ndim - 1];
|
||||
} else if (dispatch_ndim == 3) {
|
||||
// Handle non-contiguous 3D inputs
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
strides[0] = in.strides()[ndim - 3];
|
||||
strides[1] = in.strides()[ndim - 2];
|
||||
strides[2] = in.strides()[ndim - 1];
|
||||
} else {
|
||||
// Copy non-contiguous > 3D inputs into the output and treat
|
||||
// input as donated
|
||||
donated = true;
|
||||
copy_gpu(in, out, CopyType::General, s);
|
||||
strides[0] = mat_size;
|
||||
strides[1] = out.strides()[ndim - 2];
|
||||
strides[2] = out.strides()[ndim - 1];
|
||||
}
|
||||
out_strides[0] = mat_size;
|
||||
out_strides[1] = out.strides()[ndim - 2];
|
||||
out_strides[2] = out.strides()[ndim - 1];
|
||||
|
||||
// Some flags to help us dispatch below
|
||||
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
||||
bool with_freqs = inputs.size() == 3;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(donated ? out : in);
|
||||
encoder.set_input_array(offset);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(in.dtype(), "rope", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
MLX_SWITCH_BOOL(traditional_, TRADITIONAL, {
|
||||
MLX_SWITCH_BOOL(forward_, FORWARD, {
|
||||
if (single && !with_freqs) {
|
||||
auto kernel = cu::rope_single<DataType, TRADITIONAL, FORWARD>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
mat_size,
|
||||
dims);
|
||||
} else if (single) {
|
||||
auto kernel = cu::rope_single_freqs<DataType, TRADITIONAL, FORWARD>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
inputs[2].data<float>(),
|
||||
scale_,
|
||||
mat_size,
|
||||
dims,
|
||||
inputs[2].strides(0));
|
||||
} else if (with_freqs) {
|
||||
auto kernel = cu::rope_freqs<DataType, TRADITIONAL, FORWARD>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
inputs[2].data<float>(),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
strides,
|
||||
out_strides,
|
||||
in.size() / mat_size,
|
||||
dims,
|
||||
inputs[2].strides(0));
|
||||
} else {
|
||||
auto kernel = cu::rope<DataType, TRADITIONAL, FORWARD>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
scale_,
|
||||
std::log2(base_),
|
||||
strides,
|
||||
out_strides,
|
||||
in.size() / mat_size,
|
||||
dims);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
@ -86,7 +86,6 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
axis += in.ndim();
|
||||
}
|
||||
int nsort = in.shape(axis);
|
||||
int nsegments = in.data_size() / nsort;
|
||||
int last_dim = in.ndim() - 1;
|
||||
|
||||
// If we are not sorting the innermost dimension of a contiguous array,
|
||||
@ -100,7 +99,11 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
||||
encoder.add_temporary(out);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
@ -134,7 +137,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
indices.data<uint32_t>(),
|
||||
out.data<uint32_t>(),
|
||||
in.data_size(),
|
||||
nsegments,
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
stream);
|
||||
@ -144,7 +147,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
in.data<Type>(),
|
||||
out.data<Type>(),
|
||||
in.data_size(),
|
||||
nsegments,
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
stream);
|
||||
@ -177,4 +180,14 @@ void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
gpu_sort(stream(), inputs[0], out, axis_, false);
|
||||
}
|
||||
|
||||
void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("ArgPartition::eval_gpu");
|
||||
gpu_sort(stream(), inputs[0], out, axis_, true);
|
||||
}
|
||||
|
||||
void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Partition::eval_gpu");
|
||||
gpu_sort(stream(), inputs[0], out, axis_, false);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -101,10 +101,10 @@ void ternary_op_gpu_inplace(
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
auto& c_strides = strides[2];
|
||||
bool large = a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX ||
|
||||
c.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
||||
bool large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
@ -116,7 +116,7 @@ void ternary_op_gpu_inplace(
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
out.data_size(),
|
||||
out.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(a_strides),
|
||||
const_param<NDIM>(b_strides),
|
||||
@ -142,7 +142,8 @@ void ternary_op_gpu_inplace(
|
||||
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
auto kernel = cu::ternary_v<Op, DType, IdxT>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE);
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
|
@ -27,12 +27,14 @@ constexpr bool supports_unary_op() {
|
||||
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
|
||||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
|
||||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
|
||||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Log1p> ||
|
||||
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
||||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Sigmoid> ||
|
||||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Sigmoid> ||
|
||||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
|
||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
||||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p>) {
|
||||
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, BitwiseInvert>) {
|
||||
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
||||
!std::is_same_v<In, bool>;
|
||||
@ -91,7 +93,7 @@ void unary_op_gpu_inplace(
|
||||
} else {
|
||||
auto [shape, strides] = collapse_contiguous_dims(in);
|
||||
auto [in_begin, in_end] = cu::make_general_iterators<int64_t>(
|
||||
in_ptr, in.data_size(), shape, strides);
|
||||
in_ptr, in.size(), shape, strides);
|
||||
thrust::transform(policy, in_begin, in_end, out_ptr, Op());
|
||||
}
|
||||
} else {
|
||||
|
@ -31,6 +31,9 @@ const char* dtype_to_cuda_type(const Dtype& dtype) {
|
||||
if (dtype == bfloat16) {
|
||||
return "__nv_bfloat16";
|
||||
}
|
||||
if (dtype == complex64) {
|
||||
return "cuComplex";
|
||||
}
|
||||
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
|
||||
if (dtype == DTYPE) { \
|
||||
return #CPP_TYPE; \
|
||||
|
@ -4,12 +4,15 @@
|
||||
#include "mlx/backend/gpu/available.h"
|
||||
#include "mlx/backend/gpu/eval.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/thread_safey.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
std::mutex metal_operation_mutex;
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
@ -30,6 +33,7 @@ inline void check_error(MTL::CommandBuffer* cbuf) {
|
||||
}
|
||||
|
||||
void eval(array& arr) {
|
||||
std::lock_guard<std::mutex> lock(metal_operation_mutex);
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto s = arr.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
@ -78,6 +82,7 @@ void eval(array& arr) {
|
||||
}
|
||||
|
||||
void finalize(Stream s) {
|
||||
std::lock_guard<std::mutex> lock(metal_operation_mutex);
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto& d = metal::device(s.device);
|
||||
auto cb = d.get_command_buffer(s.index);
|
||||
@ -88,6 +93,7 @@ void finalize(Stream s) {
|
||||
}
|
||||
|
||||
void synchronize(Stream s) {
|
||||
std::lock_guard<std::mutex> lock(metal_operation_mutex);
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto& d = metal::device(s.device);
|
||||
auto cb = d.get_command_buffer(s.index);
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include "mlx/event.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/thread_safey.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@ -27,6 +28,7 @@ void Event::wait(Stream stream) {
|
||||
if (stream.device == Device::cpu) {
|
||||
scheduler::enqueue(stream, [*this]() mutable { wait(); });
|
||||
} else {
|
||||
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
|
||||
auto& d = metal::device(stream.device);
|
||||
d.end_encoding(stream.index);
|
||||
auto command_buffer = d.get_command_buffer(stream.index);
|
||||
@ -41,6 +43,7 @@ void Event::signal(Stream stream) {
|
||||
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
|
||||
});
|
||||
} else {
|
||||
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
|
||||
auto& d = metal::device(stream.device);
|
||||
d.end_encoding(stream.index);
|
||||
auto command_buffer = d.get_command_buffer(stream.index);
|
||||
|
@ -1,6 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include "mlx/fence.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/thread_safey.h"
|
||||
#include "mlx/scheduler.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@ -68,6 +69,7 @@ void Fence::wait(Stream stream, const array& x) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
|
||||
auto& d = metal::device(stream.device);
|
||||
auto idx = stream.index;
|
||||
|
||||
@ -116,6 +118,7 @@ void Fence::update(Stream stream, const array& x) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(gpu::metal_operation_mutex);
|
||||
auto& d = metal::device(stream.device);
|
||||
auto idx = stream.index;
|
||||
if (!f.use_fast) {
|
||||
|
7
mlx/backend/metal/thread_safey.h
Normal file
7
mlx/backend/metal/thread_safey.h
Normal file
@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <mutex>
|
||||
|
||||
namespace mlx::core::gpu {
|
||||
extern std::mutex metal_operation_mutex;
|
||||
}
|
@ -149,6 +149,11 @@ inline bool metal_fast_synch() {
|
||||
return metal_fast_synch;
|
||||
}
|
||||
|
||||
inline bool enable_tf32() {
|
||||
static bool enable_tf32_ = get_var("MLX_ENABLE_TF32", 1);
|
||||
return enable_tf32_;
|
||||
}
|
||||
|
||||
} // namespace env
|
||||
|
||||
} // namespace mlx::core
|
||||
|
5
python/tests/__main__.py
Normal file
5
python/tests/__main__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from . import mlx_tests
|
||||
|
||||
__unittest = True
|
||||
|
||||
mlx_tests.MLXTestRunner(module=None)
|
107
python/tests/cuda_skip.py
Normal file
107
python/tests/cuda_skip.py
Normal file
@ -0,0 +1,107 @@
|
||||
cuda_skip = {
|
||||
"TestArray.test_api",
|
||||
"TestBF16.test_arg_reduction_ops",
|
||||
"TestBF16.test_reduction_ops",
|
||||
"TestBlas.test_complex_gemm",
|
||||
"TestEinsum.test_ellipses",
|
||||
"TestEinsum.test_opt_einsum_test_cases",
|
||||
"TestLoad.test_load_f8_e4m3",
|
||||
"TestMemory.test_memory_info",
|
||||
"TestLayers.test_group_norm",
|
||||
"TestLayers.test_pooling",
|
||||
"TestLayers.test_quantized_embedding",
|
||||
"TestLayers.test_sin_pe",
|
||||
"TestLayers.test_upsample",
|
||||
"TestOps.test_complex_ops",
|
||||
"TestOps.test_dynamic_slicing",
|
||||
"TestOps.test_softmax",
|
||||
"TestReduce.test_axis_permutation_sums",
|
||||
"TestReduce.test_dtypes",
|
||||
"TestReduce.test_expand_sums",
|
||||
"TestReduce.test_many_reduction_axes",
|
||||
"TestUpsample.test_torch_upsample",
|
||||
# Block masked matmul NYI
|
||||
"TestBlas.test_block_masked_matmul",
|
||||
# Gather matmul NYI
|
||||
"TestBlas.test_gather_matmul",
|
||||
"TestBlas.test_gather_matmul_grad",
|
||||
# Scan NYI
|
||||
"TestAutograd.test_cumprod_grad",
|
||||
"TestOps.test_scans",
|
||||
"TestOps.test_logcumsumexp",
|
||||
# Hadamard NYI
|
||||
"TestOps.test_hadamard",
|
||||
"TestOps.test_hadamard_grad_vmap",
|
||||
# Convolutions NYI
|
||||
"TestConv.test_1d_conv_with_2d",
|
||||
"TestConv.test_asymmetric_padding",
|
||||
"TestConv.test_basic_grad_shapes",
|
||||
"TestConv.test_conv2d_unaligned_channels",
|
||||
"TestConv.test_conv_1d_groups_flipped",
|
||||
"TestConv.test_conv_general_flip_grad",
|
||||
"TestConv.test_conv_groups_grad",
|
||||
"TestConv.test_numpy_conv",
|
||||
"TestConv.test_repeated_conv",
|
||||
"TestConv.test_torch_conv_1D",
|
||||
"TestConv.test_torch_conv_1D_grad",
|
||||
"TestConv.test_torch_conv_2D",
|
||||
"TestConv.test_torch_conv_2D_grad",
|
||||
"TestConv.test_torch_conv_3D",
|
||||
"TestConv.test_torch_conv_3D_grad",
|
||||
"TestConv.test_torch_conv_depthwise",
|
||||
"TestConv.test_torch_conv_general",
|
||||
"TestConvTranspose.test_torch_conv_tranpose_1d_output_padding",
|
||||
"TestConvTranspose.test_torch_conv_transpose_1D",
|
||||
"TestConvTranspose.test_torch_conv_transpose_1D_grad",
|
||||
"TestConvTranspose.test_torch_conv_transpose_2D",
|
||||
"TestConvTranspose.test_torch_conv_transpose_2D_grad",
|
||||
"TestConvTranspose.test_torch_conv_transpose_2d_output_padding",
|
||||
"TestConvTranspose.test_torch_conv_transpose_3D",
|
||||
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
||||
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
|
||||
"TestExportImport.test_export_conv",
|
||||
"TestLayers.test_conv1d",
|
||||
"TestLayers.test_conv2d",
|
||||
"TestVmap.test_vmap_conv",
|
||||
# FFTs NYI
|
||||
"TestFFT.test_fft",
|
||||
"TestFFT.test_fft_big_powers_of_two",
|
||||
"TestFFT.test_fft_contiguity",
|
||||
"TestFFT.test_fft_exhaustive",
|
||||
"TestFFT.test_fft_grads",
|
||||
"TestFFT.test_fft_into_ifft",
|
||||
"TestFFT.test_fft_large_numbers",
|
||||
"TestFFT.test_fft_shared_mem",
|
||||
"TestFFT.test_fftn",
|
||||
# Lapack ops NYI
|
||||
"TestLinalg.test_cholesky",
|
||||
"TestLinalg.test_cholesky_inv",
|
||||
"TestLinalg.test_eig",
|
||||
"TestLinalg.test_eigh",
|
||||
"TestLinalg.test_inverse",
|
||||
"TestVmap.test_vmap_inverse",
|
||||
"TestLinalg.test_lu",
|
||||
"TestLinalg.test_lu_factor",
|
||||
"TestLinalg.test_pseudo_inverse",
|
||||
"TestLinalg.test_qr_factorization",
|
||||
"TestInit.test_orthogonal",
|
||||
"TestLinalg.test_svd_decomposition",
|
||||
"TestVmap.test_vmap_svd",
|
||||
"TestLinalg.test_tri_inverse",
|
||||
# Quantization NYI
|
||||
"TestQuantized.test_gather_matmul_grad",
|
||||
"TestQuantized.test_gather_qmm",
|
||||
"TestQuantized.test_gather_qmm_sorted",
|
||||
"TestQuantized.test_non_multiples",
|
||||
"TestQuantized.test_qmm",
|
||||
"TestQuantized.test_qmm_jvp",
|
||||
"TestQuantized.test_qmm_shapes",
|
||||
"TestQuantized.test_qmm_vjp",
|
||||
"TestQuantized.test_qmv",
|
||||
"TestQuantized.test_quantize_dequantize",
|
||||
"TestQuantized.test_qvm",
|
||||
"TestQuantized.test_qvm_splitk",
|
||||
"TestQuantized.test_small_matrix",
|
||||
"TestQuantized.test_throw",
|
||||
"TestQuantized.test_vjp_scales_biases",
|
||||
}
|
@ -9,6 +9,42 @@ import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MLXTestRunner(unittest.TestProgram):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def createTests(self, *args, **kwargs):
|
||||
super().createTests(*args, **kwargs)
|
||||
|
||||
# Asume CUDA backend in this case
|
||||
device = os.getenv("DEVICE", None)
|
||||
if device is not None:
|
||||
device = getattr(mx, device)
|
||||
else:
|
||||
device = mx.default_device()
|
||||
|
||||
if not (device == mx.gpu and not mx.metal.is_available()):
|
||||
return
|
||||
|
||||
from cuda_skip import cuda_skip
|
||||
|
||||
filtered_suite = unittest.TestSuite()
|
||||
|
||||
def filter_and_add(t):
|
||||
if isinstance(t, unittest.TestSuite):
|
||||
for sub_t in t:
|
||||
filter_and_add(sub_t)
|
||||
else:
|
||||
t_id = ".".join(t.id().split(".")[-2:])
|
||||
if t_id in cuda_skip:
|
||||
print(f"Skipping {t_id}")
|
||||
else:
|
||||
filtered_suite.addTest(t)
|
||||
|
||||
filter_and_add(self.test)
|
||||
self.test = filtered_suite
|
||||
|
||||
|
||||
class MLXTestCase(unittest.TestCase):
|
||||
@property
|
||||
def is_apple_silicon(self):
|
||||
|
@ -130,4 +130,4 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -1187,7 +1187,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
|
||||
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
|
||||
check_slices(
|
||||
np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 0, 1])
|
||||
np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 2, 1])
|
||||
)
|
||||
|
||||
# Multiple slices
|
||||
@ -2033,4 +2033,4 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -799,4 +799,4 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -193,4 +193,4 @@ class TestBF16(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -1236,4 +1236,4 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -981,4 +981,4 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -38,4 +38,4 @@ class TestConstants(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -1188,4 +1188,4 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -807,4 +807,4 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -38,7 +38,7 @@ class TestDevice(mlx_tests.MLXTestCase):
|
||||
# Restore device
|
||||
mx.set_default_device(device)
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
@unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
|
||||
def test_device_context(self):
|
||||
default = mx.default_device()
|
||||
diff = mx.cpu if default == mx.gpu else mx.gpu
|
||||
@ -114,4 +114,4 @@ class TestStream(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -294,4 +294,4 @@ class TestDouble(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -360,4 +360,4 @@ class TestEinsum(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -172,7 +172,7 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
post = mx.get_peak_memory()
|
||||
self.assertEqual(pre, post)
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
@unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
|
||||
def test_multistream_deadlock(self):
|
||||
s1 = mx.default_stream(mx.gpu)
|
||||
s2 = mx.new_stream(mx.gpu)
|
||||
@ -197,4 +197,4 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -348,4 +348,4 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -772,4 +772,4 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -607,7 +607,7 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
def test_sdpa_prommote_mask(self):
|
||||
def test_sdpa_promote_mask(self):
|
||||
mask = mx.array(2.0, mx.bfloat16)
|
||||
D = 64
|
||||
Nq = 4
|
||||
@ -653,4 +653,4 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
mlx_tests.MLXTestRunner(failfast=True)
|
||||
|
@ -320,4 +320,4 @@ class TestFFT(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -34,4 +34,4 @@ class TestGraph(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -136,4 +136,4 @@ class TestInit(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -545,4 +545,4 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -400,4 +400,4 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -83,14 +83,14 @@ class TestLosses(mlx_tests.MLXTestCase):
|
||||
logits, targets, reduction="mean"
|
||||
)
|
||||
expected_mean = mx.mean(expected_none)
|
||||
self.assertEqual(losses_mean, expected_mean)
|
||||
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
||||
|
||||
# Test with reduction 'sum'
|
||||
losses_sum = nn.losses.binary_cross_entropy(
|
||||
logits, targets, reduction="sum"
|
||||
)
|
||||
expected_sum = mx.sum(expected_none)
|
||||
self.assertEqual(losses_sum, expected_sum)
|
||||
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||
|
||||
# With weights, no label smoothing
|
||||
weights = mx.array([1.0, 2.0, 1.0, 2.0])
|
||||
@ -414,4 +414,4 @@ class TestLosses(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -60,4 +60,4 @@ class TestMemory(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -1907,4 +1907,4 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -2586,17 +2586,6 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqualArray(result, mx.array(expected))
|
||||
|
||||
def test_atleast_1d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
if isinstance(x, list) and isinstance(y, list):
|
||||
if len(x) != len(y):
|
||||
return False
|
||||
for i in range(len(x)):
|
||||
if not compare_nested_lists(x[i], y[i]):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return x == y
|
||||
|
||||
# Test 1D input
|
||||
arrays = [
|
||||
[1],
|
||||
@ -2614,23 +2603,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_1d(mx.array(array))
|
||||
np_res = np.atleast_1d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
|
||||
|
||||
def test_atleast_2d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
if isinstance(x, list) and isinstance(y, list):
|
||||
if len(x) != len(y):
|
||||
return False
|
||||
for i in range(len(x)):
|
||||
if not compare_nested_lists(x[i], y[i]):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return x == y
|
||||
|
||||
# Test 1D input
|
||||
arrays = [
|
||||
[1],
|
||||
@ -2648,23 +2625,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_2d(mx.array(array))
|
||||
np_res = np.atleast_2d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
|
||||
|
||||
def test_atleast_3d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
if isinstance(x, list) and isinstance(y, list):
|
||||
if len(x) != len(y):
|
||||
return False
|
||||
for i in range(len(x)):
|
||||
if not compare_nested_lists(x[i], y[i]):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return x == y
|
||||
|
||||
# Test 1D input
|
||||
arrays = [
|
||||
[1],
|
||||
@ -2682,10 +2647,9 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_3d(mx.array(array))
|
||||
np_res = np.atleast_3d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
|
||||
|
||||
def test_issubdtype(self):
|
||||
self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact))
|
||||
@ -3127,4 +3091,4 @@ class TestBroadcast(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -527,4 +527,4 @@ class TestSchedulers(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -576,4 +576,4 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -389,4 +389,4 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -155,4 +155,4 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
mlx_tests.MLXTestRunner(failfast=True)
|
||||
|
@ -48,4 +48,4 @@ class TestTreeUtils(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -97,4 +97,4 @@ class TestUpsample(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -725,4 +725,4 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -9,7 +9,9 @@ FetchContent_MakeAvailable(doctest)
|
||||
|
||||
add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)
|
||||
|
||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
||||
if(MLX_BUILD_METAL)
|
||||
set(METAL_TEST_SOURCES gpu_tests.cpp metal_thread_safety_tests.cpp)
|
||||
elseif(MLX_BUILD_CUDA)
|
||||
set(METAL_TEST_SOURCES gpu_tests.cpp)
|
||||
endif()
|
||||
|
||||
|
@ -589,6 +589,7 @@ TEST_CASE("test array shared buffer") {
|
||||
array b = array(buf_b, shape, float32, deleter);
|
||||
|
||||
eval(a + b);
|
||||
synchronize(); // ensure all operations complete before test ends
|
||||
}
|
||||
|
||||
TEST_CASE("test make empty array") {
|
||||
|
250
tests/metal_thread_safety_tests.cpp
Normal file
250
tests/metal_thread_safety_tests.cpp
Normal file
@ -0,0 +1,250 @@
|
||||
#include "doctest/doctest.h"
|
||||
#include "mlx/mlx.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <mutex>
|
||||
#include <iostream>
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
// Helper function to run operations across multiple threads with pre-created streams
|
||||
void run_in_threads(int num_threads, const std::function<void(int, Stream)>& func,
|
||||
const std::vector<Stream>& streams) {
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(num_threads);
|
||||
for (int i = 0; i < num_threads; ++i) {
|
||||
threads.emplace_back(func, i, streams[i % streams.size()]);
|
||||
}
|
||||
for (auto& t : threads) {
|
||||
if (t.joinable()) {
|
||||
t.join();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function for tasks not requiring streams (e.g., using default stream)
|
||||
void run_in_threads_default(int num_threads, const std::function<void(int)>& func) {
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(num_threads);
|
||||
for (int i = 0; i < num_threads; ++i) {
|
||||
threads.emplace_back(func, i);
|
||||
}
|
||||
for (auto& t : threads) {
|
||||
if (t.joinable()) {
|
||||
t.join();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Thread-safe result collection
|
||||
struct TestResults {
|
||||
std::mutex mutex;
|
||||
std::vector<bool> shape_checks;
|
||||
std::vector<bool> availability_checks;
|
||||
std::vector<bool> value_checks;
|
||||
std::vector<float> expected_values;
|
||||
std::vector<float> actual_values;
|
||||
|
||||
void record_result(bool shape_ok, bool available_ok, bool value_ok,
|
||||
float expected, float actual) {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
shape_checks.push_back(shape_ok);
|
||||
availability_checks.push_back(available_ok);
|
||||
value_checks.push_back(value_ok);
|
||||
expected_values.push_back(expected);
|
||||
actual_values.push_back(actual);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_CASE("test metal concurrent eval operations") {
|
||||
Device D_GPU = Device::gpu;
|
||||
const int num_threads = std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() : 8;
|
||||
const int ops_per_thread = 10;
|
||||
const int array_size = 32;
|
||||
std::atomic<int> completed_ops{0};
|
||||
TestResults results;
|
||||
|
||||
// Pre-create streams to avoid concurrent stream creation
|
||||
std::vector<Stream> streams;
|
||||
for (int i = 0; i < num_threads; ++i) {
|
||||
streams.push_back(new_stream(D_GPU));
|
||||
}
|
||||
synchronize(); // Ensure stream creation is complete
|
||||
|
||||
auto task = [&](int thread_id, Stream s) {
|
||||
try {
|
||||
for (int i = 0; i < ops_per_thread; ++i) {
|
||||
float val1 = static_cast<float>(thread_id * ops_per_thread + i + 1);
|
||||
float val2 = val1 * 2.0f;
|
||||
|
||||
auto x = full({array_size, array_size}, val1, s);
|
||||
auto y = full({array_size, array_size}, val2, s);
|
||||
auto z = add(x, y);
|
||||
eval(z);
|
||||
|
||||
bool shape_ok = (z.shape() == Shape{array_size, array_size});
|
||||
bool available_ok = z.is_available();
|
||||
|
||||
// Get a value from the array
|
||||
int mid = array_size/2;
|
||||
auto sample = slice(z, {mid, mid}, {mid+1, mid+1});
|
||||
float actual = sample.item<float>();
|
||||
float expected = val1 + val2;
|
||||
|
||||
bool values_match = (std::abs(actual - expected) < 1e-5);
|
||||
|
||||
results.record_result(shape_ok, available_ok, values_match, expected, actual);
|
||||
|
||||
if (shape_ok && available_ok && values_match) {
|
||||
completed_ops++;
|
||||
}
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Thread " << thread_id << " exception: " << e.what() << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
// Run the threads with pre-created streams
|
||||
CHECK_NOTHROW(run_in_threads(num_threads, task, streams));
|
||||
|
||||
// Check all results outside of threads
|
||||
for (size_t i = 0; i < results.shape_checks.size(); ++i) {
|
||||
CAPTURE(i); // Help identify which operation failed
|
||||
CHECK(results.shape_checks[i]);
|
||||
CHECK(results.availability_checks[i]);
|
||||
CHECK(results.value_checks[i]);
|
||||
if (!results.value_checks[i]) {
|
||||
CAPTURE(results.expected_values[i]);
|
||||
CAPTURE(results.actual_values[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Verify all operations completed successfully
|
||||
CHECK_EQ(completed_ops.load(), num_threads * ops_per_thread);
|
||||
}
|
||||
|
||||
TEST_CASE("test metal high contention on default stream eval") {
|
||||
Device D_GPU = Device::gpu;
|
||||
const int num_threads = std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() : 8;
|
||||
const int ops_per_thread = 5;
|
||||
const int array_size = 16;
|
||||
Stream default_gpu_stream = default_stream(D_GPU);
|
||||
std::atomic<int> successful_ops{0};
|
||||
std::vector<std::string> thread_errors;
|
||||
std::mutex errors_mutex;
|
||||
TestResults results;
|
||||
|
||||
auto task = [&](int thread_id) {
|
||||
try {
|
||||
for (int i = 0; i < ops_per_thread; ++i) {
|
||||
float val = static_cast<float>(thread_id * 100 + i + 1);
|
||||
auto x = full({array_size, array_size}, val, default_gpu_stream);
|
||||
auto y = full({array_size, array_size}, val * 0.5f, default_gpu_stream);
|
||||
auto z = multiply(x, y);
|
||||
eval(z);
|
||||
|
||||
// Sample a value
|
||||
auto sample = slice(z, {0, 0}, {1, 1});
|
||||
float actual = sample.item<float>();
|
||||
float expected = val * val * 0.5f;
|
||||
|
||||
bool shape_ok = (z.shape() == Shape{array_size, array_size});
|
||||
bool available_ok = z.is_available();
|
||||
bool values_match = (std::abs(actual - expected) < 1e-5);
|
||||
|
||||
results.record_result(shape_ok, available_ok, values_match, expected, actual);
|
||||
|
||||
if (shape_ok && available_ok && values_match) {
|
||||
successful_ops++;
|
||||
}
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
std::lock_guard<std::mutex> lock(errors_mutex);
|
||||
thread_errors.push_back(std::string("Thread ") +
|
||||
std::to_string(thread_id) +
|
||||
" exception: " + e.what());
|
||||
}
|
||||
};
|
||||
|
||||
// Use the default helper for this test since it uses the default stream
|
||||
CHECK_NOTHROW(run_in_threads_default(num_threads, task));
|
||||
|
||||
// Check for thread errors
|
||||
CHECK(thread_errors.empty());
|
||||
if (!thread_errors.empty()) {
|
||||
for (const auto& err : thread_errors) {
|
||||
CAPTURE(err);
|
||||
}
|
||||
}
|
||||
|
||||
// Check all results
|
||||
for (size_t i = 0; i < results.shape_checks.size(); ++i) {
|
||||
CAPTURE(i);
|
||||
CHECK(results.shape_checks[i]);
|
||||
CHECK(results.availability_checks[i]);
|
||||
CHECK(results.value_checks[i]);
|
||||
if (!results.value_checks[i]) {
|
||||
CAPTURE(results.expected_values[i]);
|
||||
CAPTURE(results.actual_values[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Verify operation count
|
||||
CHECK_EQ(successful_ops.load(), num_threads * ops_per_thread);
|
||||
}
|
||||
|
||||
TEST_CASE("test metal concurrent graph eval from different threads") {
|
||||
Device D_GPU = Device::gpu;
|
||||
const int num_threads = std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() : 4; // Keep modest for clarity
|
||||
const int array_size = 64;
|
||||
TestResults all_results;
|
||||
|
||||
// Pre-create streams
|
||||
std::vector<Stream> streams;
|
||||
for (int i = 0; i < num_threads; ++i) {
|
||||
streams.push_back(new_stream(D_GPU));
|
||||
}
|
||||
synchronize();
|
||||
|
||||
auto task = [&](int thread_id, Stream s) {
|
||||
try {
|
||||
float val1_base = static_cast<float>(thread_id + 1) * 10.0f;
|
||||
auto x = full({array_size, array_size}, val1_base, s);
|
||||
auto y = full({array_size, array_size}, val1_base + 1.0f, s);
|
||||
auto z = add(x, y);
|
||||
auto w = multiply(z, x);
|
||||
eval(w);
|
||||
|
||||
float expected_val = (val1_base + (val1_base + 1.0f)) * val1_base;
|
||||
auto sample = slice(w, {0,0}, {1,1});
|
||||
float actual_val = sample.item<float>();
|
||||
|
||||
bool shape_ok = (w.shape() == Shape{array_size, array_size});
|
||||
bool available_ok = w.is_available();
|
||||
bool value_ok = (std::abs(actual_val - expected_val) < 1e-4);
|
||||
|
||||
all_results.record_result(shape_ok, available_ok, value_ok, expected_val, actual_val);
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Thread " << thread_id << " exception in concurrent graph eval: " << e.what() << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
CHECK_NOTHROW(run_in_threads(num_threads, task, streams));
|
||||
|
||||
CHECK_EQ(all_results.shape_checks.size(), num_threads); // One result per thread
|
||||
for (size_t i = 0; i < num_threads; ++i) {
|
||||
CAPTURE(i);
|
||||
CHECK(all_results.shape_checks[i]);
|
||||
CHECK(all_results.availability_checks[i]);
|
||||
CHECK(all_results.value_checks[i]);
|
||||
if (!all_results.value_checks[i]) {
|
||||
CAPTURE(all_results.expected_values[i]);
|
||||
CAPTURE(all_results.actual_values[i]);
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user