Compare commits

...

19 Commits

Author SHA1 Message Date
Param Thakkar
445478c98b
Merge eeaf1fa463 into cad5c0241c 2025-06-18 10:49:45 +08:00
Awni Hannun
cad5c0241c
[CUDA] synch properly waits for all tasks to finish and clear (#2303)
* cuda synch properly waits for all tasks to finish and clear

* fix copy
2025-06-17 12:03:25 -07:00
Awni Hannun
b8022c578a
divmod, partition, sort fixes (#2302) 2025-06-16 18:49:32 -07:00
Awni Hannun
bc53f8293f
Cuda bug fixes 2 (#2298)
* more bug fixes

* more bug fixes

* format
2025-06-16 13:14:46 -07:00
Awni Hannun
c552ff2451
[CUDA] Fix back-end bugs and enable corresponding tests (#2296)
* Fix some cuda back-end bugs and enable corresponding tests

* more fixes

* enable more tests

* format
2025-06-16 08:45:40 -07:00
Awni Hannun
4fda5fbdf9
add python testing for cuda with ability to skip list of tests (#2295) 2025-06-15 10:56:48 -07:00
Angelos Katharopoulos
580776559b
RoPE for CUDA (#2293)
* First working CUDA rope

* Fix random
2025-06-15 06:08:07 -07:00
Awni Hannun
a14aaa7c9d
Fix cuda arg reduce (#2291) 2025-06-14 17:54:00 -07:00
Param Thakkar
eeaf1fa463
Update fft_tests.cpp 2025-05-07 15:24:01 +05:30
paramthakkar123
80002ed42f Updates 2025-05-06 09:53:10 +05:30
paramthakkar123
f7f323f6ae Added stft and istft in the header 2025-04-23 19:19:04 +05:30
Param Thakkar
a963a15b8d
Merge branch 'main' into stft 2025-04-23 00:02:52 +05:30
paramthakkar123
0db393acd7 Fixed fft.cpp and added tests 2025-04-18 23:21:10 +05:30
paramthakkar123
c92a2bc679 changed to mx::pad and added stft to a single fft namespace 2025-04-18 23:02:55 +05:30
paramthakkar123
faf76a8133 Added bindings and removed python implementation 2025-04-18 01:45:06 +05:30
paramthakkar123
b4f01a8f7d Added cpp implementation of stft 2025-04-17 22:27:46 +05:30
paramthakkar123
91097e6179 Added docstrings 2025-04-17 08:20:32 +05:30
paramthakkar123
f50cce83a5 Formatted 2025-04-17 00:24:49 +05:30
paramthakkar123
9e8befbe8d Added Short time fourier transform and ISTFT implementations 2025-04-13 08:55:56 +05:30
78 changed files with 1477 additions and 243 deletions

View File

@ -234,6 +234,7 @@ jobs:
command: |
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
build_release:
parameters:

View File

@ -10,7 +10,7 @@ namespace mx = mlx::core;
void time_value_and_grad() {
auto x = mx::ones({200, 1000});
mx::eval(x);
auto fn = [](mx::array x) {
auto fn = [](mx::x) {
for (int i = 0; i < 20; ++i) {
x = mx::log(mx::exp(x));
}

View File

@ -107,6 +107,16 @@ same array:
>>> a
array([1, 2, 0], dtype=int32)
Note, unlike NumPy, updates to the same location are nondeterministic:
.. code-block:: shell
>>> a = mx.array([1, 2, 3])
>>> a[[0, 0]] = mx.array([4, 5])
The first element of ``a`` could be ``4`` or ``5``.
Transformations of functions which use in-place updates are allowed and work as
expected. For example:

View File

@ -209,4 +209,14 @@ Dims get_2d_grid_dims_common(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
}
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2);
auto gx = (dim0 + bx - 1) / bx;
auto gy = (dim1 + by - 1) / by;
auto gz = (dim2 + bz - 1) / bz;
return std::make_pair(
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
}
} // namespace mlx::core

View File

@ -95,6 +95,9 @@ Dims get_2d_grid_dims_common(
const Strides& strides,
size_t divisor);
// Get both the block and a grid of blocks that covers dim0, dim1 and dim2.
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2);
struct ContiguousIterator {
inline void step() {
int dims = shape_.size();

View File

@ -8,6 +8,7 @@ target_sources(
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
@ -32,6 +33,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu

View File

@ -106,7 +106,6 @@ void CudaAllocator::cuda_free(void* buf) {
return;
}
}
cudaFree(buf);
}

View File

@ -1,5 +1,4 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
@ -113,7 +112,7 @@ __global__ void arg_reduce_general(
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T vals[N_READS];
auto tid = r * BLOCK_DIM + block.thread_index().z;
auto tid = r * BLOCK_DIM + block.thread_index().x;
cub::LoadDirectBlocked(
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
best = op.reduce_many(best, vals, tid * N_READS);
@ -158,7 +157,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
constexpr uint32_t N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
dim3 block_dims{1, 1, BLOCK_DIM};
dim3 block_dims{BLOCK_DIM, 1, 1};
auto kernel = &cu::arg_reduce_general<
InType,
cu::ArgMax<InType>,

View File

@ -101,10 +101,12 @@ constexpr bool supports_binary_op() {
return std::is_same_v<Out, bool> && std::is_same_v<In, bool>;
}
if (std::is_same_v<Op, NaNEqual>) {
return std::is_same_v<Out, bool> &&
(is_floating_v<In> || std::is_same_v<In, complex64_t>);
return std::is_same_v<Out, bool> && is_inexact_v<In>;
}
if (std::is_same_v<Op, LogAddExp> || std::is_same_v<Op, ArcTan2>) {
if (std::is_same_v<Op, LogAddExp>) {
return std::is_same_v<In, Out> && is_inexact_v<In>;
}
if (std::is_same_v<Op, ArcTan2>) {
return std::is_same_v<In, Out> && is_floating_v<In>;
}
if (std::is_same_v<Op, BitwiseAnd> || std::is_same_v<Op, BitwiseOr> ||
@ -123,13 +125,12 @@ constexpr bool supports_binary_op() {
template <typename Op>
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
array& out,
std::string_view op,
const Stream& s) {
assert(inputs.size() > 1);
const auto& a = inputs[0];
const auto& b = inputs[1];
auto& out = outputs[0];
if (out.size() == 0) {
return;
}
@ -144,16 +145,15 @@ void binary_op_gpu_inplace(
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
auto bopt = get_binary_op_type(a, b);
if (bopt == BinaryOpType::General) {
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& a_strides = strides[0];
auto& b_strides = strides[1];
bool large = a.data_size() > UINT32_MAX ||
b.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
bool large = a.data_size() > INT32_MAX ||
b.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
@ -165,7 +165,7 @@ void binary_op_gpu_inplace(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size(),
out.size(),
const_param<NDIM>(shape),
const_param<NDIM>(a_strides),
const_param<NDIM>(b_strides));
@ -178,7 +178,7 @@ void binary_op_gpu_inplace(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size(),
out.size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
@ -196,8 +196,8 @@ void binary_op_gpu_inplace(
} else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, LARGE);
auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
@ -217,20 +217,6 @@ void binary_op_gpu_inplace(
});
}
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt);
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
@ -241,8 +227,7 @@ void binary_op_gpu(
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
std::vector<array> outputs{out};
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
binary_op_gpu_inplace<Op>(inputs, out, op, s);
}
#define BINARY_GPU(func) \
@ -252,19 +237,10 @@ void binary_op_gpu(
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
}
#define BINARY_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = outputs[0].primitive().stream(); \
binary_op_gpu<cu::func>(inputs, outputs, get_primitive_string(this), s); \
}
BINARY_GPU(Add)
BINARY_GPU(ArcTan2)
BINARY_GPU(Divide)
BINARY_GPU(Remainder)
BINARY_GPU(Equal)
BINARY_GPU(Greater)
BINARY_GPU(GreaterEqual)
BINARY_GPU(Less)
@ -279,6 +255,17 @@ BINARY_GPU(NotEqual)
BINARY_GPU(Power)
BINARY_GPU(Subtract)
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Equal::eval_gpu");
auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
if (equal_nan_) {
binary_op_gpu<cu::NaNEqual>(inputs, out, op, s);
} else {
binary_op_gpu<cu::Equal>(inputs, out, op, s);
}
}
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();

View File

@ -0,0 +1,248 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/binary.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/binary_ops.cuh"
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void
binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[0], b[0]);
out_a[0] = out[0];
out_b[0] = out[1];
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void
binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[0], b[index]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void
binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[index], b[0]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void
binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto out = Op{}(a[index], b[index]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
__global__ void binary_g_nd(
const In* a,
const In* b,
Out* out_a,
Out* out_b,
IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index, shape.data(), a_strides.data(), b_strides.data());
auto out = Op{}(a[a_idx], b[b_idx]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_g(
const In* a,
const In* b,
Out* out_a,
Out* out_b,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_4d(
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
auto out = Op{}(a[a_idx], b[b_idx]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out>
constexpr bool supports_binary_op() {
if (std::is_same_v<Op, DivMod>) {
return std::is_same_v<In, Out> &&
(std::is_integral_v<Out> || is_floating_v<Out>);
}
return false;
}
} // namespace cu
template <typename Op>
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
assert(inputs.size() > 1);
const auto& a = inputs[0];
const auto& b = inputs[1];
auto& out_a = outputs[0];
auto& out_b = outputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
if (out_a.size() == 0) {
return;
}
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
encoder.set_output_array(out_b);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
MLX_SWITCH_ALL_TYPES(out_a.dtype(), CTYPE_OUT, {
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
auto bopt = get_binary_op_type(a, b);
if (bopt == BinaryOpType::General) {
auto [shape, strides] = collapse_contiguous_dims(a, b, out_a);
auto& a_strides = strides[0];
auto& b_strides = strides[1];
bool large = a.data_size() > INT32_MAX ||
b.data_size() > INT32_MAX || out_a.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel =
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.size(),
const_param<NDIM>(shape),
const_param<NDIM>(a_strides),
const_param<NDIM>(b_strides));
});
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
ndim);
}
});
} else {
MLX_SWITCH_BOOL(out_a.data_size() > UINT32_MAX, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorScalar) {
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] = get_launch_args(
kernel,
out_a.data_size(),
out_a.shape(),
out_a.strides(),
LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.data_size());
});
}
} else {
throw std::runtime_error(fmt::format(
"Can not do binary op {} on inputs of {} with result of {}.",
op,
dtype_to_string(a.dtype()),
dtype_to_string(out_a.dtype())));
}
});
});
});
}
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt);
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
void DivMod::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("DivMod::eval_gpu");
auto& s = outputs[0].primitive().stream();
binary_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s);
}
} // namespace mlx::core

