Compare commits

..

10 Commits

Author SHA1 Message Date
Angelos Katharopoulos
a9c720e8cd Improve the ring backend initialization 2025-07-11 15:31:28 -07:00
Cheng
8347575ba1 [CUDA] Implement Scan kernel (#2347)
* Contiguous scan

* Strided scan

* Enable tests

* Fix failing logaddexp test

* Use cexpf in Metal
2025-07-10 16:54:12 -07:00
Angelos Katharopoulos
b6eec20260 Fix edge check in qmm_n QuantizedLoader (#2355) 2025-07-10 16:28:50 -07:00
Angelos Katharopoulos
0eb035b4b1 Fix type promotion in Adam with bias correction (#2350) 2025-07-10 11:14:42 -07:00
Cheng
afb9817599 [CUDA] Put version in ptx cache dir path (#2352) 2025-07-10 07:24:21 -07:00
Cheng
8fb3e7a26c [CUDA] Set current device before cudaGraphLaunch (#2351) 2025-07-10 07:24:02 -07:00
jhavukainen
8c7bc30ce4 Align mlx::core::min op nan propagation with NumPy (#2346) 2025-07-10 06:20:43 -07:00
Cheng
85873cb162 [CUDA] Do vectorized store/load in contiguous elementwise ops (#2342)
* Do vectorized store/load in unary ops

* Do vectorized store/load in binary_two ops

* Do vectorized store/load in copy ops

* Do vectorized store/load in ternary ops

* Use int32_t for IdxT

* binary => binary_two in binary_two.cu

* Fix tests on large arrays

* Use uint as index type

* Contig uses uint as index and non-contig uses int
2025-07-09 18:48:43 -07:00
Awni Hannun
e14ee12491 add zero for argsort vjp (#2345) 2025-07-09 14:37:14 -07:00
jhavukainen
8b9a3f3cea Align mlx::core::max op nan propagation with NumPy (#2339)
* Make max op NaN propagation rules align with numpy

* Adding benchmarks and testing for max op nanpropagation

* Pre-commit formatting

* Fix max complex64 nan propagation and add test

* Improve the cpp unittest

* Only check nans on non-integral types in simd_reduce_impl.

* Cleanup using namespace alias

* Add cpu Max nanpropagation. Fix a small fib in cpu max dispatch data types for int8/int16.

* Make the max nanpropagation test more meaningful for integer types

* Remove tuple unpacking syntax to comply with earlier python versions. Add cuda skip to nanpropagation tests, fix cuda implementation in a separate PR.
2025-07-09 11:26:27 -07:00
33 changed files with 1314 additions and 265 deletions

View File

@@ -192,6 +192,22 @@ void time_reductions() {
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
TIME(argmin_along_1);
auto indices = mx::array({1});
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
std::vector<int> axes{0};
auto b = scatter(a, {indices}, updates, axes);
mx::eval(b);
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
TIME(max_along_0);
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
TIME(max_along_1);
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
TIME(min_along_0);
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
TIME(min_along_1);
}
void time_gather_scatter() {

View File

@@ -51,6 +51,20 @@ def time_maximum():
time_fn(mx.maximum, a, b)
def time_max():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.max, a, 0)
def time_min():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.min, a, 0)
def time_negative():
a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a)
@@ -108,6 +122,8 @@ if __name__ == "__main__":
time_add()
time_matmul()
time_min()
time_max()
time_maximum()
time_exp()
time_negative()

View File

@@ -325,7 +325,15 @@ struct MaxReduce {
};
template <int N, typename T>
T operator()(simd::Simd<T, N> x) {
std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
return simd::max(x);
};
template <int N, typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
if (simd::any(x != x)) {
return static_cast<T>(NAN);
}
return simd::max(x);
};
};
@@ -342,7 +350,15 @@ struct MinReduce {
};
template <int N, typename T>
T operator()(simd::Simd<T, N> x) {
std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
return simd::min(x);
};
template <int N, typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
if (simd::any(x != x)) {
return static_cast<T>(NAN);
}
return simd::min(x);
};
};
@@ -527,10 +543,10 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<int16_t>(in, out, reduce_type_, axes_);
break;
case int32:
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);

View File

@@ -35,6 +35,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
@@ -67,6 +68,11 @@ target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
target_compile_options(mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
# Enable calling host constexpr functions from device. This is needed because
# the constexpr version of isnan is host only.
target_compile_options(
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>")
# CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive.
# Explicitly pass this flag to suppress the warning, it is safe to set it to
# true but the warning wouldn't be suppressed.

View File

@@ -20,15 +20,10 @@ namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
int remaining = size - index * N_READS;
if (remaining <= 0) {
return;
}
if (remaining < N_READS) {
for (int i = 0; i < remaining; ++i) {
IdxT offset = index * N_READS + i;
out[offset] = Op{}(a[0], b[0]);
if ((index + 1) * N_READS > size) {
for (int i = index * N_READS; i < size; ++i) {
out[i] = Op{}(a[0], b[0]);
}
} else {
AlignedVector<Out, N_READS> out_vec;
@@ -44,15 +39,10 @@ __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
int remaining = size - index * N_READS;
if (remaining <= 0) {
return;
}
if (remaining < N_READS) {
for (int i = 0; i < remaining; ++i) {
IdxT offset = index * N_READS + i;
out[offset] = Op{}(a[0], b[offset]);
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(a[0], b[i]);
}
} else {
auto b_vec = load_vector<N_READS>(b, index);
@@ -70,15 +60,10 @@ __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
int remaining = size - index * N_READS;
if (remaining <= 0) {
return;
}
if (remaining < N_READS) {
for (int i = 0; i < remaining; ++i) {
IdxT offset = index * N_READS + i;
out[offset] = Op{}(a[offset], b[0]);
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(a[i], b[0]);
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
@@ -96,15 +81,10 @@ __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
int remaining = size - index * N_READS;
if (remaining <= 0) {
return;
}
if (remaining < N_READS) {
for (int i = 0; i < remaining; ++i) {
IdxT offset = index * N_READS + i;
out[offset] = Op{}(a[offset], b[offset]);
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(a[i], b[i]);
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
@@ -267,7 +247,7 @@ void binary_op_gpu_inplace(
}
});
} else {
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
// TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;

View File

@@ -17,52 +17,119 @@ namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT>
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void
binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[0], b[0]);
out_a[0] = out[0];
out_b[0] = out[1];
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
auto out = Op{}(a[0], b[0]);
out_a[i] = out[0];
out_b[i] = out[1];
}
} else {
AlignedVector<Out, N_READS> out_a_vec;
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a[0], b[0]);
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
store_vector<N_READS>(out_b, index, out_b_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void
binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[0], b[index]);
out_a[index] = out[0];
out_b[index] = out[1];
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
auto out = Op{}(a[0], b[i]);
out_a[i] = out[0];
out_b[i] = out[1];
}
} else {
auto b_vec = load_vector<N_READS>(b, index);
AlignedVector<Out, N_READS> out_a_vec;
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a[0], b_vec.val[i]);
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
store_vector<N_READS>(out_b, index, out_b_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void
binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[index], b[0]);
out_a[index] = out[0];
out_b[index] = out[1];
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
auto out = Op{}(a[i], b[0]);
out_a[i] = out[0];
out_b[i] = out[1];
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
AlignedVector<Out, N_READS> out_a_vec;
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec.val[i], b[0]);
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
store_vector<N_READS>(out_b, index, out_b_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void
binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[index], b[index]);
out_a[index] = out[0];
out_b[index] = out[1];
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
auto out = Op{}(a[i], b[i]);
out_a[i] = out[0];
out_b[i] = out[1];
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
auto b_vec = load_vector<N_READS>(b, index);
AlignedVector<Out, N_READS> out_a_vec;
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec.val[i], b_vec.val[i]);
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
store_vector<N_READS>(out_b, index, out_b_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
__global__ void binary_g_nd(
__global__ void binary_two_g_nd(
const In* a,
const In* b,
Out* out_a,
@@ -82,7 +149,7 @@ __global__ void binary_g_nd(
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_g(
__global__ void binary_two_g(
const In* a,
const In* b,
Out* out_a,
@@ -103,7 +170,7 @@ __global__ void binary_g(
}
template <typename Op, typename In, typename Out>
constexpr bool supports_binary_op() {
constexpr bool supports_binary_two_op() {
if (std::is_same_v<Op, DivMod>) {
return std::is_same_v<In, Out> &&
(std::is_integral_v<Out> || is_floating_v<Out>);
@@ -114,7 +181,7 @@ constexpr bool supports_binary_op() {
} // namespace cu
template <typename Op>
void binary_op_gpu_inplace(
void binary_two_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
@@ -141,7 +208,7 @@ void binary_op_gpu_inplace(
dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) {
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
if constexpr (cu::supports_binary_two_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
@@ -161,8 +228,12 @@ void binary_op_gpu_inplace(
int ndim = shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu::
binary_g_nd<Op, InType, OutType, IdxT, dims_constant()>;
auto kernel = cu::binary_two_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large());
encoder.add_kernel_node(
@@ -179,7 +250,7 @@ void binary_op_gpu_inplace(
const_param<dims_constant()>(b_strides));
});
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto kernel = cu::binary_two_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large());
encoder.add_kernel_node(
@@ -198,22 +269,25 @@ void binary_op_gpu_inplace(
}
});
} else {
dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) {
dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
// TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;
auto kernel = cu::binary_two_ss<Op, InType, OutType, IdxT, N_READS>;
if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
kernel = cu::binary_two_sv<Op, InType, OutType, IdxT, N_READS>;
} else if (bopt == BinaryOpType::VectorScalar) {
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
kernel = cu::binary_two_vs<Op, InType, OutType, IdxT, N_READS>;
} else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
kernel = cu::binary_two_vv<Op, InType, OutType, IdxT, N_READS>;
}
auto [num_blocks, block_dims] = get_launch_args(
kernel,
out_a.data_size(),
out_a.shape(),
out_a.strides(),
large());
large(),
N_READS);
encoder.add_kernel_node(
kernel,
num_blocks,
@@ -237,7 +311,7 @@ void binary_op_gpu_inplace(
}
template <typename Op>
void binary_op_gpu(
void binary_two_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
@@ -247,7 +321,7 @@ void binary_op_gpu(
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt);
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
binary_two_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
void DivMod::eval_gpu(
@@ -255,7 +329,7 @@ void DivMod::eval_gpu(
std::vector<array>& outputs) {
nvtx3::scoped_range r("DivMod::eval_gpu");
auto& s = outputs[0].primitive().stream();
binary_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s);
binary_two_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s);
}
} // namespace mlx::core

View File

@@ -10,19 +10,43 @@ namespace cu {
namespace cg = cooperative_groups;
template <typename In, typename Out, typename IdxT>
template <typename In, typename Out, typename IdxT, int N_READS>
__global__ void copy_s(const In* in, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = CastOp<In, Out>{}(in[0]);
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = cast_to<Out>(in[0]);
}
} else {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = cast_to<Out>(in[0]);
}
store_vector<N_READS>(out, index, out_vec);
}
}
template <typename In, typename Out, typename IdxT>
template <typename In, typename Out, typename IdxT, int N_READS>
__global__ void copy_v(const In* in, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = CastOp<In, Out>{}(in[index]);
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = cast_to<Out>(in[i]);
}
} else {
auto in_vec = load_vector<N_READS>(in, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = cast_to<Out>(in_vec.val[i]);
}
store_vector<N_READS>(out, index, out_vec);
}
}
@@ -41,12 +65,19 @@ void copy_contiguous(
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
auto kernel = cu::copy_s<InType, OutType, IdxT>;
// TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;
auto kernel = cu::copy_s<InType, OutType, IdxT, N_READS>;
if (ctype == CopyType::Vector) {
kernel = cu::copy_v<InType, OutType, IdxT>;
kernel = cu::copy_v<InType, OutType, IdxT, N_READS>;
}
auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), large());
kernel,
out.data_size(),
out.shape(),
out.strides(),
large(),
N_READS);
encoder.add_kernel_node(
kernel,
num_blocks,

View File

@@ -57,6 +57,14 @@ void Device::make_current() {
}
}
CommandEncoder& Device::get_command_encoder(Stream s) {
auto it = encoders_.find(s.index);
if (it == encoders_.end()) {
it = encoders_.try_emplace(s.index, *this).first;
}
return it->second;
}
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
CHECK_CUDA_ERROR(
@@ -168,15 +176,7 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
}
}
CommandEncoder& Device::get_command_encoder(Stream s) {
auto it = encoders_.find(s.index);
if (it == encoders_.end()) {
it = encoders_.try_emplace(s.index, *this).first;
}
return it->second;
}
CommandEncoder::CommandEncoder(Device& d) : stream_(d) {
CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) {
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
}
@@ -287,6 +287,7 @@ void CommandEncoder::commit() {
CHECK_CUDA_ERROR(
cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0));
}
device_.make_current();
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
// TODO smarter cache policy

View File

@@ -93,6 +93,7 @@ class CommandEncoder {
void insert_graph_dependencies(GraphNode node);
void insert_graph_dependencies(std::vector<GraphNode> nodes);
Device& device_;
CudaStream stream_;
cudaGraph_t graph_;
Worker worker_;

View File

@@ -1,10 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/device/unary_ops.cuh"
#include <cuComplex.h>
#include <cuda/std/array>
namespace mlx::core::cu {
@@ -114,36 +111,38 @@ struct LessEqual {
struct LogAddExp {
template <typename T>
__device__ T operator()(T x, T y) {
if (isnan(x) || isnan(y)) {
return cuda::std::numeric_limits<T>::quiet_NaN();
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
isnan(cuCimagf(y))) {
return {
cuda::std::numeric_limits<float>::quiet_NaN(),
cuda::std::numeric_limits<float>::quiet_NaN()};
}
auto max = cuCrealf(x) > cuCrealf(y) ? x : y;
auto min = cuCrealf(x) < cuCrealf(y) ? x : y;
auto min_real = cuCrealf(min);
auto max_real = cuCrealf(max);
if (!isfinite(min_real) && (min_real == max_real)) {
if (min_real < 0) {
return min;
} else {
return Log{}(Exp{}(min) + Exp{}(max));
}
} else {
return Log1p{}(Exp{}(min - max)) + max;
}
} else {
if (isnan(x) || isnan(y)) {
return cuda::std::numeric_limits<T>::quiet_NaN();
}
T maxval = max(x, y);
T minval = min(x, y);
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
maxval == cuda::std::numeric_limits<T>::infinity())
? maxval
: T(float(maxval) + log1p(expf(minval - maxval)));
}
T maxval = max(x, y);
T minval = min(x, y);
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
maxval == cuda::std::numeric_limits<T>::infinity())
? maxval
: T(float(maxval) + log1p(expf(minval - maxval)));
};
__device__ cuComplex operator()(cuComplex x, cuComplex y) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
isnan(cuCimagf(y))) {
return {
cuda::std::numeric_limits<float>::quiet_NaN(),
cuda::std::numeric_limits<float>::quiet_NaN()};
}
float inf = cuda::std::numeric_limits<float>::infinity();
auto maxval = x > y ? x : y;
auto minval = x < y ? x : y;
if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf)
return maxval;
float m = exp(cuCrealf(minval) - cuCrealf(maxval));
cuComplex dexp{
m * cos(cuCimagf(minval) - cuCimagf(maxval)),
m * sin(cuCimagf(minval) - cuCimagf(maxval)),
};
return maxval + log1p(dexp);
}
};
struct Maximum {

View File

@@ -0,0 +1,138 @@
// Copyright © 2025 Apple Inc.
// Copyright © 2008-2013 NVIDIA Corporation
// Copyright © 2013 Filipe RNC Maia
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Forked from
// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h
// TODO: We should use thrust::exp but the thrust header in old CUDA versions
// can not be used in JIT.
#pragma once
#include <cuComplex.h>
#include <cuda/std/cstdint>
namespace mlx::core::cu::detail {
using ieee_float_shape_type = union {
float value;
uint32_t word;
};
inline __device__ void get_float_word(uint32_t& i, float d) {
ieee_float_shape_type gf_u;
gf_u.value = (d);
(i) = gf_u.word;
}
inline __device__ void get_float_word(int32_t& i, float d) {
ieee_float_shape_type gf_u;
gf_u.value = (d);
(i) = gf_u.word;
}
inline __device__ void set_float_word(float& d, uint32_t i) {
ieee_float_shape_type sf_u;
sf_u.word = (i);
(d) = sf_u.value;
}
inline __device__ float frexp_expf(float x, int* expt) {
const uint32_t k = 235;
const float kln2 = 162.88958740F;
float exp_x;
uint32_t hx;
exp_x = expf(x - kln2);
get_float_word(hx, exp_x);
*expt = (hx >> 23) - (0x7f + 127) + k;
set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23));
return exp_x;
}
inline __device__ cuComplex ldexp_cexpf(cuComplex z, int expt) {
float x, y, exp_x, scale1, scale2;
int ex_expt, half_expt;
x = cuCrealf(z);
y = cuCimagf(z);
exp_x = frexp_expf(x, &ex_expt);
expt += ex_expt;
half_expt = expt / 2;
set_float_word(scale1, (0x7f + half_expt) << 23);
half_expt = expt - half_expt;
set_float_word(scale2, (0x7f + half_expt) << 23);
return cuComplex{
cosf(y) * exp_x * scale1 * scale2, sinf(y) * exp_x * scale1 * scale2};
}
inline __device__ cuComplex cexpf(const cuComplex& z) {
float x, y, exp_x;
uint32_t hx, hy;
const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074;
x = cuCrealf(z);
y = cuCimagf(z);
get_float_word(hy, y);
hy &= 0x7fffffff;
/* cexp(x + I 0) = exp(x) + I 0 */
if (hy == 0) {
return cuComplex{expf(x), y};
}
get_float_word(hx, x);
/* cexp(0 + I y) = cos(y) + I sin(y) */
if ((hx & 0x7fffffff) == 0) {
return cuComplex{cosf(y), sinf(y)};
}
if (hy >= 0x7f800000) {
if ((hx & 0x7fffffff) != 0x7f800000) {
/* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */
return cuComplex{y - y, y - y};
} else if (hx & 0x80000000) {
/* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */
return cuComplex{0.0, 0.0};
} else {
/* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */
return cuComplex{x, y - y};
}
}
if (hx >= exp_ovfl && hx <= cexp_ovfl) {
/*
* x is between 88.7 and 192, so we must scale to avoid
* overflow in expf(x).
*/
return ldexp_cexpf(z, 0);
} else {
/*
* Cases covered here:
* - x < exp_ovfl and exp(x) won't overflow (common case)
* - x > cexp_ovfl, so exp(x) * s overflows for all s > 0
* - x = +-Inf (generated by exp())
* - x = NaN (spurious inexact exception from y)
*/
exp_x = expf(x);
return cuComplex{exp_x * cosf(y), exp_x * sinf(y)};
}
}
} // namespace mlx::core::cu::detail