View File

@ -130,11 +130,13 @@ struct FusedKernelBuilder {
constexpr const char* g_jit_includes = R"(
#include "mlx/backend/cuda/device/binary_ops.cuh"
#include "mlx/backend/cuda/device/ternary_ops.cuh"
#include "mlx/backend/cuda/device/unary_ops.cuh"
#include "mlx/backend/cuda/device/utils.cuh"
#include <cooperative_groups.h>
#define inf cuda::std::numeric_limits<float>::infinity()
)";
void Compiled::eval_gpu(

View File

@ -6,7 +6,7 @@
namespace mlx::core {
void copy_gpu_inplace(
const array& in_,
const array& in,
array& out,
const Shape& shape,
const Strides& strides_in,
@ -20,7 +20,6 @@ void copy_gpu_inplace(
if (out.size() == 0) {
return;
}
const array& in = in_.data_shared_ptr() ? in_ : out;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);

View File

@ -10,20 +10,13 @@
namespace mlx::core {
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
using InType = cuda_type_t<CTYPE_IN>; \
using OutType = cuda_type_t<CTYPE_OUT>; \
if constexpr (cu::CastOp<InType, OutType>::is_castable) { \
__VA_ARGS__; \
} else { \
throw std::runtime_error(fmt::format( \
"Can not copy data from dtype {} to {}.", \
dtype_to_string(out.dtype()), \
dtype_to_string(in.dtype()))); \
} \
}); \
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
using InType = cuda_type_t<CTYPE_IN>; \
using OutType = cuda_type_t<CTYPE_OUT>; \
__VA_ARGS__; \
}); \
})
void copy_contiguous(

View File

@ -43,7 +43,8 @@ void copy_contiguous(
if (ctype == CopyType::Vector) {
kernel = cu::copy_v<InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE);
auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>() + in_offset,
out.data<OutType>() + out_offset,

View File

@ -59,29 +59,34 @@ void copy_general(
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
size_t data_size = 1;
for (auto& s : shape)
data_size *= s;
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
auto [num_blocks, block_dims] =
get_launch_args(kernel, data_size, shape, out.strides(), large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
data_size,
const_param<NDIM>(shape),
const_param<NDIM>(strides_in),
const_param<NDIM>(strides_out));
});
} else { // ndim >= 4
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
auto [num_blocks, block_dims] =
get_launch_args(kernel, data_size, shape, out.strides(), large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
data_size,
const_param(shape),
const_param(strides_in),
const_param(strides_out),

View File

@ -65,9 +65,9 @@ void copy_general_dynamic(
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
@ -76,7 +76,7 @@ void copy_general_dynamic(
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
out.size(),
const_param<NDIM>(shape),
const_param<NDIM>(strides_in),
const_param<NDIM>(strides_out),
@ -89,7 +89,7 @@ void copy_general_dynamic(
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
out.size(),
const_param(shape),
const_param(strides_in),
const_param(strides_out),

View File

@ -54,9 +54,9 @@ void copy_general_input(
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
@ -65,7 +65,7 @@ void copy_general_input(
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
out.size(),
const_param<NDIM>(shape),
const_param<NDIM>(strides_in));
});
@ -75,7 +75,7 @@ void copy_general_input(
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.data_size(),
out.size(),
const_param(shape),
const_param(strides_in),
ndim);

View File

@ -6,6 +6,7 @@
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
#include <future>
namespace mlx::core {
@ -107,6 +108,16 @@ void CommandEncoder::commit() {
worker_.commit(stream_.last_cuda_stream());
}
void CommandEncoder::synchronize() {
stream().synchronize();
auto p = std::make_shared<std::promise<void>>();
std::future<void> f = p->get_future();
add_completed_handler([p = std::move(p)]() { p->set_value(); });
worker_.end_batch();
worker_.commit();
f.wait();
}
Device& device(mlx::core::Device device) {
static std::unordered_map<int, Device> devices;
auto it = devices.find(device.index);

View File

@ -123,6 +123,9 @@ class CommandEncoder {
return has_gpu_work_;
}
// Wait until kernels and completion handlers are finished
void synchronize();
private:
Device& device_;
DeviceStream& stream_;

View File

@ -1,6 +1,8 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh"
#include <cuComplex.h>
#include <cuda/std/array>
@ -20,7 +22,7 @@ struct FloorDivide {
if constexpr (cuda::std::is_integral_v<T>) {
return x / y;
} else {
return trunc(x / y);
return truncf(x / y);
}
}
};
@ -122,6 +124,26 @@ struct LogAddExp {
? maxval
: T(float(maxval) + log1p(expf(minval - maxval)));
};
__device__ cuComplex operator()(cuComplex x, cuComplex y) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
isnan(cuCimagf(y))) {
return {
cuda::std::numeric_limits<float>::quiet_NaN(),
cuda::std::numeric_limits<float>::quiet_NaN()};
}
float inf = cuda::std::numeric_limits<float>::infinity();
auto maxval = x > y ? x : y;
auto minval = x < y ? x : y;
if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf)
return maxval;
float m = exp(cuCrealf(minval) - cuCrealf(maxval));
cuComplex dexp{
m * cos(cuCimagf(minval) - cuCimagf(maxval)),
m * sin(cuCimagf(minval) - cuCimagf(maxval)),
};
return maxval + log1p(dexp);
}
};
struct Maximum {

View File

@ -45,6 +45,18 @@ struct CastOp<
}
};
template <typename SrcT, typename DstT>
struct CastOp<
SrcT,
DstT,
cuda::std::enable_if_t<cuda::std::is_same_v<SrcT, DstT>>> {
static constexpr bool is_castable = true;
__device__ SrcT operator()(SrcT x) {
return x;
}
};
// Return an iterator that cast the value to DstT using CastOp.
template <typename DstT, typename Iterator>
__host__ __device__ auto make_cast_iterator(Iterator it) {

View File

@ -5,7 +5,7 @@
#pragma once
// The maximum dimensions of shape/strides passed as kernel parameters.
#define MAX_NDIM 8
#define MAX_NDIM 10
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
// warpSize variable exists, using it would prevent compile-time optimizations.

View File

@ -1,4 +1,5 @@
// Copyright © 2025 Apple Inc.
#pragma once
namespace mlx::core::cu {

View File

@ -5,6 +5,8 @@
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh"
#include <math_constants.h>
namespace mlx::core::cu {
struct Abs {
@ -183,21 +185,38 @@ struct Imag {
struct Log {
template <typename T>
__device__ T operator()(T x) {
return log(x);
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto r = log(cuCrealf(Abs{}(x)));
auto i = atan2f(cuCimagf(x), cuCrealf(x));
return {r, i};
} else {
return log(x);
}
}
};
struct Log2 {
template <typename T>
__device__ T operator()(T x) {
return log2(x);
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto y = Log{}(x);
return {cuCrealf(y) / CUDART_LN2_F, cuCimagf(y) / CUDART_LN2_F};
} else {
return log2(x);
}
}
};
struct Log10 {
template <typename T>
__device__ T operator()(T x) {
return log10(x);
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto y = Log{}(x);
return {cuCrealf(y) / CUDART_LNT_F, cuCimagf(y) / CUDART_LNT_F};
return y;
} else {
return log10(x);
}
}
};

View File

@ -187,8 +187,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
template <typename IdxT = int64_t>
inline __host__ __device__ IdxT
elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
IdxT loc = elem_to_loc_nd<3>(elem, shape, strides);
for (int i = ndim - 1; i >= 3; --i) {
IdxT loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
loc += (elem % shape[i]) * IdxT(strides[i]);
elem /= shape[i];
}
@ -202,8 +202,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
const int64_t* a_strides,
const int64_t* b_strides,
int ndim) {
auto [a_loc, b_loc] = elem_to_loc_nd<3>(elem, shape, a_strides, b_strides);
for (int i = ndim - 1; i >= 3; --i) {
IdxT a_loc = 0;
IdxT b_loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
@ -220,9 +221,10 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
const int64_t* b_strides,
const int64_t* c_strides,
int ndim) {
auto [a_loc, b_loc, c_loc] =
elem_to_loc_nd<3>(elem, shape, a_strides, b_strides, c_strides);
for (int i = ndim - 1; i >= 3; --i) {
IdxT a_loc = 0;
IdxT b_loc = 0;
IdxT c_loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
@ -336,4 +338,21 @@ struct LoopedElemToLoc<1, false, OffsetT> {
}
};
inline __device__ cuComplex log1p(cuComplex in) {
float x = cuCrealf(in);
float y = cuCimagf(in);
float zabs = sqrt(x * x + y * y);
float theta = atan2f(y, x + 1);
if (zabs < 0.5f) {
float r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return {x, theta};
}
return {0.5f * log1pf(r), theta};
} else {
auto z0 = sqrt((x + 1) * (x + 1) + y * y);
return {log(z0), theta};
}
}
} // namespace mlx::core::cu

View File

@ -62,7 +62,7 @@ void finalize(Stream s) {
void synchronize(Stream s) {
nvtx3::scoped_range r("gpu::synchronize");
cu::get_stream(s).synchronize();
cu::get_command_encoder(s).synchronize();
}
} // namespace mlx::core::gpu

View File

@ -65,8 +65,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) ||
(src.size() > UINT32_MAX) || (out.size() > UINT32_MAX);
bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) ||
(src.size() > INT32_MAX) || (out.size() > INT32_MAX);
uint32_t slice_size = std::accumulate(
slice_sizes_.begin(), slice_sizes_.end(), 1, std::multiplies<uint32_t>());
@ -88,7 +88,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
dtype_to_cuda_type(idx_dtype),
nidx,
ndim,
large ? "int64_t" : "uint32_t"));
large ? "int64_t" : "int32_t"));
}
}
return std::make_pair(jit_source_gather, std::move(kernel_names));
@ -99,7 +99,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
if (large) {
mod.append_arg<int64_t>(out.size());
} else {
mod.append_arg<uint32_t>(out.size());
mod.append_arg<int32_t>(out.size());
}
mod.append_ndim_arg(src.shape());
mod.append_ndim_arg(src.strides());
@ -115,7 +115,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
dtype_to_cuda_type(idx_dtype),
nidx,
idx_ndim,
large ? "int64_t" : "uint32_t");
large ? "int64_t" : "int32_t");
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
@ -152,14 +152,14 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) ||
(upd.size() > UINT32_MAX) || (out.size() > UINT32_MAX);
bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) ||
(upd.size() > INT32_MAX) || (out.size() > INT32_MAX);
uint32_t upd_post_idx_size = std::accumulate(
int32_t upd_post_idx_size = std::accumulate(
upd.shape().begin() + idx_ndim,
upd.shape().end(),
1,
std::multiplies<uint32_t>());
std::multiplies<int32_t>());
const char* op = g_scatter_ops[reduce_type_];
std::string module_name = fmt::format(
@ -181,7 +181,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
op,
nidx,
ndim,
large ? "int64_t" : "uint32_t"));
large ? "int64_t" : "int32_t"));
}
}
return std::make_pair(jit_source_scatter, std::move(kernel_names));
@ -192,7 +192,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
if (large) {
mod.append_arg<int64_t>(upd.size());
} else {
mod.append_arg<uint32_t>(upd.size());
mod.append_arg<int32_t>(upd.size());
}
mod.append_ndim_arg(upd.shape());
mod.append_ndim_arg(upd.strides());
@ -200,7 +200,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
if (large) {
mod.append_arg<int64_t>(upd_post_idx_size);
} else {
mod.append_arg<uint32_t>(upd_post_idx_size);
mod.append_arg<int32_t>(upd_post_idx_size);
}
mod.append_ndim_arg(out.shape());
mod.append_ndim_arg(out.strides());
@ -215,7 +215,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
op,
nidx,
idx_ndim,
large ? "int64_t" : "uint32_t");
large ? "int64_t" : "int32_t");
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
@ -238,7 +238,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX;
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
std::string module_name = fmt::format(
"gather_axis_{}_{}",
@ -258,7 +258,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
ndim,
contiguous & 1 ? true : false,
contiguous & 2 ? true : false,
large ? "int64_t" : "uint32_t"));
large ? "int64_t" : "int32_t"));
}
}
}
@ -283,9 +283,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
mod.append_arg<int64_t>(idx_size_axis);
mod.append_arg<int64_t>(idx_size_post);
} else {
mod.append_arg<uint32_t>(idx_size_pre);
mod.append_arg<uint32_t>(idx_size_axis);
mod.append_arg<uint32_t>(idx_size_post);
mod.append_arg<int32_t>(idx_size_pre);
mod.append_arg<int32_t>(idx_size_axis);
mod.append_arg<int32_t>(idx_size_post);
}
mod.append_arg(remove_index(idx.shape(), axis_));
mod.append_arg(remove_index(src.strides(), axis_));
@ -302,7 +302,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
src.ndim() - 1,
src.flags().row_contiguous,
idx.flags().row_contiguous,
large ? "int64_t" : "uint32_t");
large ? "int64_t" : "int32_t");
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
@ -337,7 +337,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX;
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
const char* op = reduce_type_ == ScatterAxis::Sum ? "Sum" : "Assign";
std::string module_name = fmt::format(
@ -360,7 +360,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
ndim,
contiguous & 1 ? true : false,
contiguous & 2 ? true : false,
large ? "int64_t" : "uint32_t"));
large ? "int64_t" : "int32_t"));
}
}
}
@ -385,9 +385,9 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
mod.append_arg<int64_t>(idx_size_axis);
mod.append_arg<int64_t>(idx_size_post);
} else {
mod.append_arg<uint32_t>(idx_size_pre);
mod.append_arg<uint32_t>(idx_size_axis);
mod.append_arg<uint32_t>(idx_size_post);
mod.append_arg<int32_t>(idx_size_pre);
mod.append_arg<int32_t>(idx_size_axis);
mod.append_arg<int32_t>(idx_size_post);
}
mod.append_arg(remove_index(idx.shape(), axis_));
mod.append_arg(remove_index(upd.strides(), axis_));
@ -405,7 +405,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
idx.ndim() - 1,
upd.flags().row_contiguous,
idx.flags().row_contiguous,
large ? "int64_t" : "uint32_t");
large ? "int64_t" : "int32_t");
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {

View File

@ -23,4 +23,11 @@ dim3 get_2d_grid_dims(
return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
}
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2) {
auto [grid, block] = get_grid_and_block_common(dim0, dim1, dim2);
auto [gx, gy, gz] = grid;
auto [bx, by, bz] = block;
return std::make_pair(dim3(gx, gy, gz), dim3(bx, by, bz));
}
} // namespace mlx::core

View File

@ -102,6 +102,11 @@ inline constexpr bool is_floating_v =
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
// Type traits for detecting complex or real floating point numbers.
template <typename T>
inline constexpr bool is_inexact_v =
is_floating_v<T> || cuda::std::is_same_v<T, complex64_t>;
// Utility to copy data from vector to array in host.
template <int NDIM = MAX_NDIM, typename T = int32_t>
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
@ -121,6 +126,7 @@ dim3 get_2d_grid_dims(
const Shape& shape,
const Strides& strides,
size_t divisor);
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
// Return a block size that achieves maximum potential occupancy for kernel.
template <typename T>
@ -135,17 +141,19 @@ inline uint max_occupancy_block_dim(T kernel) {
template <typename T>
inline std::tuple<dim3, uint> get_launch_args(
T kernel,
const array& arr,
size_t size,
const Shape& shape,
const Strides& strides,
bool large,
int work_per_thread = 1) {
size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread);
size_t nthreads = cuda::ceil_div(size, work_per_thread);
uint block_dim = max_occupancy_block_dim(kernel);
if (block_dim > nthreads) {
block_dim = nthreads;
}
dim3 num_blocks;
if (large) {
num_blocks = get_2d_grid_dims(arr.shape(), arr.strides(), work_per_thread);
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
} else {
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
@ -153,4 +161,14 @@ inline std::tuple<dim3, uint> get_launch_args(
return std::make_tuple(num_blocks, block_dim);
}
template <typename T>
inline std::tuple<dim3, uint> get_launch_args(
T kernel,
const array& arr,
bool large,
int work_per_thread = 1) {
return get_launch_args(
kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
}
} // namespace mlx::core

View File

@ -5,6 +5,7 @@
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
#include <cublasLt.h>
#include <fmt/format.h>
@ -45,7 +46,7 @@ class MatMul {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cuda_type(dtype);
if (dtype == bfloat16) {
if (dtype == bfloat16 || dtype == float16) {
scale_type = CUDA_R_32F;
}
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
@ -192,11 +193,12 @@ class MatMul {
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
switch (dtype) {
case float16:
return CUBLAS_COMPUTE_16F;
return CUBLAS_COMPUTE_32F;
case bfloat16:
return CUBLAS_COMPUTE_32F;
case float32:
return CUBLAS_COMPUTE_32F;
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F;
case float64:
case complex64:
return CUBLAS_COMPUTE_64F;

View File

@ -71,10 +71,8 @@ bool fast::ScaledDotProductAttention::use_fallback(
throw std::runtime_error(#func " has no CUDA implementation."); \
}
NO_GPU(ArgPartition)
NO_GPU(BlockMaskedMM)
NO_GPU(Convolution)
NO_GPU_MULTI(DivMod)
NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate)
NO_GPU(FFT)
@ -83,7 +81,6 @@ NO_GPU(GatherQMM)
NO_GPU(Hadamard)
NO_GPU(Load)
NO_GPU_MULTI(LUF)
NO_GPU(Partition)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(Scan)
@ -94,7 +91,6 @@ NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU_USE_FALLBACK(RoPE)
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel)

View File

@ -4,6 +4,7 @@
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
#include <cassert>
@ -12,6 +13,8 @@ namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
__constant__ constexpr uint32_t rotations[2][4] = {
{13, 15, 26, 6},
{17, 29, 16, 24}};
@ -47,27 +50,28 @@ __global__ void rbitsc(
dim3 grid_dims,
bool odd,
uint32_t bytes_per_key) {
uint2 index{
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y};
if (index.x >= grid_dims.x || index.y >= grid_dims.y) {
auto grid = cg::this_grid();
uint thread_index = grid.thread_rank();
uint index_x = thread_index % grid_dims.x;
uint index_y = thread_index / grid_dims.x;
if (index_x >= grid_dims.x || index_y >= grid_dims.y) {
return;
}
auto kidx = 2 * index.x;
auto kidx = 2 * index_x;
auto key = uint2{keys[kidx], keys[kidx + 1]};
auto half_size = grid_dims.y - odd;
out += index.x * bytes_per_key;
bool drop_last = odd && (index.y == half_size);
out += index_x * bytes_per_key;
bool drop_last = odd && (index_y == half_size);
auto bits = threefry2x32_hash(
key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y});
size_t idx = size_t(index.y) << 2;
key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y});
size_t idx = size_t(index_y) << 2;
for (int i = 0; i < 4; ++i) {
out[idx + i] = bits.bytes[0][i];
}
if (!drop_last) {
idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2;
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2;
if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) {
out[idx + i] = bits.bytes[1][i];
@ -89,30 +93,31 @@ __global__ void rbits(
int32_t ndim,
const __grid_constant__ Shape key_shape,
const __grid_constant__ Strides key_strides) {
uint2 index{
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y};
if (index.x >= grid_dims.x || index.y >= grid_dims.y) {
auto grid = cg::this_grid();
uint thread_index = grid.thread_rank();
uint index_x = thread_index % grid_dims.x;
uint index_y = thread_index / grid_dims.x;
if (index_x >= grid_dims.x || index_y >= grid_dims.y) {
return;
}
auto kidx = 2 * index.x;
auto kidx = 2 * index_x;
auto k1_elem = elem_to_loc(kidx, key_shape.data(), key_strides.data(), ndim);
auto k2_elem =
elem_to_loc(kidx + 1, key_shape.data(), key_strides.data(), ndim);
auto key = uint2{keys[k1_elem], keys[k2_elem]};
auto half_size = grid_dims.y - odd;
out += size_t(index.x) * bytes_per_key;
bool drop_last = odd && (index.y == half_size);
out += size_t(index_x) * bytes_per_key;
bool drop_last = odd && (index_y == half_size);
auto bits = threefry2x32_hash(
key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y});
size_t idx = size_t(index.y) << 2;
key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y});
size_t idx = size_t(index_y) << 2;
for (int i = 0; i < 4; ++i) {
out[idx + i] = bits.bytes[0][i];
}
if (!drop_last) {
idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2;
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2;
if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) {
out[idx + i] = bits.bytes[1][i];
@ -153,19 +158,22 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
dim3 grid_dims{num_keys, half_size + odd};
dim3 block_dims = get_block_dims(grid_dims.x, grid_dims.y, 1);
dim3 num_blocks{
cuda::ceil_div(grid_dims.x, block_dims.x),
cuda::ceil_div(grid_dims.y, block_dims.y)};
int64_t total = grid_dims.x * grid_dims.y;
int32_t threads_y = 1;
while ((total / threads_y) >= (1U << 31)) {
threads_y *= 2;
}
int32_t threads_x = cuda::ceil_div(total, threads_y);
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
if (keys.flags().row_contiguous) {
cu::rbitsc<<<num_blocks, block_dims, 0, stream>>>(
cu::rbitsc<<<grid, block, 0, stream>>>(
keys.data<uint32_t>(),
out.data<uint8_t>(),
grid_dims,
odd,
bytes_per_key);
} else {
cu::rbits<<<num_blocks, block_dims, 0, stream>>>(
cu::rbits<<<grid, block, 0, stream>>>(
keys.data<uint32_t>(),
out.data<uint8_t>(),
grid_dims,

385
mlx/backend/cuda/rope.cu Normal file
View File

@ -0,0 +1,385 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
template <typename T, bool traditional, bool forward>
__device__ void rope_single_impl(
const T* in,
T* out,
int32_t offset,
float inv_freq,
float scale,
int64_t stride,
uint2 pos,
uint2 dims) {
float L = scale * static_cast<float>(offset);
// Compute costheta, sintheta
float theta = L * inv_freq;
float costheta = cos(theta);
float sintheta = sin(theta);
// Compute the input and output indices
uint index_1, index_2;
if (traditional) {
index_1 = 2 * pos.x + pos.y * stride;
index_2 = index_1 + 1;
} else {
index_1 = pos.x + pos.y * stride;
index_2 = index_1 + dims.x;
}
// Read and write the output
float x1 = static_cast<float>(in[index_1]);
float x2 = static_cast<float>(in[index_2]);
float rx1;
float rx2;
if (forward) {
rx1 = x1 * costheta - x2 * sintheta;
rx2 = x1 * sintheta + x2 * costheta;
} else {
rx1 = x2 * sintheta + x1 * costheta;
rx2 = x2 * costheta - x1 * sintheta;
}
out[index_1] = static_cast<T>(rx1);
out[index_2] = static_cast<T>(rx2);
}
template <typename T, bool traditional, bool forward>
__global__ void rope_single(
const T* in,
T* out,
const int32_t* offset,
float scale,
float base,
int64_t stride,
uint2 dims) {
uint2 pos = make_uint2(
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y);
if (pos.x >= dims.x || pos.y >= dims.y) {
return;
}
float d = static_cast<float>(pos.x) / static_cast<float>(dims.x);
float inv_freq = exp2(-d * base);
rope_single_impl<T, traditional, forward>(
in, out, *offset, inv_freq, scale, stride, pos, dims);
}
template <typename T, bool traditional, bool forward>
__global__ void rope_single_freqs(
const T* in,
T* out,
const int32_t* offset,
const float* freqs,
float scale,
int64_t stride,
uint2 dims,
int64_t freq_stride) {
uint2 pos = make_uint2(
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y);
if (pos.x >= dims.x || pos.y >= dims.y) {
return;
}
float inv_freq = 1.0 / freqs[freq_stride * pos.x];
rope_single_impl<T, traditional, forward>(
in, out, *offset, inv_freq, scale, stride, pos, dims);
}
template <typename T, bool traditional, bool forward, int N = 4>
__device__ void rope_impl(
const T* in,
T* out,
int offset,
float inv_freq,
float scale,
const cuda::std::array<int64_t, 3> strides,
const cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch,
uint3 pos,
uint3 dims) {
float L = scale * static_cast<float>(pos.y + offset);
// Compute costheta, sintheta
float theta = L * inv_freq;
float costheta = cos(theta);
float sintheta = sin(theta);
// Compute the input and output indices
size_t in_index_1, in_index_2;
size_t out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
out_index_2 = out_index_1 + 1;
in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
out_index_2 = out_index_1 + dims.x * out_strides[2];
in_index_1 =
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + dims.x * strides[2];
}
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
// Read and write the output
float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]);
float rx1;
float rx2;
if (forward) {
rx1 = x1 * costheta - x2 * sintheta;
rx2 = x1 * sintheta + x2 * costheta;
} else {
rx1 = x2 * sintheta + x1 * costheta;
rx2 = x2 * costheta - x1 * sintheta;
}
out[out_index_1] = static_cast<T>(rx1);
out[out_index_2] = static_cast<T>(rx2);
in_index_1 += strides[0];
in_index_2 += strides[0];
out_index_1 += out_strides[0];
out_index_2 += out_strides[0];
}
}
template <typename T, bool traditional, bool forward>
__global__ void rope(
const T* in,
T* out,
const int32_t* offset,
float scale,
float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch,
uint3 dims) {
uint3 pos = make_uint3(
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y,
blockIdx.z * blockDim.z + threadIdx.z);
if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) {
return;
}
float d = static_cast<float>(pos.x) / static_cast<float>(dims.x);
float inv_freq = exp2(-d * base);
rope_impl<T, traditional, forward>(
in,
out,
*offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
pos,
dims);
}
template <typename T, bool traditional, bool forward>
__global__ void rope_freqs(
const T* in,
T* out,
const int32_t* offset,
const float* freqs,
float scale,
float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch,
uint3 dims,
int64_t freq_stride) {
uint3 pos = make_uint3(
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y,
blockIdx.z * blockDim.z + threadIdx.z);
if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) {
return;
}
float inv_freq = 1.0 / freqs[freq_stride * pos.x];
rope_impl<T, traditional, forward>(
in,
out,
*offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
pos,
dims);
}
} // namespace cu
namespace fast {
bool RoPE::use_fallback(Stream s) {
return s.device == Device::cpu;
}
void RoPE::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("RoPE::eval_gpu");
auto& s = stream();
auto& in = inputs[0];
auto& offset = inputs[1];
auto& out = outputs[0];
if (in.ndim() < 3) {
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
}
cuda::std::array<int64_t, 3> strides;
cuda::std::array<int64_t, 3> out_strides;
bool donated = false;
int ndim = in.ndim();
int dispatch_ndim = in.ndim();
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
dispatch_ndim--;
}
size_t mat_size = in.shape(-2) * in.shape(-1);
// We apply rope to less that the whole vector so copy to output and then
// apply in-place.
if (dims_ < in.shape(-1)) {
donated = true;
auto ctype =
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
copy_gpu(in, out, ctype, s);
strides[0] = mat_size;
strides[1] = out.strides()[ndim - 2];
strides[2] = out.strides()[ndim - 1];
}
// Either copy or apply in-place
else if (in.flags().row_contiguous) {
if (in.is_donatable()) {
donated = true;
out.copy_shared_buffer(in);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
strides[0] = mat_size;
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
} else if (dispatch_ndim == 3) {
// Handle non-contiguous 3D inputs
out.set_data(allocator::malloc(out.nbytes()));
strides[0] = in.strides()[ndim - 3];
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
} else {
// Copy non-contiguous > 3D inputs into the output and treat
// input as donated
donated = true;
copy_gpu(in, out, CopyType::General, s);
strides[0] = mat_size;
strides[1] = out.strides()[ndim - 2];
strides[2] = out.strides()[ndim - 1];
}
out_strides[0] = mat_size;
out_strides[1] = out.strides()[ndim - 2];
out_strides[2] = out.strides()[ndim - 1];
// Some flags to help us dispatch below
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
bool with_freqs = inputs.size() == 3;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(donated ? out : in);
encoder.set_input_array(offset);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(in.dtype(), "rope", CTYPE, {
using DataType = cuda_type_t<CTYPE>;
MLX_SWITCH_BOOL(traditional_, TRADITIONAL, {
MLX_SWITCH_BOOL(forward_, FORWARD, {
if (single && !with_freqs) {
auto kernel = cu::rope_single<DataType, TRADITIONAL, FORWARD>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
kernel<<<grid, block, 0, stream>>>(
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
scale_,
std::log2(base_),
mat_size,
dims);
} else if (single) {
auto kernel = cu::rope_single_freqs<DataType, TRADITIONAL, FORWARD>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
kernel<<<grid, block, 0, stream>>>(
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
inputs[2].data<float>(),
scale_,
mat_size,
dims,
inputs[2].strides(0));
} else if (with_freqs) {
auto kernel = cu::rope_freqs<DataType, TRADITIONAL, FORWARD>;
uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4;
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
kernel<<<grid, block, 0, stream>>>(
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
inputs[2].data<float>(),
scale_,
std::log2(base_),
strides,
out_strides,
in.size() / mat_size,
dims,
inputs[2].strides(0));
} else {
auto kernel = cu::rope<DataType, TRADITIONAL, FORWARD>;
uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4;
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
kernel<<<grid, block, 0, stream>>>(
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
scale_,
std::log2(base_),
strides,
out_strides,
in.size() / mat_size,
dims);
}
});
});
});
});
}
} // namespace fast
} // namespace mlx::core