View File

@@ -2,6 +2,8 @@
#pragma once
#include "mlx/backend/cuda/device/cexpf.cuh"
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh"
@@ -150,8 +152,7 @@ struct Exp {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto m = exp(cuCrealf(x));
return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))};
return detail::cexpf(x);
} else {
return exp(x);
}
@@ -228,8 +229,25 @@ struct Log10 {
struct Log1p {
template <typename T>
__device__ T operator()(T x) {
return log1p(x);
__device__ T operator()(T z) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
float x = cuCrealf(z);
float y = cuCimagf(z);
float zabs = cuCrealf(Abs{}(z));
float theta = atan2f(y, x + 1);
if (zabs < 0.5f) {
float r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return {x, theta};
}
return {0.5f * log1pf(r), theta};
} else {
float z0 = hypotf(x + 1, y);
return {logf(z0), theta};
}
} else {
return log1p(z);
}
}
};
@@ -387,19 +405,19 @@ struct Tanh {
}
};
__device__ cuComplex ArcCos::operator()(cuComplex x) {
inline __device__ cuComplex ArcCos::operator()(cuComplex x) {
auto i = cuComplex{0.0, 1.0};
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
return {cuCimagf(y), -cuCrealf(y)};
};
__device__ cuComplex ArcSin::operator()(cuComplex x) {
inline __device__ cuComplex ArcSin::operator()(cuComplex x) {
auto i = cuComplex{0.0f, 1.0f};
auto y = Log{}(i * x + Sqrt{}(1.0f - x * x));
return {cuCimagf(y), -cuCrealf(y)};
};
__device__ cuComplex ArcTan::operator()(cuComplex x) {
inline __device__ cuComplex ArcTan::operator()(cuComplex x) {
auto i = cuComplex{0.0f, 1.0f};
auto ix = i * x;
return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix));

View File

@@ -359,21 +359,4 @@ struct LoopedElemToLoc<1, false, OffsetT> {
}
};
inline __device__ cuComplex log1p(cuComplex in) {
float x = cuCrealf(in);
float y = cuCimagf(in);
float zabs = sqrt(x * x + y * y);
float theta = atan2f(y, x + 1);
if (zabs < 0.5f) {
float r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return {x, theta};
}
return {0.5f * log1pf(r), theta};
} else {
auto z0 = sqrt((x + 1) * (x + 1) + y * y);
return {log(z0), theta};
}
}
} // namespace mlx::core::cu

View File

@@ -2,6 +2,7 @@
#include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/version.h"
#include "cuda_jit_sources.h"
@@ -53,10 +54,11 @@ const std::string& cuda_home() {
const std::filesystem::path& ptx_cache_dir() {
static std::filesystem::path cache = []() -> std::filesystem::path {
std::filesystem::path cache;
if (auto c = std::getenv("MLX_PTX_CACHE"); c) {
if (auto c = std::getenv("MLX_PTX_CACHE_DIR"); c) {
cache = c;
} else {
cache = std::filesystem::temp_directory_path() / "mlx" / "ptx";
cache =
std::filesystem::temp_directory_path() / "mlx" / version() / "ptx";
}
if (!std::filesystem::exists(cache)) {
std::error_code error;
@@ -159,6 +161,7 @@ constexpr const char* g_include_names[] = {
INCLUDE_PREFIX "atomic_ops.cuh",
INCLUDE_PREFIX "binary_ops.cuh",
INCLUDE_PREFIX "cast_op.cuh",
INCLUDE_PREFIX "cexpf.cuh",
INCLUDE_PREFIX "config.h",
INCLUDE_PREFIX "cucomplex_math.cuh",
INCLUDE_PREFIX "fp16_math.cuh",
@@ -175,6 +178,7 @@ constexpr const char* g_headers[] = {
jit_source_atomic_ops,
jit_source_binary_ops,
jit_source_cast_op,
jit_source_cexpf,
jit_source_config,
jit_source_cucomplex_math,
jit_source_fp16_math,

View File

@@ -82,7 +82,6 @@ NO_GPU(Load)
NO_GPU_MULTI(LUF)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(Scan)
NO_GPU(SegmentedMM)
NO_GPU_MULTI(SVD)
NO_GPU(Inverse)

View File

@@ -4,6 +4,7 @@
#include <numeric>
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include <cooperative_groups.h>

467
mlx/backend/cuda/scan.cu Normal file
View File

@@ -0,0 +1,467 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/binary_ops.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/reduce/reduce_ops.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <cooperative_groups/scan.h>
#include <nvtx3/nvtx3.hpp>
#include <cassert>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename T>
struct ScanResult {
using type = T;
};
template <>
struct ScanResult<Sum, bool> {
using type = int32_t;
};
template <typename T>
struct ReduceInit<LogAddExp, T> {
static constexpr __host__ __device__ T value() {
return Limits<T>::min();
}
};
template <bool reverse, typename T, typename U, int N_READS>
inline __device__ void
load_values(int index, const T* in, U (&values)[N_READS], int size, U init) {
int remaining = size - index * N_READS;
if constexpr (reverse) {
in += remaining - N_READS;
if (remaining < N_READS) {
for (int i = 0; i < N_READS; ++i) {
values[N_READS - i - 1] =
(N_READS - i - 1 < remaining) ? cast_to<U>(in[i]) : init;
}
} else {
for (int i = 0; i < N_READS; ++i) {
values[N_READS - i - 1] = cast_to<U>(in[i]);
}
}
} else {
in += index * N_READS;
if (remaining < N_READS) {
for (int i = 0; i < N_READS; ++i) {
values[i] = (i < remaining) ? cast_to<U>(in[i]) : init;
}
} else {
for (int i = 0; i < N_READS; ++i) {
values[i] = cast_to<U>(in[i]);
}
}
}
}
template <bool reverse, int offset, typename T, int N_READS>
inline __device__ void
store_values(int index, T* out, T (&values)[N_READS], int size) {
int start = index * N_READS + offset;
int remaining = size - start;
if constexpr (reverse) {
out += remaining - N_READS;
if (remaining < N_READS) {
for (int i = 0; i < N_READS; ++i) {
if (N_READS - i - 1 < remaining) {
out[i] = values[N_READS - i - 1];
}
}
} else {
for (int i = 0; i < N_READS; ++i) {
out[i] = values[N_READS - i - 1];
}
}
} else {
out += start;
if (remaining < N_READS) {
for (int i = 0; i < N_READS; ++i) {
if (i < remaining) {
out[i] = values[i];
}
}
} else {
for (int i = 0; i < N_READS; ++i) {
out[i] = values[i];
}
}
}
}
template <
typename T,
typename U,
typename Op,
int N_READS,
bool inclusive,
bool reverse>
__global__ void contiguous_scan(const T* in, U* out, int32_t axis_size) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
in += grid.block_rank() * axis_size;
out += grid.block_rank() * axis_size;
__shared__ U warp_sums[WARP_SIZE];
Op op;
U init = ReduceInit<Op, T>::value();
U prefix = init;
// Scan per block.
for (int r = 0; r < cuda::ceil_div(axis_size, block.size() * N_READS); ++r) {
int32_t index = r * block.size() + block.thread_rank();
U values[N_READS];
load_values<reverse>(index, in, values, axis_size, init);
// Compute an inclusive scan per thread.
for (int i = 1; i < N_READS; ++i) {
values[i] = op(values[i], values[i - 1]);
}
// Compute exclusive scan of thread sums.
U prev_thread_sum = cg::exclusive_scan(warp, values[N_READS - 1], op);
if (warp.thread_rank() == 0) {
prev_thread_sum = init;
}
// Write wrap's sum to shared memory.
if (warp.thread_rank() == WARP_SIZE - 1) {
warp_sums[warp.meta_group_rank()] =
op(prev_thread_sum, values[N_READS - 1]);
}
block.sync();
// Compute exclusive scan of warp sums.
if (warp.meta_group_rank() == 0) {
U prev_warp_sum =
cg::exclusive_scan(warp, warp_sums[warp.thread_rank()], op);
if (warp.thread_rank() == 0) {
prev_warp_sum = init;
}
warp_sums[warp.thread_rank()] = prev_warp_sum;
}
block.sync();
// Compute the output.
for (int i = 0; i < N_READS; ++i) {
values[i] = op(values[i], prefix);
values[i] = op(values[i], warp_sums[warp.meta_group_rank()]);
values[i] = op(values[i], prev_thread_sum);
}
// Write the values.
if (inclusive) {
store_values<reverse, 0>(index, out, values, axis_size);
} else {
store_values<reverse, 1>(index, out, values, axis_size);
if (reverse) {
if (block.thread_rank() == 0 && index == 0) {
out[axis_size - 1] = init;
}
} else {
if (block.thread_rank() == 0 && index == 0) {
out[0] = init;
}
}
}
block.sync();
// Share the prefix.
if ((warp.meta_group_rank() == warp.meta_group_size() - 1) &&
(warp.thread_rank() == WARP_SIZE - 1)) {
warp_sums[0] = values[N_READS - 1];
}
block.sync();
prefix = warp_sums[0];
}
}
template <
typename T,
typename U,
typename Op,
int N_READS,
int BM,
int BN,
bool inclusive,
bool reverse>
__global__ void strided_scan(
const T* in,
U* out,
int32_t axis_size,
int64_t stride,
int64_t stride_blocks) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
constexpr int BN_pad = WARP_SIZE + 16 / sizeof(U);
constexpr int n_warps = BN / N_READS;
constexpr int n_scans = BN / n_warps;
__shared__ U read_buffer[BM * BN_pad];
Op op;
U init = ReduceInit<Op, T>::value();
U values[n_scans];
U prefix[n_scans];
for (int i = 0; i < n_scans; ++i) {
prefix[i] = init;
}
// Compute offsets.
int64_t offset = (grid.block_rank() / stride_blocks) * axis_size * stride;
int64_t global_index_x = (grid.block_rank() % stride_blocks) * BN;
uint read_offset_y = (block.thread_rank() * N_READS) / BN;
uint read_offset_x = (block.thread_rank() * N_READS) % BN;
uint scan_offset_y = warp.thread_rank();
uint scan_offset_x = warp.meta_group_rank() * n_scans;
uint stride_limit = stride - global_index_x;
in += offset + global_index_x + read_offset_x;
out += offset + global_index_x + read_offset_x;
U* read_into = read_buffer + read_offset_y * BN_pad + read_offset_x;
U* read_from = read_buffer + scan_offset_y * BN_pad + scan_offset_x;
for (uint j = 0; j < axis_size; j += BM) {
// Calculate the indices for the current thread.
uint index_y = j + read_offset_y;
uint check_index_y = index_y;
if (reverse) {
index_y = axis_size - 1 - index_y;
}
// Read in SM.
if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
for (int i = 0; i < N_READS; ++i) {
read_into[i] = in[index_y * stride + i];
}
} else {
for (int i = 0; i < N_READS; ++i) {
if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {
read_into[i] = in[index_y * stride + i];
} else {
read_into[i] = init;
}
}
}
block.sync();
// Read strided into registers.
for (int i = 0; i < n_scans; ++i) {
values[i] = read_from[i];
}
// Perform the scan.
for (int i = 0; i < n_scans; ++i) {
values[i] = cg::inclusive_scan(warp, values[i], op);
values[i] = op(values[i], prefix[i]);
prefix[i] = warp.shfl(values[i], WARP_SIZE - 1);
}
// Write to SM.
for (int i = 0; i < n_scans; ++i) {
read_from[i] = values[i];
}
block.sync();
// Write to device memory.
if (!inclusive) {
if (check_index_y == 0) {
if ((read_offset_x + N_READS) < stride_limit) {
for (int i = 0; i < N_READS; ++i) {
out[index_y * stride + i] = init;
}
} else {
for (int i = 0; i < N_READS; ++i) {
if ((read_offset_x + i) < stride_limit) {
out[index_y * stride + i] = init;
}
}
}
}
if (reverse) {
index_y -= 1;
check_index_y += 1;
} else {
index_y += 1;
check_index_y += 1;
}
}
if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
for (int i = 0; i < N_READS; ++i) {
out[index_y * stride + i] = read_into[i];
}
} else {
for (int i = 0; i < N_READS; ++i) {
if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {
out[index_y * stride + i] = read_into[i];
}
}
}
}
}
} // namespace cu
template <typename F>
void dispatch_scan_ops(Scan::ReduceType scan_op, F&& f) {
if (scan_op == Scan::ReduceType::Max) {
f(type_identity<cu::Max>{});
} else if (scan_op == Scan::ReduceType::Min) {
f(type_identity<cu::Min>{});
} else if (scan_op == Scan::ReduceType::Sum) {
f(type_identity<cu::Sum>{});
} else if (scan_op == Scan::ReduceType::Prod) {
f(type_identity<cu::Prod>{});
} else if (scan_op == Scan::ReduceType::LogAddExp) {
f(type_identity<cu::LogAddExp>{});
} else {
throw std::invalid_argument("Unknown reduce type.");
}
}
template <typename Op>
const char* op_to_string() {
if (cuda::std::is_same_v<Op, cu::Max>) {
return "Max";
} else if (cuda::std::is_same_v<Op, cu::Min>) {
return "Min";
} else if (cuda::std::is_same_v<Op, cu::Sum>) {
return "Sum";
} else if (cuda::std::is_same_v<Op, cu::Prod>) {
return "Prod";
} else if (cuda::std::is_same_v<Op, cu::LogAddExp>) {
return "LogAddExp";
} else {
throw std::invalid_argument("Unknown op.");
}
}
template <typename Op, typename T>
constexpr bool supports_scan_op() {
if constexpr (cuda::std::is_same_v<Op, LogAddExp>) {
return is_inexact_v<T>;
} else {
return true;
}
}
void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Scan::eval_gpu");
assert(inputs.size() == 1);
auto in = inputs[0];
auto& s = stream();
if (in.flags().contiguous && in.strides()[axis_] != 0) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, arr_copy, CopyType::General, s);
in = std::move(arr_copy);
out.copy_shared_buffer(in);
}
constexpr int N_READS = 4;
int32_t axis_size = in.shape(axis_);
bool contiguous = in.strides()[axis_] == 1;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_all_types(in.dtype(), [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
dispatch_scan_ops(reduce_type_, [&](auto scan_op_tag) {
using Op = MLX_GET_TYPE(scan_op_tag);
if constexpr (supports_scan_op<Op, T>) {
using U = typename cu::ScanResult<Op, T>::type;
dispatch_bool(inclusive_, [&](auto inclusive) {
dispatch_bool(reverse_, [&](auto reverse) {
if (contiguous) {
auto kernel = cu::contiguous_scan<
T,
U,
Op,
N_READS,
inclusive.value,
reverse.value>;
int block_dim = cuda::ceil_div(axis_size, N_READS);
block_dim = cuda::ceil_div(block_dim, WARP_SIZE) * WARP_SIZE;
block_dim = std::min(block_dim, WARP_SIZE * WARP_SIZE);
encoder.add_kernel_node(
kernel,
in.data_size() / axis_size,
block_dim,
in.data<T>(),
out.data<U>(),
axis_size);
} else {
constexpr int BM = WARP_SIZE;
constexpr int BN = WARP_SIZE;
auto kernel = cu::strided_scan<
T,
U,
Op,
N_READS,
BM,
BN,
inclusive.value,
reverse.value>;
int64_t stride = in.strides()[axis_];
int64_t stride_blocks = cuda::ceil_div(stride, BN);
dim3 num_blocks = get_2d_grid_dims(
in.shape(), in.strides(), axis_size * stride);
if (num_blocks.x * stride_blocks <= UINT32_MAX) {
num_blocks.x *= stride_blocks;
} else {
num_blocks.y *= stride_blocks;
}
int block_dim = (BN / N_READS) * WARP_SIZE;
encoder.add_kernel_node(
kernel,
num_blocks,
block_dim,
in.data<T>(),
out.data<U>(),
axis_size,
stride,
stride_blocks);
}
});
});
} else {
throw std::runtime_error(fmt::format(
"Can not do scan op {} on inputs of {} with result of {}.",
op_to_string<Op>(),
dtype_to_string(in.dtype()),
dtype_to_string(out.dtype())));
}
});
});
}
} // namespace mlx::core