View File

@ -86,7 +86,6 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
axis += in.ndim();
}
int nsort = in.shape(axis);
int nsegments = in.data_size() / nsort;
int last_dim = in.ndim() - 1;
// If we are not sorting the innermost dimension of a contiguous array,
@ -100,7 +99,11 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
encoder.add_temporary(out);
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
encoder.launch_kernel([&](cudaStream_t stream) {
@ -134,7 +137,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
indices.data<uint32_t>(),
out.data<uint32_t>(),
in.data_size(),
nsegments,
in.data_size() / nsort,
offsets,
offsets + 1,
stream);
@ -144,7 +147,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.data<Type>(),
out.data<Type>(),
in.data_size(),
nsegments,
in.data_size() / nsort,
offsets,
offsets + 1,
stream);
@ -177,4 +180,14 @@ void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
gpu_sort(stream(), inputs[0], out, axis_, false);
}
void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("ArgPartition::eval_gpu");
gpu_sort(stream(), inputs[0], out, axis_, true);
}
void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Partition::eval_gpu");
gpu_sort(stream(), inputs[0], out, axis_, false);
}
} // namespace mlx::core

View File

@ -101,10 +101,10 @@ void ternary_op_gpu_inplace(
auto& a_strides = strides[0];
auto& b_strides = strides[1];
auto& c_strides = strides[2];
bool large = a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX ||
c.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
bool large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
@ -116,7 +116,7 @@ void ternary_op_gpu_inplace(
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
out.data_size(),
out.size(),
const_param<NDIM>(shape),
const_param<NDIM>(a_strides),
const_param<NDIM>(b_strides),
@ -142,7 +142,8 @@ void ternary_op_gpu_inplace(
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
auto kernel = cu::ternary_v<Op, DType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE);
auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<bool>(),
b.data<DType>(),

View File

@ -27,12 +27,14 @@ constexpr bool supports_unary_op() {
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Log1p> ||
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Sigmoid> ||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Sigmoid> ||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
return std::is_same_v<In, Out> && is_floating_v<In>;
}
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p>) {
return std::is_same_v<In, Out> && is_inexact_v<In>;
}
if (std::is_same_v<Op, BitwiseInvert>) {
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
!std::is_same_v<In, bool>;
@ -91,7 +93,7 @@ void unary_op_gpu_inplace(
} else {
auto [shape, strides] = collapse_contiguous_dims(in);
auto [in_begin, in_end] = cu::make_general_iterators<int64_t>(
in_ptr, in.data_size(), shape, strides);
in_ptr, in.size(), shape, strides);
thrust::transform(policy, in_begin, in_end, out_ptr, Op());
}
} else {

View File

@ -31,6 +31,9 @@ const char* dtype_to_cuda_type(const Dtype& dtype) {
if (dtype == bfloat16) {
return "__nv_bfloat16";
}
if (dtype == complex64) {
return "cuComplex";
}
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
if (dtype == DTYPE) { \
return #CPP_TYPE; \

View File

@ -80,7 +80,9 @@ void Worker::thread_fn() {
}
worker_tasks_.erase(worker_tasks_.begin(), end);
}
for (auto& task : tasks) {
// Make sure tasks are cleared before the next wait
for (int i = 0; i < tasks.size(); ++i) {
auto task = std::move(tasks[i]);
task();
}
worker_event_.wait(batch + 1);

View File

@ -1,4 +1,6 @@
// Copyright © 2023 Apple Inc.
#include <cmath>
#include <numeric>
#include <set>
@ -189,6 +191,124 @@ array irfftn(const array& a, StreamOrDevice s /* = {} */) {
return fft_impl(a, true, true, s);
}
array stft(
const array& x,
int n_fft = 2048,
int hop_length = -1,
int win_length = -1,
const array& window,
bool center = true,
const std::string& pad_mode = "reflect",
bool normalized = false,
bool onesided = true,
StreamOrDevice s /* = {} */) {
if (hop_length == -1)
hop_length = n_fft / 4;
if (win_length == -1)
win_length = n_fft;
array win = (window.size() == 0) ? ones({win_length}, float32, s) : window;
if (win_length < n_fft) {
int pad_left = (n_fft - win_length) / 2;
int pad_right = n_fft - win_length - pad_left;
win = mlx::core::pad(
win, {{pad_left, pad_right}}, array(0, float32), "constant", s);
}
array padded_x = x;
if (center) {
int pad_width = n_fft / 2;
padded_x = mlx::core::pad(
padded_x, {{pad_width, pad_width}}, array(0, x.dtype()), pad_mode, s);
}
int n_frames = 1 + (padded_x.shape(0) - n_fft) / hop_length;
Shape strided_shape = {n_frames, n_fft};
Strides strided_strides = {
hop_length * static_cast<long long>(sizeof(float32)),
static_cast<long long>(sizeof(float32))};
array frames = as_strided(padded_x, strided_shape, strided_strides, 0, s);
array stacked_frames = multiply(frames, win, s);
array stft_result = mlx::core::fft::rfftn(stacked_frames, {n_fft}, {-1}, s);
if (normalized) {
array n_fft_array = full({1}, static_cast<float>(n_fft), float32, s);
stft_result = divide(stft_result, sqrt(n_fft_array, s), s);
}
if (onesided) {
stft_result = slice(stft_result, {}, {n_fft / 2 + 1}, s);
}
return stft_result;
}
array istft(
const array& stft_matrix,
int hop_length = -1,
int win_length = -1,
const array& window,
bool center = true,
int length = -1,
bool normalized = false,
StreamOrDevice s /* = {} */) {
int n_fft = (stft_matrix.shape(-1) - 1) * 2;
if (hop_length == -1)
hop_length = n_fft / 4;
if (win_length == -1)
win_length = n_fft;
array win = (window.size() == 0) ? ones({win_length}, float32, s) : window;
if (win_length < n_fft) {
int pad_left = (n_fft - win_length) / 2;
int pad_right = n_fft - win_length - pad_left;
win = mlx::core::pad(
win, {{pad_left, pad_right}}, array(0, float32), "constant", s);
}
array frames = mlx::core::fft::irfftn(stft_matrix, {n_fft}, {-1}, s);
frames = multiply(frames, win, s);
int signal_length = (frames.shape(0) - 1) * hop_length + n_fft;
array signal = zeros({signal_length}, float32, s);
array window_sum = zeros({signal_length}, float32, s);
Shape strided_shape = {frames.shape(0), n_fft};
Strides strided_strides = {
hop_length * static_cast<long long>(sizeof(float32)),
static_cast<long long>(sizeof(float32))};
array signal_strided =
as_strided(signal, strided_shape, strided_strides, 0, s);
array window_sum_strided =
as_strided(window_sum, strided_shape, strided_strides, 0, s);
signal_strided = add(signal_strided, frames, s);
window_sum_strided = add(window_sum_strided, win, s);
signal = divide(signal, window_sum, s);
if (center) {
int pad_width = n_fft / 2;
signal = slice(signal, {pad_width}, {signal.shape(0) - pad_width}, s);
}
if (length > 0) {
if (signal.shape(0) > length) {
signal = slice(signal, {0}, {length}, s);
} else if (signal.shape(0) < length) {
int pad_length = length - signal.shape(0);
signal = mlx::core::pad(
signal, {{0, pad_length}}, array(0, signal.dtype()), "constant", s);
}
}
return signal;
}
array fftshift(
const array& a,
const std::vector<int>& axes,
@ -258,5 +378,4 @@ array ifftshift(const array& a, StreamOrDevice s /* = {} */) {
std::iota(axes.begin(), axes.end(), 0);
return ifftshift(a, axes, s);
}
} // namespace mlx::core::fft
} // namespace mlx::core::fft

View File

@ -6,8 +6,11 @@
#include "array.h"
#include "device.h"
#include "mlx/mlx.h"
#include "utils.h"
namespace mx = mlx::core;
namespace mlx::core::fft {
/** Compute the n-dimensional Fourier Transform. */
@ -164,4 +167,48 @@ array ifftshift(
const std::vector<int>& axes,
StreamOrDevice s = {});
inline array stft(
const array& x,
int n_fft = 2048,
int hop_length = -1,
int win_length = -1,
const array& window = mx::array({}),
bool center = true,
const std::string& pad_mode = "reflect",
bool normalized = false,
bool onesided = true,
StreamOrDevice s = {}) {
return mlx::core::fft::stft(
x,
n_fft,
hop_length,
win_length,
window,
center,
pad_mode,
normalized,
onesided,
s);
}
inline array istft(
const array& stft_matrix,
int hop_length = -1,
int win_length = -1,
const array& window = mx::array({}),
bool center = true,
int length = -1,
bool normalized = false,
StreamOrDevice s = {}) {
return mlx::core::fft::istft(
stft_matrix,
hop_length,
win_length,
window,
center,
length,
normalized,
s);
}
} // namespace mlx::core::fft

View File

@ -149,6 +149,11 @@ inline bool metal_fast_synch() {
return metal_fast_synch;
}
inline bool enable_tf32() {
static bool enable_tf32_ = get_var("MLX_ENABLE_TF32", 1);
return enable_tf32_;
}
} // namespace env
} // namespace mlx::core