View File

@@ -15,12 +15,27 @@ namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename T, typename IdxT>
template <typename Op, typename T, typename IdxT, int N_READS>
__global__ void
ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(a[index], b[index], c[index]);
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(a[i], b[i], c[i]);
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
auto b_vec = load_vector<N_READS>(b, index);
auto c_vec = load_vector<N_READS>(c, index);
AlignedVector<T, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i], c_vec.val[i]);
}
store_vector<N_READS>(out, index, out_vec);
}
}
@@ -149,11 +164,18 @@ void ternary_op_gpu_inplace(
}
});
} else {
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
auto kernel = cu::ternary_v<Op, DType, IdxT>;
// TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;
auto kernel = cu::ternary_v<Op, DType, IdxT, N_READS>;
auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), large());
kernel,
out.data_size(),
out.shape(),
out.strides(),
large(),
N_READS);
encoder.add_kernel_node(
kernel,
num_blocks,

View File

@@ -18,11 +18,24 @@ namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT>
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void unary_v(const In* in, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(in[index]);
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(in[i]);
}
} else {
auto in_vec = load_vector<N_READS>(in, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(in_vec.val[i]);
}
store_vector<N_READS>(out, index, out_vec);
}
}
@@ -112,14 +125,20 @@ void unary_op_gpu_inplace(
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
dispatch_bool(large, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
if (contig) {
auto kernel = cu::unary_v<Op, InType, OutType, IdxT>;
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
// TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;
auto kernel = cu::unary_v<Op, InType, OutType, IdxT, N_READS>;
auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), large);
kernel,
out.data_size(),
out.shape(),
out.strides(),
large,
N_READS);
encoder.add_kernel_node(
kernel,
num_blocks,
@@ -128,6 +147,7 @@ void unary_op_gpu_inplace(
out.data<OutType>(),
out.data_size());
} else {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
auto [shape, strides] = collapse_contiguous_dims(in);
auto kernel = cu::unary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);

View File

@@ -0,0 +1,134 @@
// Copyright © 2025 Apple Inc.
// Copyright © 2008-2013 NVIDIA Corporation
// Copyright © 2013 Filipe RNC Maia
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Forked from
// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h
// TODO: We should use thrust::exp but the thrust header in old CUDA versions
// can not be used in JIT.
#pragma once
#include <metal_math>
using ieee_float_shape_type = union {
float value;
uint32_t word;
};
inline void get_float_word(thread uint32_t& i, float d) {
ieee_float_shape_type gf_u;
gf_u.value = (d);
(i) = gf_u.word;
}
inline void get_float_word(thread int32_t& i, float d) {
ieee_float_shape_type gf_u;
gf_u.value = (d);
(i) = gf_u.word;
}
inline void set_float_word(thread float& d, uint32_t i) {
ieee_float_shape_type sf_u;
sf_u.word = (i);
(d) = sf_u.value;
}
inline float frexp_expf(float x, thread int* expt) {
const uint32_t k = 235;
const float kln2 = 162.88958740F;
float exp_x;
uint32_t hx;
exp_x = metal::exp(x - kln2);
get_float_word(hx, exp_x);
*expt = (hx >> 23) - (0x7f + 127) + k;
set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23));
return exp_x;
}
inline complex64_t ldexp_cexpf(complex64_t z, int expt) {
float x, y, exp_x, scale1, scale2;
int ex_expt, half_expt;
x = z.real;
y = z.imag;
exp_x = frexp_expf(x, &ex_expt);
expt += ex_expt;
half_expt = expt / 2;
set_float_word(scale1, (0x7f + half_expt) << 23);
half_expt = expt - half_expt;
set_float_word(scale2, (0x7f + half_expt) << 23);
return complex64_t{
metal::cos(y) * exp_x * scale1 * scale2,
metal::sin(y) * exp_x * scale1 * scale2};
}
inline complex64_t cexpf(const thread complex64_t& z) {
float x, y, exp_x;
uint32_t hx, hy;
const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074;
x = z.real;
y = z.imag;
get_float_word(hy, y);
hy &= 0x7fffffff;
/* cexp(x + I 0) = exp(x) + I 0 */
if (hy == 0) {
return complex64_t{metal::exp(x), y};
}
get_float_word(hx, x);
/* cexp(0 + I y) = cos(y) + I sin(y) */
if ((hx & 0x7fffffff) == 0) {
return complex64_t{metal::cos(y), metal::sin(y)};
}
if (hy >= 0x7f800000) {
if ((hx & 0x7fffffff) != 0x7f800000) {
/* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */
return complex64_t{y - y, y - y};
} else if (hx & 0x80000000) {
/* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */
return complex64_t{0.0, 0.0};
} else {
/* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */
return complex64_t{x, y - y};
}
}
if (hx >= exp_ovfl && hx <= cexp_ovfl) {
/*
* x is between 88.7 and 192, so we must scale to avoid
* overflow in expf(x).
*/
return ldexp_cexpf(z, 0);
} else {
/*
* Cases covered here:
* - x < exp_ovfl and exp(x) won't overflow (common case)
* - x > cexp_ovfl, so exp(x) * s overflows for all s > 0
* - x = +-Inf (generated by exp())
* - x = NaN (spurious inexact exception from y)
*/
exp_x = metal::exp(x);
return complex64_t{exp_x * metal::cos(y), exp_x * metal::sin(y)};
}
}