View File

@ -459,6 +459,101 @@ void init_fft(nb::module_& parent_module) {
Returns:
array: The real array containing the inverse of :func:`rfftn`.
)pbdoc");
m.def(
"stft",
[](const mx::array& x,
int n_fft = 2048,
std::optional<int> hop_length = std::nullopt,
std::optional<int> win_length = std::nullopt,
const mx::array& window = mx::array(),
bool center = true,
const std::string& pad_mode = "reflect",
bool normalized = false,
bool onesided = true,
mx::StreamOrDevice s = {}) {
return mx::stft::stft(
x,
n_fft,
hop_length,
win_length,
window,
center,
pad_mode,
normalized,
onesided,
s);
},
"x"_a,
"n_fft"_a = 2048,
"hop_length"_a = nb::none(),
"win_length"_a = nb::none(),
"window"_a = mx::array(),
"center"_a = true,
"pad_mode"_a = "reflect",
"normalized"_a = false,
"onesided"_a = true,
"stream"_a = nb::none(),
R"pbdoc(
Short-time Fourier Transform (STFT).
Args:
x (array): Input signal.
n_fft (int, optional): Number of FFT points. Default is 2048.
hop_length (int, optional): Number of samples between successive frames. Default is `n_fft // 4`.
win_length (int, optional): Window size. Default is `n_fft`.
window (array, optional): Window function. Default is a rectangular window.
center (bool, optional): Whether to pad the signal to center the frames. Default is True.
pad_mode (str, optional): Padding mode. Default is "reflect".
normalized (bool, optional): Whether to normalize the STFT. Default is False.
onesided (bool, optional): Whether to return a one-sided STFT. Default is True.
Returns:
array: The STFT of the input signal.
)pbdoc");
m.def(
"istft",
[](const mx::array& stft_matrix,
std::optional<int> hop_length = std::nullopt,
std::optional<int> win_length = std::nullopt,
const mx::array& window = mx::array(),
bool center = true,
int length = -1,
bool normalized = false,
mx::StreamOrDevice s = {}) {
return mx::stft::istft(
stft_matrix,
hop_length,
win_length,
window,
center,
length,
normalized,
s);
},
"stft_matrix"_a,
"hop_length"_a = nb::none(),
"win_length"_a = nb::none(),
"window"_a = mx::array(),
"center"_a = true,
"length"_a = -1,
"normalized"_a = false,
"stream"_a = nb::none(),
R"pbdoc(
Inverse Short-time Fourier Transform (ISTFT).
Args:
stft_matrix (array): Input STFT matrix.
hop_length (int, optional): Number of samples between successive frames. Default is `n_fft // 4`.
win_length (int, optional): Window size. Default is `n_fft`.
window (array, optional): Window function. Default is a rectangular window.
center (bool, optional): Whether the signal was padded to center the frames. Default is True.
length (int, optional): Length of the output signal. Default is inferred from the STFT matrix.
normalized (bool, optional): Whether the STFT was normalized. Default is False.
Returns:
array: The reconstructed signal.
m.def(
"fftshift",
[](const mx::array& a,

5
python/tests/__main__.py Normal file
View File

@ -0,0 +1,5 @@
from . import mlx_tests
__unittest = True
mlx_tests.MLXTestRunner(module=None)

106
python/tests/cuda_skip.py Normal file
View File

@ -0,0 +1,106 @@
cuda_skip = {
"TestArray.test_api",
"TestBF16.test_arg_reduction_ops",
"TestBF16.test_reduction_ops",
"TestBlas.test_complex_gemm",
"TestEinsum.test_ellipses",
"TestEinsum.test_opt_einsum_test_cases",
"TestLoad.test_load_f8_e4m3",
"TestLayers.test_group_norm",
"TestLayers.test_pooling",
"TestLayers.test_quantized_embedding",
"TestLayers.test_sin_pe",
"TestLayers.test_upsample",
"TestOps.test_complex_ops",
"TestOps.test_dynamic_slicing",
"TestOps.test_softmax",
"TestReduce.test_axis_permutation_sums",
"TestReduce.test_dtypes",
"TestReduce.test_expand_sums",
"TestReduce.test_many_reduction_axes",
"TestUpsample.test_torch_upsample",
# Block masked matmul NYI
"TestBlas.test_block_masked_matmul",
# Gather matmul NYI
"TestBlas.test_gather_matmul",
"TestBlas.test_gather_matmul_grad",
# Scan NYI
"TestAutograd.test_cumprod_grad",
"TestOps.test_scans",
"TestOps.test_logcumsumexp",
# Hadamard NYI
"TestOps.test_hadamard",
"TestOps.test_hadamard_grad_vmap",
# Convolutions NYI
"TestConv.test_1d_conv_with_2d",
"TestConv.test_asymmetric_padding",
"TestConv.test_basic_grad_shapes",
"TestConv.test_conv2d_unaligned_channels",
"TestConv.test_conv_1d_groups_flipped",
"TestConv.test_conv_general_flip_grad",
"TestConv.test_conv_groups_grad",
"TestConv.test_numpy_conv",
"TestConv.test_repeated_conv",
"TestConv.test_torch_conv_1D",
"TestConv.test_torch_conv_1D_grad",
"TestConv.test_torch_conv_2D",
"TestConv.test_torch_conv_2D_grad",
"TestConv.test_torch_conv_3D",
"TestConv.test_torch_conv_3D_grad",
"TestConv.test_torch_conv_depthwise",
"TestConv.test_torch_conv_general",
"TestConvTranspose.test_torch_conv_tranpose_1d_output_padding",
"TestConvTranspose.test_torch_conv_transpose_1D",
"TestConvTranspose.test_torch_conv_transpose_1D_grad",
"TestConvTranspose.test_torch_conv_transpose_2D",
"TestConvTranspose.test_torch_conv_transpose_2D_grad",
"TestConvTranspose.test_torch_conv_transpose_2d_output_padding",
"TestConvTranspose.test_torch_conv_transpose_3D",
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
"TestExportImport.test_export_conv",
"TestLayers.test_conv1d",
"TestLayers.test_conv2d",
"TestVmap.test_vmap_conv",
# FFTs NYI
"TestFFT.test_fft",
"TestFFT.test_fft_big_powers_of_two",
"TestFFT.test_fft_contiguity",
"TestFFT.test_fft_exhaustive",
"TestFFT.test_fft_grads",
"TestFFT.test_fft_into_ifft",
"TestFFT.test_fft_large_numbers",
"TestFFT.test_fft_shared_mem",
"TestFFT.test_fftn",
# Lapack ops NYI
"TestLinalg.test_cholesky",
"TestLinalg.test_cholesky_inv",
"TestLinalg.test_eig",
"TestLinalg.test_eigh",
"TestLinalg.test_inverse",
"TestVmap.test_vmap_inverse",
"TestLinalg.test_lu",
"TestLinalg.test_lu_factor",
"TestLinalg.test_pseudo_inverse",
"TestLinalg.test_qr_factorization",
"TestInit.test_orthogonal",
"TestLinalg.test_svd_decomposition",
"TestVmap.test_vmap_svd",
"TestLinalg.test_tri_inverse",
# Quantization NYI
"TestQuantized.test_gather_matmul_grad",
"TestQuantized.test_gather_qmm",
"TestQuantized.test_gather_qmm_sorted",
"TestQuantized.test_non_multiples",
"TestQuantized.test_qmm",
"TestQuantized.test_qmm_jvp",
"TestQuantized.test_qmm_shapes",
"TestQuantized.test_qmm_vjp",
"TestQuantized.test_qmv",
"TestQuantized.test_quantize_dequantize",
"TestQuantized.test_qvm",
"TestQuantized.test_qvm_splitk",
"TestQuantized.test_small_matrix",
"TestQuantized.test_throw",
"TestQuantized.test_vjp_scales_biases",
}

View File

@ -9,6 +9,42 @@ import mlx.core as mx
import numpy as np
class MLXTestRunner(unittest.TestProgram):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def createTests(self, *args, **kwargs):
super().createTests(*args, **kwargs)
# Asume CUDA backend in this case
device = os.getenv("DEVICE", None)
if device is not None:
device = getattr(mx, device)
else:
device = mx.default_device()
if not (device == mx.gpu and not mx.metal.is_available()):
return
from cuda_skip import cuda_skip
filtered_suite = unittest.TestSuite()
def filter_and_add(t):
if isinstance(t, unittest.TestSuite):
for sub_t in t:
filter_and_add(sub_t)
else:
t_id = ".".join(t.id().split(".")[-2:])
if t_id in cuda_skip:
print(f"Skipping {t_id}")
else:
filtered_suite.addTest(t)
filter_and_add(self.test)
self.test = filtered_suite
class MLXTestCase(unittest.TestCase):
@property
def is_apple_silicon(self):

View File

@ -130,4 +130,4 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -1187,7 +1187,7 @@ class TestArray(mlx_tests.MLXTestCase):
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
check_slices(
np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 0, 1])
np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 2, 1])
)
# Multiple slices
@ -2033,4 +2033,4 @@ class TestArray(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -799,4 +799,4 @@ class TestAutograd(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -193,4 +193,4 @@ class TestBF16(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -1236,4 +1236,4 @@ class TestBlas(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -981,4 +981,4 @@ class TestCompile(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -38,4 +38,4 @@ class TestConstants(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -1188,4 +1188,4 @@ class TestConv(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -807,4 +807,4 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -38,7 +38,7 @@ class TestDevice(mlx_tests.MLXTestCase):
# Restore device
mx.set_default_device(device)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
@unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
def test_device_context(self):
default = mx.default_device()
diff = mx.cpu if default == mx.gpu else mx.gpu
@ -114,4 +114,4 @@ class TestStream(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -294,4 +294,4 @@ class TestDouble(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -360,4 +360,4 @@ class TestEinsum(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -172,7 +172,7 @@ class TestEval(mlx_tests.MLXTestCase):
post = mx.get_peak_memory()
self.assertEqual(pre, post)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
@unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
def test_multistream_deadlock(self):
s1 = mx.default_stream(mx.gpu)
s2 = mx.new_stream(mx.gpu)
@ -197,4 +197,4 @@ class TestEval(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -348,4 +348,4 @@ class TestExportImport(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -772,4 +772,4 @@ class TestFast(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -607,7 +607,7 @@ class TestSDPA(mlx_tests.MLXTestCase):
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_sdpa_prommote_mask(self):
def test_sdpa_promote_mask(self):
mask = mx.array(2.0, mx.bfloat16)
D = 64
Nq = 4
@ -653,4 +653,4 @@ class TestSDPA(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main(failfast=True)
mlx_tests.MLXTestRunner(failfast=True)

View File

@ -320,4 +320,4 @@ class TestFFT(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -34,4 +34,4 @@ class TestGraph(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -136,4 +136,4 @@ class TestInit(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -545,4 +545,4 @@ class TestLinalg(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -400,4 +400,4 @@ class TestLoad(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -83,14 +83,14 @@ class TestLosses(mlx_tests.MLXTestCase):
logits, targets, reduction="mean"
)
expected_mean = mx.mean(expected_none)
self.assertEqual(losses_mean, expected_mean)
self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum'
losses_sum = nn.losses.binary_cross_entropy(
logits, targets, reduction="sum"
)
expected_sum = mx.sum(expected_none)
self.assertEqual(losses_sum, expected_sum)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
# With weights, no label smoothing
weights = mx.array([1.0, 2.0, 1.0, 2.0])
@ -414,4 +414,4 @@ class TestLosses(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -60,4 +60,4 @@ class TestMemory(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -1907,4 +1907,4 @@ class TestLayers(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -2586,17 +2586,6 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqualArray(result, mx.array(expected))
def test_atleast_1d(self):
def compare_nested_lists(x, y):
if isinstance(x, list) and isinstance(y, list):
if len(x) != len(y):
return False
for i in range(len(x)):
if not compare_nested_lists(x[i], y[i]):
return False
return True
else:
return x == y
# Test 1D input
arrays = [
[1],
@ -2614,23 +2603,11 @@ class TestOps(mlx_tests.MLXTestCase):
for i, array in enumerate(arrays):
mx_res = mx.atleast_1d(mx.array(array))
np_res = np.atleast_1d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
def test_atleast_2d(self):
def compare_nested_lists(x, y):
if isinstance(x, list) and isinstance(y, list):
if len(x) != len(y):
return False
for i in range(len(x)):
if not compare_nested_lists(x[i], y[i]):
return False
return True
else:
return x == y
# Test 1D input
arrays = [
[1],
@ -2648,23 +2625,11 @@ class TestOps(mlx_tests.MLXTestCase):
for i, array in enumerate(arrays):
mx_res = mx.atleast_2d(mx.array(array))
np_res = np.atleast_2d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
def test_atleast_3d(self):
def compare_nested_lists(x, y):
if isinstance(x, list) and isinstance(y, list):
if len(x) != len(y):
return False
for i in range(len(x)):
if not compare_nested_lists(x[i], y[i]):
return False
return True
else:
return x == y
# Test 1D input
arrays = [
[1],
@ -2682,10 +2647,9 @@ class TestOps(mlx_tests.MLXTestCase):
for i, array in enumerate(arrays):
mx_res = mx.atleast_3d(mx.array(array))
np_res = np.atleast_3d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
def test_issubdtype(self):
self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact))
@ -3127,4 +3091,4 @@ class TestBroadcast(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -527,4 +527,4 @@ class TestSchedulers(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -576,4 +576,4 @@ class TestQuantized(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -389,4 +389,4 @@ class TestRandom(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -155,4 +155,4 @@ class TestReduce(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main(failfast=True)
mlx_tests.MLXTestRunner(failfast=True)

View File

@ -48,4 +48,4 @@ class TestTreeUtils(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -97,4 +97,4 @@ class TestUpsample(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -725,4 +725,4 @@ class TestVmap(mlx_tests.MLXTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@ -307,6 +307,80 @@ TEST_CASE("test fft grads") {
CHECK_EQ(vjp_out.shape(), Shape{5, 5});
}
TEST_CASE("test stft and istft") {
int n_fft = 4;
int hop_length = 2;
int win_length = 4;
array signal = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, float32);
array window = array({0.5, 1.0, 1.0, 0.5}, float32);
SUBCASE("stft basic functionality") {
auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window);
CHECK_EQ(stft_result.shape(0), 4);
CHECK_EQ(stft_result.shape(1), 3);
CHECK_EQ(stft_result.dtype(), complex64);
}
SUBCASE("istft reconstruction") {
auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window);
auto reconstructed_signal =
fft::istft(stft_result, hop_length, win_length, window);
CHECK_EQ(reconstructed_signal.shape(0), signal.shape(0));
CHECK(allclose(signal, reconstructed_signal, 1e-5, 1e-5).item<bool>());
}
SUBCASE("stft with default parameters") {
auto stft_result = fft::stft(signal);
CHECK_EQ(stft_result.shape(0), 5);
CHECK_EQ(stft_result.shape(1), 3);
CHECK_EQ(stft_result.dtype(), complex64);
}
SUBCASE("istft with length parameter") {
auto stft_result = fft::stft(signal, n_fft, hop_length, win_length, window);
int length = 6;
auto reconstructed_signal =
fft::istft(stft_result, hop_length, win_length, window, true, length);
CHECK_EQ(reconstructed_signal.shape(0), length);
CHECK(
allclose(slice(signal, {0}, {length}), reconstructed_signal, 1e-5, 1e-5)
.item<bool>());
}
SUBCASE("stft and istft with normalization") {
auto stft_result = fft::stft(
signal, n_fft, hop_length, win_length, window, true, "reflect", true);
auto reconstructed_signal =
fft::istft(stft_result, hop_length, win_length, window, true, -1, true);
CHECK(allclose(signal, reconstructed_signal, 1e-5, 1e-5).item<bool>());
}
SUBCASE("stft with onesided=False") {
auto stft_result = fft::stft(
signal,
n_fft,
hop_length,
win_length,
window,
true,
"reflect",
false,
false);
CHECK_EQ(stft_result.shape(1), n_fft);
}
}
TEST_CASE("test fftshift and ifftshift") {
// Test 1D array with even length
auto x = arange(8);