View File

@@ -643,14 +643,14 @@ struct QuantizedBlockLoader {
return;
}
if (reduction_dim == 1 && bi >= src_tile_dim.y) {
if (reduction_dim == 1 && bi >= src_tile_dim.x) {
for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0);
}
return;
}
if (reduction_dim == 0 && bi >= src_tile_dim.x) {
if (reduction_dim == 0 && bi >= src_tile_dim.y) {
for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0);
}

View File

@@ -164,7 +164,15 @@ struct Min {
DEFINE_SIMD_REDUCE()
template <typename T>
T simd_reduce_impl(T val) {
metal::enable_if_t<metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
return simd_min(val);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
if (simd_any(val != val)) {
return static_cast<T>(NAN);
}
return simd_min(val);
}
@@ -176,17 +184,52 @@ struct Min {
}
// Operator
U operator()(U a, U b) {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T a, T b) {
return a < b ? a : b;
}
};
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T a, T b) {
if (metal::isnan(a) || metal::isnan(b)) {
return static_cast<T>(NAN);
} else {
return a < b ? a : b;
}
}
template <>
complex64_t operator()(complex64_t a, complex64_t b) {
bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real);
bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag);
if (!real_is_nan && !imag_is_nan) {
return a < b ? a : b;
} else if (real_is_nan && !imag_is_nan) {
return complex64_t(
static_cast<float>(NAN), a.imag < b.imag ? a.imag : b.imag);
} else if (!real_is_nan && imag_is_nan) {
return complex64_t(
a.real < b.real ? a.real : b.real, static_cast<float>(NAN));
} else {
return complex64_t(static_cast<float>(NAN), static_cast<float>(NAN));
}
};
};
template <typename U>
struct Max {
DEFINE_SIMD_REDUCE()
template <typename T>
T simd_reduce_impl(T val) {
metal::enable_if_t<metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
return simd_max(val);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
if (simd_any(val != val)) {
return static_cast<T>(NAN);
}
return simd_max(val);
}
@@ -198,7 +241,35 @@ struct Max {
}
// Operator
U operator()(U a, U b) {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T a, T b) {
return a > b ? a : b;
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T a, T b) {
if (metal::isnan(a) || metal::isnan(b)) {
return static_cast<T>(NAN);
} else {
return a > b ? a : b;
}
}
template <>
complex64_t operator()(complex64_t a, complex64_t b) {
bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real);
bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag);
if (!real_is_nan && !imag_is_nan) {
return a > b ? a : b;
} else if (real_is_nan && !imag_is_nan) {
return complex64_t(
static_cast<float>(NAN), a.imag > b.imag ? a.imag : b.imag);
} else if (!real_is_nan && imag_is_nan) {
return complex64_t(
a.real > b.real ? a.real : b.real, static_cast<float>(NAN));
} else {
return complex64_t(static_cast<float>(NAN), static_cast<float>(NAN));
}
}
};

View File

@@ -5,6 +5,7 @@
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/cexpf.h"
#include "mlx/backend/metal/kernels/erf.h"
#include "mlx/backend/metal/kernels/expm1f.h"
@@ -178,8 +179,7 @@ struct Exp {
return metal::precise::exp(x);
};
complex64_t operator()(complex64_t x) {
auto m = metal::precise::exp(x.real);
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
return cexpf(x);
}
};

View File

@@ -22,78 +22,20 @@
#include "mlx/backend/cpu/encoder.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/dtype_utils.h"
#include "mlx/threadpool.h"
#ifndef SOL_TCP
#define SOL_TCP IPPROTO_TCP
#endif
#define SWITCH_TYPE(x, ...) \
switch ((x).dtype()) { \
case bool_: { \
using T = bool; \
__VA_ARGS__; \
} break; \
case int8: { \
using T = int8_t; \
__VA_ARGS__; \
} break; \
case int16: { \
using T = int16_t; \
__VA_ARGS__; \
} break; \
case int32: { \
using T = int32_t; \
__VA_ARGS__; \
} break; \
case int64: { \
using T = int64_t; \
__VA_ARGS__; \
} break; \
case uint8: { \
using T = uint8_t; \
__VA_ARGS__; \
} break; \
case uint16: { \
using T = uint16_t; \
__VA_ARGS__; \
} break; \
case uint32: { \
using T = uint32_t; \
__VA_ARGS__; \
} break; \
case uint64: { \
using T = uint64_t; \
__VA_ARGS__; \
} break; \
case bfloat16: { \
using T = bfloat16_t; \
__VA_ARGS__; \
} break; \
case float16: { \
using T = float16_t; \
__VA_ARGS__; \
} break; \
case float32: { \
using T = float; \
__VA_ARGS__; \
} break; \
case float64: { \
using T = double; \
__VA_ARGS__; \
} break; \
case complex64: { \
using T = complex64_t; \
__VA_ARGS__; \
} break; \
}
namespace mlx::core::distributed::ring {
constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024;
constexpr const size_t ALL_SUM_BUFFERS = 2;
constexpr const int CONN_ATTEMPTS = 5;
constexpr const int CONN_WAIT = 1000;
constexpr const int INIT_TIMEOUT = 20000;
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
using json = nlohmann::json;
@@ -503,6 +445,7 @@ std::vector<int> make_connections(
return sockets;
}
template <typename T>
struct SumOp {
void operator()(const T* input, T* output, size_t N) {
@@ -550,19 +493,27 @@ class RingGroup : public GroupImpl {
size_ = nodes.size();
int connect_to = (rank_ + 1) % size_;
// We define the connection order by having the rank_ == size_ - 1 connect
// first and accept after.
if (rank_ < connect_to) {
log_info(verbose_, "Rank", rank_, "accepting");
sockets_left_ = std::move(accept_connections(nodes[rank_]));
log_info(verbose_, "Rank", rank_, "connecting to", connect_to);
sockets_right_ = std::move(make_connections(nodes[connect_to], verbose));
} else {
log_info(verbose_, "Rank", rank_, "connecting to", connect_to);
sockets_right_ = std::move(make_connections(nodes[connect_to], verbose));
log_info(verbose_, "Rank", rank_, "accepting");
sockets_left_ = std::move(accept_connections(nodes[rank_]));
// Initialize the ring by making all the connections
log_info(verbose_, "Rank", rank_, "accepting");
log_info(verbose_, "Rank", rank_, "connecting to", connect_to);
auto sl = std::async(std::launch::async, accept_connections, nodes[rank_]);
auto sr = std::async(
std::launch::async, make_connections, nodes[connect_to], verbose);
std::future_status status_sl, status_sr;
for (int i = 0; i < 10; i++) {
status_sl = sl.wait_for(std::chrono::milliseconds(INIT_TIMEOUT / 10));
status_sr = sl.wait_for(std::chrono::milliseconds(INIT_TIMEOUT / 10));
if (status_sl == std::future_status::ready &&
status_sr == std::future_status::ready) {
break;
}
}
if (status_sl != std::future_status::ready ||
status_sr != std::future_status::ready) {
throw std::runtime_error("[ring] Ring initialization timed out");
}
sockets_left_ = std::move(sl.get());
sockets_right_ = std::move(sr.get());
// Failure if we couldn't make right or left sockets
if (sockets_right_.empty()) {
@@ -628,18 +579,24 @@ class RingGroup : public GroupImpl {
}
void all_sum(const array& input, array& output, Stream stream) override {
SWITCH_TYPE(
output, all_reduce<T, SumOp<T>>(input, output, stream, SumOp<T>()));
dispatch_all_types(output.dtype(), [&](auto type_tag) {
using T = MLX_GET_TYPE(type_tag);
all_reduce<T, SumOp<T>>(input, output, stream, SumOp<T>());
});
}
void all_max(const array& input, array& output, Stream stream) override {
SWITCH_TYPE(
output, all_reduce<T, MaxOp<T>>(input, output, stream, MaxOp<T>()));
dispatch_all_types(output.dtype(), [&](auto type_tag) {
using T = MLX_GET_TYPE(type_tag);
all_reduce<T, MaxOp<T>>(input, output, stream, MaxOp<T>());
});
}
void all_min(const array& input, array& output, Stream stream) override {
SWITCH_TYPE(
output, all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>()));
dispatch_all_types(output.dtype(), [&](auto type_tag) {
using T = MLX_GET_TYPE(type_tag);
all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>());
});
}
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {

View File

@@ -620,10 +620,11 @@ std::vector<array> ArgReduce::vjp(
}
std::vector<array> ArgReduce::jvp(
const std::vector<array>& primals,
const std::vector<array>&,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {zeros_like(tangents[0], stream())};
auto shape = output_shapes(primals)[0];
return {zeros(shape, uint32, stream())};
}
std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
@@ -647,6 +648,21 @@ bool ArgSort::is_equivalent(const Primitive& other) const {
return axis_ == r_other.axis_;
}
std::vector<array> ArgSort::vjp(
const std::vector<array>& primals,
const std::vector<array>&,
const std::vector<int>&,
const std::vector<array>&) {
return {zeros_like(primals[0], stream())};
}
std::vector<array> ArgSort::jvp(
const std::vector<array>& primals,
const std::vector<array>&,
const std::vector<int>&) {
return {zeros(primals[0].shape(), uint32, stream())};
}
std::vector<array> AsType::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,

View File

@@ -378,6 +378,7 @@ class ArgSort : public UnaryPrimitive {
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(ArgSort)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;

View File

@@ -526,8 +526,10 @@ class Adam(Optimizer):
state["v"] = v
if bias_correction:
numerator = lr / (1 - b1**step) * m
denominator = mx.sqrt(v) / mx.sqrt(1 - b2**step) + eps
c1 = (lr / (1 - b1**step)).astype(gradient.dtype)
c2 = mx.rsqrt(1 - b2**step).astype(gradient.dtype)
numerator = c1 * m
denominator = mx.sqrt(v) * c2 + eps
return parameter - numerator / denominator
else:
return parameter - lr * m / (mx.sqrt(v) + eps)

View File

@@ -3,6 +3,8 @@ cuda_skip = {
"TestLayers.test_quantized_embedding",
"TestOps.test_dynamic_slicing",
"TestReduce.test_dtypes",
"TestReduce.test_nanpropagation",
"TestReduce.test_nanpropagation_complex64",
# Block masked matmul NYI
"TestBlas.test_block_masked_matmul",
# Gather matmul NYI
@@ -11,11 +13,6 @@ cuda_skip = {
"TestBlas.test_gather_mm_sorted",
# Segmented matmul NYI
"TestBlas.test_segmented_mm",
# Scan NYI
"TestArray.test_api",
"TestAutograd.test_cumprod_grad",
"TestOps.test_scans",
"TestOps.test_logcumsumexp",
# Hadamard NYI
"TestOps.test_hadamard",
"TestOps.test_hadamard_grad_vmap",

View File

@@ -4,6 +4,7 @@ import unittest
import mlx.core as mx
import mlx_distributed_tests
import mlx_tests
class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
@@ -150,4 +151,4 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@@ -4,6 +4,7 @@ import unittest
import mlx.core as mx
import mlx_distributed_tests
import mlx_tests
class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):

View File

@@ -196,6 +196,13 @@ class TestOptimizers(mlx_tests.MLXTestCase):
)
)
# Test for correct gradient type propagation
params = tree_map(lambda x: x.astype(mx.float16), params)
grads = tree_map(lambda x: x.astype(mx.float16), grads)
optim = opt.Adam(1e-2, bias_correction=True)
new_params = optim.apply_gradients(grads, params)
self.assertTrue(tree_equal(lambda p: p.dtype == mx.float16, new_params))
@unittest.skipIf(not has_torch, "requires Torch")
def test_adamw_matches_pytorch(self):
mx.random.seed(0)

View File

@@ -153,6 +153,63 @@ class TestReduce(mlx_tests.MLXTestCase):
x = x.transpose(1, 0, 2, 3, 4, 5, 6, 7, 8, 9)
check(x, (1, 3, 5, 7, 9))
def test_nanpropagation(self):
dtypes = [
"uint8",
"uint16",
"uint32",
"int8",
"int16",
"int32",
"float16",
"float32",
]
for dtype in dtypes:
with self.subTest(dtype=dtype):
x = (mx.random.normal((4, 4)) * 10).astype(getattr(mx, dtype))
indices = mx.random.randint(0, 4, shape=(6,)).reshape(3, 2)
for idx in indices:
x[idx[0], idx[1]] = mx.nan
x_np = np.array(x)
for op in ["max", "min"]:
for axis in [0, 1]:
out = getattr(mx, op)(x, axis=axis)
ref = getattr(np, op)(x_np, axis=axis)
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
def test_nanpropagation_complex64(self):
complex_array_1 = mx.array(
[1 + 1j, 2 + 2j, 3 + 3j, mx.nan + 4j], dtype=mx.complex64
).reshape(2, 2)
complex_array_2 = mx.array(
[1 + 1j, 2 + 2j, 3 + mx.nan * 1j, 4 + 4j], dtype=mx.complex64
).reshape(2, 2)
complex_array_3 = mx.array(
[1 + 1j, 2 + mx.nan * 1j, 3 + 3j, 4 + 4j], dtype=mx.complex64
).reshape(2, 2)
complex_array_4 = mx.array(
[mx.nan + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=mx.complex64
).reshape(2, 2)
np_arrays = [
np.array(complex_array_1),
np.array(complex_array_2),
np.array(complex_array_3),
np.array(complex_array_4),
]
for mx_arr, np_arr in zip(
[complex_array_1, complex_array_2, complex_array_3, complex_array_4],
np_arrays,
):
for axis in [0, 1]:
for op in ["max", "min"]:
out = getattr(mx, op)(mx_arr, axis=axis)
ref = getattr(np, op)(np_arr, axis=axis)
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
if __name__ == "__main__":
mlx_tests.MLXTestRunner(failfast=True)

View File

@@ -1024,6 +1024,10 @@ TEST_CASE("test reduction ops") {
x = array({true, true, true, false, true, false}, {2, 3});
CHECK(array_equal(min(x, 1), array({true, false})).item<bool>());
CHECK(array_equal(min(x, 0), array({false, true, false})).item<bool>());
x = array({1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3});
CHECK(array_equal(max(x, 0), array({4.0f, NAN, 6.0f}), true).item<bool>());
CHECK(array_equal(max(x, 1), array({NAN, 6.0f}), true).item<bool>());
}
// Test logsumexp
@@ -1346,6 +1350,11 @@ TEST_CASE("test arithmetic unary ops") {
x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0];
auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1});
CHECK(allclose(exp(x), expected).item<bool>());
// Complex of -inf
constexpr float inf = std::numeric_limits<float>::infinity();
x = array(complex64_t{-inf, -inf});
CHECK_EQ(exp(x).item<complex64_t>(), complex64_t{0, 0});
}
// Test expm1
@@ -1826,6 +1835,10 @@ TEST_CASE("test arithmetic binary ops") {
x = array(-inf);
y = array(inf);
CHECK_EQ(logaddexp(x, y).item<float>(), inf);
x = array(complex64_t{1, 1});
y = array(complex64_t{-inf, -inf});
CHECK_EQ(logaddexp(x, y).item<complex64_t>(), complex64_t{1, 1});
}
TEST_CASE("test broadcast") {