Compare commits

..

8 Commits

Author SHA1 Message Date
Angelos Katharopoulos
a22d0bf273 Add stricter condition to matrix sdpa 2025-08-06 19:51:14 -07:00
Jagrit Digani
99d8de8445 Fix cudnn routing 2025-08-06 15:05:58 -07:00
Jagrit Digani
c66b76a8c8 Update routing 2025-08-06 15:01:15 -07:00
Jagrit Digani
f81edd184f Complete 2 pass sdpav 2025-08-06 13:57:40 -07:00
Jagrit Digani
7f8ba2a003 [WIP] 2 pass sdpav 2025-08-06 09:56:39 -07:00
Jagrit Digani
c28249b81a Add more nvtx range for debug 2025-08-06 09:56:39 -07:00
Jagrit Digani
e74bcdc5e3 Add sdpa file 2025-08-06 09:56:39 -07:00
Jagrit Digani
d8ed6c1aa3 Add base cudnn attention support 2025-08-06 09:56:39 -07:00
100 changed files with 1214 additions and 2102 deletions

View File

@@ -1,5 +1,4 @@
sphinx
breathe
sphinx-book-theme
sphinx-copybutton
mlx

View File

@@ -18,7 +18,6 @@ release = version
# -- General configuration ---------------------------------------------------
extensions = [
"sphinx_copybutton",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",

View File

@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
optimizer.update(model, grads)
# Save the state
state = tree_flatten(optimizer.state, destination={})
mx.save_safetensors("optimizer.safetensors", state)
state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", dict(state))
# Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(mx.load("optimizer.safetensors"))
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For

View File

@@ -7,17 +7,17 @@ Exporting Functions
MLX has an API to export and import functions to and from a file. This lets you
run computations written in one MLX front-end (e.g. Python) in another MLX
front-end (e.g. C++).
front-end (e.g. C++).
This guide walks through the basics of the MLX export API with some examples.
To see the full list of functions check-out the :ref:`API documentation
<export>`.
Basics of Exporting
Basics of Exporting
-------------------
Let's start with a simple example:
.. code-block:: python
def fun(x, y):
@@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
x = mx.array(1.0)
y = mx.array(1.0)
# Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y)
@@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
For enclosed arrays inside an exported function, be extra careful to ensure
they are evaluated. The computation graph that gets exported will include
the computation that produces enclosed inputs.
If the above example was missing ``mx.eval(model.parameters()``, the
exported function would include the random initialization of the
:obj:`mlx.nn.Module` parameters.
@@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
# Set the model's parameters to the input parameters
model.update(tree_unflatten(list(params.items())))
return model(x)
params = tree_flatten(model.parameters(), destination={})
params = dict(tree_flatten(model.parameters()))
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
@@ -169,8 +169,8 @@ to export a function which can be used for inputs with variable shapes:
# Ok
out, = imported_abs(mx.array(-1.0))
# Also ok
# Also ok
out, = imported_abs(mx.array([-1.0, -2.0]))
With ``shapeless=False`` (which is the default), the second call to
@@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
def fun(x, y=None):
constant = mx.array(3.0)
if y is not None:
x += y
x += y
return x + constant
with mx.exporter("fun.mlxfn", fun) as exporter:
@@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
print(out)
In the above example the function constant data, (i.e. ``constant``), is only
saved once.
saved once.
Transformations with Imported Functions
---------------------------------------
@@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
# Prints: array(1, dtype=float32)
print(dfdx(x))
# Compile the imported function
# Compile the imported function
mx.compile(imported_fun)
# Prints: array(0, dtype=float32)
print(compiled_fun(x)[0])
@@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
// Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl;
Imported functions can be transformed in C++ just like in Python. Use
Imported functions can be transformed in C++ just like in Python. Use
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
mx::array>`` for keyword arguments when calling imported functions in C++.

View File

@@ -228,4 +228,31 @@ std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
}
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
int ndim = x.ndim();
if (axis1 < 0) {
axis1 += ndim;
}
if (axis2 < 0) {
axis2 += ndim;
}
auto shape = x.shape();
std::swap(shape[axis1], shape[axis2]);
auto strides = x.strides();
std::swap(strides[axis1], strides[axis2]);
auto [data_size, row_contiguous, col_contiguous] =
check_contiguity(shape, strides);
bool contiguous = data_size == x.data_size();
array out(std::move(shape), x.dtype(), nullptr, {});
out.copy_shared_buffer(
x,
std::move(strides),
{contiguous, row_contiguous, col_contiguous},
x.data_size());
return out;
}
} // namespace mlx::core

View File

@@ -196,6 +196,9 @@ void shared_buffer_reshape(
const Strides& out_strides,
array& out);
// Like the swapaxes op but safe to call in eval_gpu.
array swapaxes_in_eval(const array& x, int axis1, int axis2);
template <typename T>
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
vec.erase(std::next(vec.begin(), index));

View File

@@ -157,12 +157,10 @@ inline void build_kernel(
#endif
// Start the kernel
os << "void " << kernel_name
<< "(int* shape, int64_t** strides, void** args) {" << std::endl;
os << "void " << kernel_name << "(void** args) {" << std::endl;
// Add the input arguments
int cnt = 0;
int strides_index = 1;
for (size_t i = 0; i < inputs.size(); ++i) {
// Skip constants from the input list
if (is_constant(i)) {
@@ -177,8 +175,8 @@ inline void build_kernel(
<< "];" << std::endl;
// Scalars and contiguous need no strides
if (!is_scalar(x) && !contiguous) {
os << " const int64_t* " << xname << "_strides = strides["
<< strides_index++ << "];" << std::endl;
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
<< "];" << std::endl;
}
}
@@ -188,8 +186,10 @@ inline void build_kernel(
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
<< "*)args[" << cnt++ << "];" << std::endl;
}
// Add output size
if (contiguous) {
// Add output strides and shape to extract the indices.
if (!contiguous) {
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
} else {
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
}
@@ -288,8 +288,17 @@ void Compiled::eval_cpu(
auto [contiguous, shape, strides] =
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
// Force allocating shape/strides on heap so we can take their data() first
// and then std::move them.
// TODO: Refactor code to avoid heap allocation.
shape.grow();
for (auto& s : strides) {
s.grow();
}
// Collect function input arguments.
std::vector<void*> args;
int strides_index = 1;
for (size_t i = 0; i < inputs.size(); ++i) {
if (is_constant_(i)) {
continue;
@@ -297,6 +306,9 @@ void Compiled::eval_cpu(
const auto& x = inputs[i];
encoder.set_input_array(x);
args.push_back((void*)x.data<void>());
if (!contiguous && !is_scalar(x)) {
args.push_back(strides[strides_index++].data());
}
}
// Get the kernel name from the lib
@@ -331,20 +343,16 @@ void Compiled::eval_cpu(
args.push_back(x.data<void>());
encoder.set_output_array(x);
}
if (contiguous) {
if (!contiguous) {
args.push_back((void*)shape.data());
} else {
args.push_back((void*)outputs[0].data_size());
}
auto fun = reinterpret_cast<void (*)(int*, int64_t**, void**)>(fn_ptr);
auto fun = (void (*)(void**))fn_ptr;
encoder.dispatch([fun,
args = std::move(args),
strides = std::move(strides),
shape = std::move(shape)]() mutable {
SmallVector<int64_t*> strides_ptrs;
for (auto& s : strides) {
strides_ptrs.push_back(s.data());
}
fun(shape.data(), strides_ptrs.data(), args.data());
});
shape = std::move(shape)]() mutable { fun(args.data()); });
}
} // namespace mlx::core

View File

@@ -491,27 +491,19 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
switch (in.dtype()) {
case bool_:
case uint8:
reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:

View File

@@ -8,6 +8,7 @@ target_sources(
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arange.cu
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
@@ -44,20 +45,18 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu)
else()
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp)
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp)
endif()
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
@@ -149,7 +148,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
FetchContent_Declare(
cudnn
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
GIT_TAG v1.14.0
GIT_TAG v1.12.1
GIT_SHALLOW TRUE
EXCLUDE_FROM_ALL)
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)

View File

@@ -99,89 +99,39 @@ __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
}
}
template <
typename Op,
typename In,
typename Out,
typename IdxT,
int NDIM,
int N_READS>
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
__global__ void binary_g_nd(
const In* a,
const In* b,
Out* out,
IdxT size_rest,
IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index, shape.data(), a_strides.data(), b_strides.data());
out[index] = Op{}(a[a_idx], b[b_idx]);
}
auto shape_x = shape[NDIM - 1];
auto a_stride_x = a_strides[NDIM - 1];
auto b_stride_x = b_strides[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data());
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_g(
const In* a,
const In* b,
Out* out,
IdxT size_rest,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides,
int ndim) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc(
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
out[index] = Op{}(a[a_idx], b[b_idx]);
}
auto shape_x = shape[ndim - 1];
auto a_stride_x = a_strides[ndim - 1];
auto b_stride_x = b_strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc(
index_rest * shape_x,
shape.data(),
a_strides.data(),
b_strides.data(),
ndim);
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
template <typename Op, typename In, typename Out>
@@ -259,61 +209,39 @@ void binary_op_gpu_inplace(
auto& a_strides = strides[0];
auto& b_strides = strides[1];
int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu::binary_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant(),
1>;
if (work_per_thread == 4) {
kernel = cu::binary_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant(),
4>;
}
auto [num_blocks, block_dims] =
get_launch_args(out, large());
encoder.add_kernel_node(
kernel,
{num_blocks_x, num_blocks_y},
cu::binary_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant()>,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
rest,
out.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides));
});
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT, 1>;
if (work_per_thread == 4) {
kernel = cu::binary_g<Op, InType, OutType, IdxT, 4>;
}
auto [num_blocks, block_dims] = get_launch_args(out, large());
encoder.add_kernel_node(
kernel,
{num_blocks_x, num_blocks_y},
cu::binary_g<Op, InType, OutType, IdxT>,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
rest,
out.size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
@@ -376,4 +304,54 @@ void binary_op_gpu(
binary_op_gpu<cu::func>(inputs, out, name(), s); \
}
BINARY_GPU(Add)
BINARY_GPU(ArcTan2)
BINARY_GPU(Divide)
BINARY_GPU(Remainder)
BINARY_GPU(Greater)
BINARY_GPU(GreaterEqual)
BINARY_GPU(Less)
BINARY_GPU(LessEqual)
BINARY_GPU(LogicalAnd)
BINARY_GPU(LogicalOr)
BINARY_GPU(LogAddExp)
BINARY_GPU(Maximum)
BINARY_GPU(Minimum)
BINARY_GPU(Multiply)
BINARY_GPU(NotEqual)
BINARY_GPU(Power)
BINARY_GPU(Subtract)
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Equal::eval_gpu");
auto& s = out.primitive().stream();
if (equal_nan_) {
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
} else {
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
}
}
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
break;
case BitwiseBinary::Or:
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
break;
case BitwiseBinary::Xor:
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
break;
case BitwiseBinary::LeftShift:
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
break;
case BitwiseBinary::RightShift:
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
break;
}
}
} // namespace mlx::core

View File

@@ -1,21 +0,0 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/add.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan2.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_binary.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/divide.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_and.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_or.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log_add_exp.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/minimum.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/maximum.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/multiply.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/power.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/remainder.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/not_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/subtract.cu)

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Add)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(ArcTan2)
} // namespace mlx::core

View File

@@ -1,27 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
break;
case BitwiseBinary::Or:
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
break;
case BitwiseBinary::Xor:
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
break;
case BitwiseBinary::LeftShift:
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
break;
case BitwiseBinary::RightShift:
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
break;
}
}
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Divide)
} // namespace mlx::core

View File

@@ -1,15 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Equal::eval_gpu");
auto& s = out.primitive().stream();
if (equal_nan_) {
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
} else {
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
}
}
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Greater)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(GreaterEqual)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Less)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LessEqual)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LogAddExp)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LogicalAnd)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LogicalOr)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Maximum)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Minimum)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Multiply)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(NotEqual)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Power)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Remainder)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Subtract)
} // namespace mlx::core

View File

@@ -127,99 +127,45 @@ binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
}
}
template <
typename Op,
typename In,
typename Out,
typename IdxT,
int NDIM,
int N_READS>
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
__global__ void binary_two_g_nd(
const In* a,
const In* b,
Out* out_a,
Out* out_b,
IdxT size_rest,
IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index, shape.data(), a_strides.data(), b_strides.data());
auto out = Op{}(a[a_idx], b[b_idx]);
out_a[index] = out[0];
out_b[index] = out[1];
}
auto shape_x = shape[NDIM - 1];
auto a_stride_x = a_strides[NDIM - 1];
auto b_stride_x = b_strides[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data());
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec_a;
AlignedVector<Out, N_READS> out_vec_b;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec[i], b_vec[i]);
out_vec_a[i] = out[0];
out_vec_b[i] = out[1];
}
store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x);
store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x);
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_two_g(
const In* a,
const In* b,
Out* out_a,
Out* out_b,
IdxT size_rest,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides,
int ndim) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc(
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
auto out = Op{}(a[a_idx], b[b_idx]);
out_a[index] = out[0];
out_b[index] = out[1];
}
auto shape_x = shape[ndim - 1];
auto a_stride_x = a_strides[ndim - 1];
auto b_stride_x = b_strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc(
index_rest * shape_x,
shape.data(),
a_strides.data(),
b_strides.data(),
ndim);
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec_a;
AlignedVector<Out, N_READS> out_vec_b;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec[i], b_vec[i]);
out_vec_a[i] = out[0];
out_vec_b[i] = out[1];
}
store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x);
store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x);
}
template <typename Op, typename In, typename Out>
@@ -279,64 +225,42 @@ void binary_two_op_gpu_inplace(
auto& a_strides = strides[0];
auto& b_strides = strides[1];
int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out_a.size() / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu::binary_two_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant(),
1>;
if (work_per_thread == 4) {
kernel = cu::binary_two_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant(),
4>;
}
auto [num_blocks, block_dims] =
get_launch_args(out_a, large());
encoder.add_kernel_node(
kernel,
{num_blocks_x, num_blocks_y},
cu::binary_two_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant()>,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
rest,
out_a.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides));
});
} else {
auto kernel = cu::binary_two_g<Op, InType, OutType, IdxT, 1>;
if (work_per_thread == 4) {
kernel = cu::binary_two_g<Op, InType, OutType, IdxT, 4>;
}
auto [num_blocks, block_dims] =
get_launch_args(out_a, large());
encoder.add_kernel_node(
kernel,
{num_blocks_x, num_blocks_y},
cu::binary_two_g<Op, InType, OutType, IdxT>,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
rest,
out_a.size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),

View File

@@ -7,6 +7,9 @@
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
// cudnn_frontend.h redefines this macro.
#undef CHECK_CUDA_ERROR
#include <cudnn_frontend.h>
#include <cudnn_frontend_find_plan.h>
#include <fmt/format.h>
@@ -333,42 +336,6 @@ std::optional<cudnn_frontend::OperationGraph> build_op_graph(
}
}
// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups).
array group_transpose(
const array& x,
int groups,
int group_dim,
int axis1,
int axis2,
Stream s) {
if (groups == 1) {
return swapaxes_in_eval(x, axis1, axis2);
}
int ndim = x.ndim();
if (group_dim < 0) {
group_dim += ndim;
}
if (axis1 < 0) {
axis1 += ndim;
}
if (axis2 < 0) {
axis2 += ndim;
}
if (group_dim <= axis1) {
axis1 += 1;
}
if (group_dim <= axis2) {
axis2 += 1;
}
auto shape = x.shape();
shape.insert(shape.begin() + group_dim, groups);
shape[group_dim + 1] = shape[group_dim + 1] / groups;
array x_trans = reshape_in_eval(x, std::move(shape), s);
x_trans = swapaxes_in_eval(x_trans, axis1, axis2);
x_trans = flatten_in_eval(x_trans, group_dim, group_dim + 1, s);
return x_trans;
}
// Do necessary transposes and copies to prepare the inputs and outputs for
// building the cuDNN conv op. It is safe to be called multiple times in one
// eval_gpu, with cost of possible redundant copies.
@@ -378,14 +345,13 @@ std::tuple<array, array, array> prepare_args(
array in,
array wt,
array out,
int groups,
Stream s) {
// Transpose the args depending on the backend type.
// TODO: Handle groups.
if (backend_type == CONV_BACKWARD_INPUT) {
wt = group_transpose(wt, groups, 0, 0, -1, s);
wt = swapaxes_in_eval(wt, 0, -1);
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
in = group_transpose(in, groups, -1, 0, -1, s);
in = swapaxes_in_eval(in, 0, -1);
wt = swapaxes_in_eval(wt, 0, -1);
// Create a contiguous array that shares the data with |out|, but with dim
// C_in and C_out swapped.
@@ -491,8 +457,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
get_alignment(out)};
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
auto& [backend_type, plan] = it->second;
std::tie(in, wt, out) =
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
std::tie(in, wt, out) = prepare_args(encoder, backend_type, in, wt, out, s);
register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (!execute_plan(encoder, plan, x, w, y)) {
@@ -525,7 +490,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
std::optional<cudnn_frontend::OperationGraph> op_graph;
for (auto try_backend : try_backends) {
auto [in_copy, wt_copy, out_copy] =
prepare_args(encoder, try_backend, in, wt, out, groups_, s);
prepare_args(encoder, try_backend, in, wt, out, s);
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
try_backend,

View File

@@ -10,80 +10,37 @@ namespace cu {
namespace cg = cooperative_groups;
template <typename In, typename Out, typename IdxT, int NDIM, int N_READS>
template <typename In, typename Out, typename IdxT, int NDIM>
__global__ void copy_gg_nd(
const In* in,
Out* out,
IdxT size_rest,
IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
index, shape.data(), strides_in.data(), strides_out.data());
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
}
auto shape_x = shape[NDIM - 1];
auto in_stride_x = strides_in[NDIM - 1];
auto out_stride_x = strides_out[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
index_rest * shape_x,
shape.data(),
strides_in.data(),
strides_out.data());
auto in_vec =
load_vector<N_READS>(in + idx_in, index_x, shape_x, in_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
}
store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x);
}
template <typename In, typename Out, typename IdxT, int N_READS>
template <typename In, typename Out, typename IdxT>
__global__ void copy_gg(
const In* in,
Out* out,
IdxT size_rest,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides_in,
const __grid_constant__ Strides strides_out,
int ndim) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [idx_in, idx_out] = elem_to_loc(
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
}
auto shape_x = shape[ndim - 1];
auto in_stride_x = strides_in[ndim - 1];
auto out_stride_x = strides_out[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [idx_in, idx_out] = elem_to_loc(
index_rest * shape_x,
shape.data(),
strides_in.data(),
strides_out.data(),
ndim);
auto in_vec =
load_vector<N_READS>(in + idx_in, index_x, shape_x, in_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
}
store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x);
}
} // namespace cu
@@ -112,52 +69,33 @@ void copy_general(
size_t data_size = 1;
for (auto& s : shape)
data_size *= s;
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = data_size / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
auto kernel =
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant(), 1>;
if (work_per_thread == 4) {
kernel =
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant(), 4>;
}
auto [num_blocks, block_dims] =
get_launch_args(data_size, shape, out.strides(), large());
encoder.add_kernel_node(
kernel,
{num_blocks_x, num_blocks_y},
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
rest,
data_size,
const_param<ndim_constant()>(shape),
const_param<ndim_constant()>(strides_in),
const_param<ndim_constant()>(strides_out));
});
} else { // ndim >= 4
auto kernel = cu::copy_gg<InType, OutType, IdxT, 1>;
if (work_per_thread == 4) {
kernel = cu::copy_gg<InType, OutType, IdxT, 4>;
}
auto [num_blocks, block_dims] =
get_launch_args(data_size, shape, out.strides(), large());
encoder.add_kernel_node(
kernel,
{num_blocks_x, num_blocks_y},
cu::copy_gg<InType, OutType, IdxT>,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
rest,
data_size,
const_param(shape),
const_param(strides_in),
const_param(strides_out),

View File

@@ -10,67 +10,33 @@ namespace cu {
namespace cg = cooperative_groups;
template <typename In, typename Out, typename IdxT, int NDIM, int N_READS>
template <typename In, typename Out, typename IdxT, int NDIM>
__global__ void copy_g_nd(
const In* in,
Out* out,
IdxT size_rest,
IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
IdxT idx_in = elem_to_loc_nd<NDIM>(index, shape.data(), strides_in.data());
out[index] = CastOp<In, Out>{}(in[idx_in]);
}
auto shape_x = shape[NDIM - 1];
auto stride_x = strides[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto idx =
elem_to_loc_nd<NDIM>(index_rest * shape_x, shape.data(), strides.data());
auto in_vec =
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
template <typename In, typename Out, typename IdxT, int N_READS>
template <typename In, typename Out, typename IdxT>
__global__ void copy_g(
const In* in,
Out* out,
IdxT size_rest,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides,
const __grid_constant__ Strides strides_in,
int ndim) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
IdxT idx_in = elem_to_loc(index, shape.data(), strides_in.data(), ndim);
out[index] = CastOp<In, Out>{}(in[idx_in]);
}
auto shape_x = shape[ndim - 1];
auto stride_x = strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto idx =
elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim);
auto in_vec =
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
} // namespace cu
@@ -95,49 +61,30 @@ void copy_general_input(
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
if (work_per_thread == 4) {
kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
}
auto [num_blocks, block_dims] = get_launch_args(out, large());
encoder.add_kernel_node(
kernel,
{num_blocks_x, num_blocks_y},
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
rest,
out.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(strides_in));
});
} else { // ndim >= 4
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
if (work_per_thread == 4) {
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
}
auto [num_blocks, block_dims] = get_launch_args(out, large());
encoder.add_kernel_node(
kernel,
{num_blocks_x, num_blocks_y},
cu::copy_g<InType, OutType, IdxT>,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
rest,
out.size(),
const_param(shape),
const_param(strides_in),
ndim);

View File

@@ -146,23 +146,6 @@ inline __device__ void store_vector(
}
}
template <int N, typename T, typename SizeT>
inline __device__ void store_vector(
T* ptr,
uint32_t offset,
const AlignedVector<T, N>& vec,
SizeT size,
int64_t stride) {
if (is_aligned<N>(ptr) && (offset + 1) * N <= size && stride == 1) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
} else {
for (int i = 0; (offset * N + i) < size && i < N; ++i) {
ptr[stride * (offset * N + i)] = vec[i];
}
}
}
///////////////////////////////////////////////////////////////////////////////
// Type limits utils
///////////////////////////////////////////////////////////////////////////////

View File

@@ -4,16 +4,16 @@
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
namespace mlx::core {
namespace mlx::core::cu {
void CublasGemm::run_batched(
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides) {
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
@@ -22,7 +22,7 @@ void CublasGemm::run_batched(
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) {
execute(
run_impl(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
@@ -33,16 +33,16 @@ void CublasGemm::run_batched(
}
}
void CublasGemm::run_batched(
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
float alpha,
float beta) {
encoder.set_input_array(a);
@@ -56,7 +56,7 @@ void CublasGemm::run_batched(
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) {
execute(
run_impl(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
@@ -70,4 +70,4 @@ void CublasGemm::run_batched(
}
}
} // namespace mlx::core
} // namespace mlx::core::cu

View File

@@ -0,0 +1,208 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include <cooperative_groups.h>
namespace mlx::core::cu {
namespace cg = cooperative_groups;
__global__ void set_mm_device_pointers(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_addmm_device_pointers(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* c_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
const __grid_constant__ Strides c_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset, c_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
c_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
pointers[index + 3 * batch_count] =
out_start + item_size * index * batch_stride;
}
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
&batch_mode,
sizeof(batch_mode)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
}
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides) {
auto batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 3),
{static_cast<int>(batch_count * 3)},
uint64);
encoder.add_temporary(pointers);
int block_size = 512;
encoder.set_output_array(pointers);
encoder.add_kernel_node(
cu::set_mm_device_pointers,
cuda::ceil_div(pointers.size(), block_size),
block_size,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
static_cast<int>(out.dtype().size()),
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
static_cast<int64_t>(M_) * N_,
static_cast<int>(batch_shape.size()),
batch_count);
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count;
run_impl(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
nullptr);
}
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
float alpha,
float beta) {
auto batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(c_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
{static_cast<int>(batch_count * 4)},
uint64);
encoder.add_temporary(pointers);
int block_size = 512;
encoder.set_output_array(pointers);
encoder.add_kernel_node(
cu::set_addmm_device_pointers,
cuda::ceil_div(pointers.size(), block_size),
block_size,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
static_cast<int>(out.dtype().size()),
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
const_param(c_batch_strides),
static_cast<int64_t>(M_) * N_,
static_cast<int>(batch_shape.size()),
batch_count);
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto c_pointers = b_pointers + batch_count;
auto out_pointers = c_pointers + batch_count;
run_impl(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
reinterpret_cast<void*>(c_pointers),
alpha,
beta);
}
} // namespace mlx::core::cu

View File

@@ -7,12 +7,10 @@
#include <fmt/format.h>
namespace mlx::core {
namespace {
namespace mlx::core::cu {
struct CublasPreference {
CublasPreference(cu::Device& device) {
CublasPreference(Device& device) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
@@ -35,7 +33,7 @@ struct CublasPreference {
cublasLtMatmulPreference_t pref_{nullptr};
};
cublasLtMatmulPreference_t cublas_preference(cu::Device& device) {
cublasLtMatmulPreference_t cublas_preference(Device& device) {
static CublasPreference pref(device);
return pref.pref_;
}
@@ -54,7 +52,7 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
return CUBLAS_COMPUTE_64F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
}
}
@@ -72,7 +70,7 @@ cudaDataType_t dtype_to_cublas_type(Dtype dtype) {
return CUDA_C_32F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
}
}
@@ -104,10 +102,8 @@ cublasLtMatrixLayout_t create_matrix_layout(
return desc;
}
} // namespace
CublasGemm::CublasGemm(
cu::Device& device,
Matmul::Matmul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
@@ -159,8 +155,8 @@ CublasGemm::CublasGemm(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
}
CublasGemm::CublasGemm(
cu::Device& device,
Matmul::Matmul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
@@ -175,7 +171,7 @@ CublasGemm::CublasGemm(
int64_t a_batch_stride,
int64_t b_batch_stride,
int64_t c_batch_stride)
: CublasGemm(
: Matmul(
device,
dtype,
a_transposed,
@@ -194,7 +190,7 @@ CublasGemm::CublasGemm(
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
}
CublasGemm::~CublasGemm() {
Matmul::~Matmul() {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
@@ -202,73 +198,7 @@ CublasGemm::~CublasGemm() {
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
}
void CublasGemm::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) {
run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
return;
}
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
}
void CublasGemm::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta) {
int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) {
run_batched(
encoder,
out,
a,
b,
c,
batch_shape,
a_batch_strides,
b_batch_strides,
c_batch_strides,
alpha,
beta);
return;
}
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c.data<void>(),
alpha,
beta);
}
void CublasGemm::execute(
void Matmul::run_impl(
cu::CommandEncoder& encoder,
void* out,
const void* a,
@@ -326,4 +256,29 @@ void CublasGemm::execute(
encoder.stream()));
}
} // namespace mlx::core
void Matmul::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const std::optional<array>& c /* = std::nullopt */,
float alpha /* = 1 */,
float beta /* = 0 */) {
encoder.set_input_array(a);
encoder.set_input_array(b);
if (c) {
encoder.set_input_array(*c);
}
encoder.set_output_array(out);
run_impl(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c ? c->data<void>() : nullptr,
alpha,
beta);
}
} // namespace mlx::core::cu

View File

@@ -5,13 +5,13 @@
#include "mlx/backend/cuda/device.h"
#include <cublasLt.h>
#include <optional>
namespace mlx::core {
class CublasGemm {
namespace mlx::core::cu {
class Matmul {
public:
CublasGemm(
cu::Device& device,
Matmul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
@@ -25,8 +25,8 @@ class CublasGemm {
int64_t a_batch_stride,
int64_t b_batch_stride);
CublasGemm(
cu::Device& device,
Matmul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
@@ -42,54 +42,41 @@ class CublasGemm {
int64_t b_batch_stride,
int64_t c_batch_stride);
~CublasGemm();
~Matmul();
void run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides);
const std::optional<array>& c = std::nullopt,
float alpha = 1,
float beta = 0);
void run(
void run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides);
void run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
float alpha,
float beta);
private:
void run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides);
void run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta);
void execute(
void run_impl(
cu::CommandEncoder& encoder,
void* out,
const void* a,
@@ -110,4 +97,4 @@ class CublasGemm {
cublasLtMatmulHeuristicResult_t heuristic_;
};
} // namespace mlx::core
} // namespace mlx::core::cu

View File

@@ -1,327 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <int NDIM>
__global__ void set_mm_device_pointers_nd(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* out_start,
int item_size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
int64_t batch_stride,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset] = elem_to_loc_nd<NDIM>(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data());
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_mm_device_pointers_g(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] =
out_start + item_size * index * batch_stride;
}
template <int NDIM>
__global__ void set_addmm_device_pointers_nd(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* c_start,
int8_t* out_start,
int item_size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> c_batch_strides,
int64_t batch_stride,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset, c_offset] = elem_to_loc_nd<NDIM>(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
c_batch_strides.data());
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
pointers[index + 3 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_addmm_device_pointers_g(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* c_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
const __grid_constant__ Strides c_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset, c_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
c_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
pointers[index + 3 * batch_count] =
out_start + item_size * index * batch_stride;
}
} // namespace cu
namespace {
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
&batch_mode,
sizeof(batch_mode)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
}
} // namespace
void CublasGemm::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(void*) * 3),
{batch_count * 3},
uint64);
encoder.add_temporary(pointers);
encoder.set_output_array(pointers);
int block_dims = std::min(batch_count, 256);
int num_blocks = cuda::ceil_div(batch_count, block_dims);
int64_t batch_stride = M_ * N_;
int item_size = out.itemsize();
int ndim = batch_shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
encoder.add_kernel_node(
cu::set_mm_device_pointers_nd<ndim_constant()>,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
const_param<ndim_constant()>(b_batch_strides),
batch_stride,
batch_count);
});
} else {
encoder.add_kernel_node(
cu::set_mm_device_pointers_g,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
batch_stride,
ndim,
batch_count);
}
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count;
execute(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
nullptr);
}
void CublasGemm::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta) {
int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(c_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
{batch_count * 4},
uint64);
encoder.add_temporary(pointers);
encoder.set_output_array(pointers);
int block_dims = std::min(batch_count, 256);
int num_blocks = cuda::ceil_div(batch_count, block_dims);
int64_t batch_stride = M_ * N_;
int item_size = out.itemsize();
int ndim = batch_shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
encoder.add_kernel_node(
cu::set_addmm_device_pointers_nd<ndim_constant()>,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
const_param<ndim_constant()>(b_batch_strides),
const_param<ndim_constant()>(c_batch_strides),
batch_stride,
batch_count);
});
} else {
encoder.add_kernel_node(
cu::set_addmm_device_pointers_g,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
const_param(c_batch_strides),
batch_stride,
ndim,
batch_count);
}
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto c_pointers = b_pointers + batch_count;
auto out_pointers = c_pointers + batch_count;
execute(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
reinterpret_cast<void*>(c_pointers),
alpha,
beta);
}
} // namespace mlx::core

View File

@@ -97,7 +97,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
CublasGemm gemm(
cu::Matmul matmul(
cu::device(s.device),
a.dtype(),
a_transposed,
@@ -111,7 +111,14 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
if ((batch_count / batch_shape.back()) == 1) {
matmul.run(encoder, out, a, b);
return;
}
matmul.run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
}
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -179,7 +186,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
CublasGemm gemm(
cu::Matmul matmul(
cu::device(s.device),
a.dtype(),
a_transposed,
@@ -195,7 +202,12 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
a_batch_strides.back(),
b_batch_strides.back(),
c_batch_strides.back());
gemm.run(
if ((batch_count / batch_shape.back()) == 1) {
matmul.run(encoder, out, a, b, c, alpha_, beta_);
return;
}
matmul.run_batched(
encoder,
out,
a,

View File

@@ -8,13 +8,19 @@
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/transforms_impl.h"
// cudnn_frontend.h redefines this macro.
#undef CHECK_CUDA_ERROR
#include <cudnn_frontend.h>
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace fe = cudnn_frontend;
namespace mlx::core {
namespace cu {
@@ -639,6 +645,294 @@ void sdpa_vector_fallback(
}
}
struct SDPACacheKey {
int device_id;
fe::DataType_t cudnn_type;
int B;
int H;
int D;
int qL;
int kL;
int gqa_factor;
float scale;
int64_t Q_strides[3];
int64_t K_strides[3];
int64_t V_strides[3];
int64_t O_strides[3];
bool generate_stats;
bool causal_mask;
};
auto& sdpa_cache() {
static LRUBytesKeyCache<SDPACacheKey, std::shared_ptr<fe::graph::Graph>>
cache(
/* capacity */ 128);
return cache;
}
#define Q_UID 1
#define K_UID 2
#define V_UID 3
#define O_UID 4
#define STATS_UID 5
std::shared_ptr<fe::graph::Graph> get_sdpa_forward_graph(
cu::CommandEncoder& encoder,
const SDPACacheKey& cache_key) {
// Check if graph has already been fully built
if (auto it = sdpa_cache().find(cache_key); it != sdpa_cache().end()) {
return it->second;
}
// Set up new graph
auto graph = std::make_shared<fe::graph::Graph>();
graph->set_io_data_type(cache_key.cudnn_type)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
auto Q = graph->tensor(
fe::graph::Tensor_attributes()
.set_name("Q")
.set_uid(Q_UID)
.set_dim({cache_key.B, cache_key.H, cache_key.qL, cache_key.D})
.set_stride(
{cache_key.Q_strides[0],
cache_key.Q_strides[1],
cache_key.Q_strides[2],
1}));
int h_kv = cache_key.H / cache_key.gqa_factor;
auto K =
graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_uid(K_UID)
.set_dim({cache_key.B, h_kv, cache_key.kL, cache_key.D})
.set_stride(
{cache_key.K_strides[0],
cache_key.K_strides[1],
cache_key.V_strides[2],
1}));
auto V =
graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_uid(V_UID)
.set_dim({cache_key.B, h_kv, cache_key.kL, cache_key.D})
.set_stride(
{cache_key.V_strides[0],
cache_key.V_strides[1],
cache_key.V_strides[2],
1}));
auto sdpa_options = fe::graph::SDPA_attributes()
.set_name("flash_attention")
.set_is_inference(!cache_key.generate_stats)
.set_attn_scale(cache_key.scale);
if (cache_key.causal_mask && cache_key.qL > 1) {
sdpa_options.set_diagonal_alignment(fe::DiagonalAlignment_t::TOP_LEFT)
.set_diagonal_band_right_bound(0);
}
auto [O, Stats] = graph->sdpa(Q, K, V, sdpa_options);
O->set_output(true)
.set_uid(O_UID)
.set_dim({cache_key.B, cache_key.H, cache_key.qL, cache_key.D})
.set_stride(
{cache_key.O_strides[0],
cache_key.O_strides[1],
cache_key.O_strides[2],
1});
if (cache_key.generate_stats) {
Stats->set_output(true)
.set_data_type(fe::DataType_t::FLOAT)
.set_uid(STATS_UID);
}
// Build and Validate cudnn graph
auto handle = encoder.device().cudnn_handle();
// cuDNN only supports native CUDA graphs for sdpa in 9.6 or above.
if (cudnnGetVersion() < 90600) {
auto build_status = graph->build(handle, {fe::HeurMode_t::A});
if (!build_status.is_good()) {
throw std::runtime_error(
"Unable to build cudnn graph for attention."
" Failed with message: " +
build_status.get_message());
}
} else {
auto val_status = graph->validate();
auto op_status = graph->build_operation_graph(handle);
auto plan_stauts =
graph->create_execution_plans({cudnn_frontend::HeurMode_t::A});
if (!plan_stauts.is_good()) {
throw std::runtime_error(
"Unable to create exec plan for cudnn attention."
" Failed with message: " +
plan_stauts.get_message());
}
graph->select_behavior_notes(
{cudnn_frontend::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
auto support_status = graph->check_support(handle);
if (!support_status.is_good()) {
throw std::runtime_error(
"No cuda graph support for cudnn attention."
" Failed with message: " +
support_status.get_message());
}
auto build_status = graph->build_plans(handle);
if (!build_status.is_good()) {
throw std::runtime_error(
"Unable to build cudnn graph for attention."
" Failed with message: " +
build_status.get_message());
}
}
auto [it, _] = sdpa_cache().emplace(cache_key, graph);
return it->second;
}
inline fe::DataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return fe::DataType_t::INT8;
case int32:
return fe::DataType_t::INT32;
case uint8:
return fe::DataType_t::UINT8;
case float16:
return fe::DataType_t::HALF;
case bfloat16:
return fe::DataType_t::BFLOAT16;
case float32:
return fe::DataType_t::FLOAT;
case float64:
return fe::DataType_t::DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in SDPA: {}.", dtype_to_string(dtype)));
}
}
void sdpa_cudnn(
const Stream& s,
cu::CommandEncoder& encoder,
const array& q,
const array& k,
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
encoder.set_output_array(o);
auto cudnn_type = dtype_to_cudnn_type(q.dtype());
int B = q.shape(0);
int H = q.shape(1);
int D = q.shape(3);
int gqa_factor = q.shape(1) / k.shape(1);
int qL = q.shape(2);
int kL = k.shape(2);
SDPACacheKey cache_key{
/* int device_id = */ encoder.device().cuda_device(),
/* fe::DataType_t cudnn_type = */ cudnn_type,
/* int B = */ B,
/* int H = */ H,
/* int D = */ D,
/* int qL = */ qL,
/* int kL = */ kL,
/* int gqa_factor = */ gqa_factor,
/* float scale = */ scale,
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)},
/* bool generate_stats = */ false,
/* bool causal_mask = */ do_causal_};
auto graph = get_sdpa_forward_graph(encoder, cache_key);
int64_t workspace_size = 0;
auto workspace_status = graph->get_workspace_size(workspace_size);
if (!workspace_status.is_good()) {
throw std::runtime_error("Unable to get workspace for cudnn attention.");
}
array workspace(
allocator::malloc(workspace_size), {int(workspace_size)}, uint8);
auto workspace_ptr = workspace.data<void>();
std::unordered_map<int64_t, void*> variant_pack = {
{Q_UID, const_cast<void*>(q.data<void>())},
{K_UID, const_cast<void*>(k.data<void>())},
{V_UID, const_cast<void*>(v.data<void>())},
{O_UID, o.data<void>()}};
auto handle = encoder.device().cudnn_handle();
cudnnSetStream(handle, encoder.stream());
// cuDNN only supports native CUDA graphs for sdpa in 9.6 or above.
if (cudnnGetVersion() < 90600) {
auto capture = encoder.capture_context();
auto exec_status = graph->execute(handle, variant_pack, workspace_ptr);
if (!exec_status.is_good()) {
capture.discard = true;
throw std::runtime_error(
"Unable to execute cudnn attention."
" Failed with message: " +
exec_status.get_message());
}
} else {
cudaGraph_t cu_graph;
cudaGraphCreate(&cu_graph, 0);
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&cu_graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
auto cu_graph_status = graph->populate_cuda_graph(
handle, variant_pack, workspace_ptr, cu_graph);
if (!cu_graph_status.is_good()) {
throw std::runtime_error(
"Unable to add cuda graph for cudnn attention."
" Failed with message: " +
cu_graph_status.get_message());
}
encoder.add_graph_node(cu_graph);
}
encoder.add_temporary(workspace);
}
} // namespace
namespace fast {
@@ -651,9 +945,6 @@ bool ScaledDotProductAttention::use_fallback(
bool has_arr_mask,
bool do_causal,
Stream s) {
if (detail::in_grad_tracing()) {
return true;
}
if (s.device == Device::cpu) {
return true;
}
@@ -669,7 +960,15 @@ bool ScaledDotProductAttention::use_fallback(
const bool supported_vector_config =
sdpa_supported_head_dim && query_sequence_length < 4;
const bool supported_config = supported_vector_config;
auto& cu_device = cu::device(s.device);
const bool supported_matrix_config = query_sequence_length > 4 &&
cu_device.compute_capability_major() >= 8 &&
query_sequence_length == key_sequence_length &&
(q.dtype() == float16 || q.dtype() == bfloat16);
const bool supported_config =
(supported_matrix_config || supported_vector_config);
return has_arr_mask || !supported_config;
}
@@ -703,6 +1002,10 @@ void ScaledDotProductAttention::eval_gpu(
}
};
auto is_matrix_contiguous = [](const array& arr) {
return arr.strides(-1) == 1;
};
// We are in vector mode ie single query
if (q_pre.shape(2) < 4) {
auto q_copy_unless = [](const array& arr) {
@@ -756,7 +1059,7 @@ void ScaledDotProductAttention::eval_gpu(
array::Flags flags{
/* bool contiguous = */ 1,
/* bool row_contiguous = */ o.shape(2) == 1,
/* bool row_contiguous = */ 0,
/* bool col_contiguous = */ 0,
};
@@ -770,9 +1073,35 @@ void ScaledDotProductAttention::eval_gpu(
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
}
// Full attention mode should never reach here
// Full attention mode
else {
throw std::runtime_error("Doesn't support matrix yet.");
const auto& q = copy_unless(is_matrix_contiguous, q_pre);
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
for (const auto& cp : copies) {
encoder.add_temporary(cp);
}
int64_t str_oD = 1;
int64_t str_oH = o.shape(3);
int64_t str_oL = o.shape(1) * str_oH;
int64_t str_oB = o.shape(2) * str_oL;
size_t data_size = o.shape(0) * str_oB;
array::Flags flags{
/* bool contiguous = */ 1,
/* bool row_contiguous = */ 0,
/* bool col_contiguous = */ 0,
};
o.set_data(
allocator::malloc(o.nbytes()),
data_size,
{str_oB, str_oH, str_oL, str_oD},
flags);
return sdpa_cudnn(s, encoder, q, k, v, scale_, o, do_causal_);
}
}

View File

@@ -1,5 +1,6 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h"

View File

@@ -39,98 +39,52 @@ ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) {
}
}
template <typename Op, typename T, typename IdxT, int NDIM, int N_READS>
template <typename Op, typename T, typename IdxT, int NDIM>
__global__ void ternary_g_nd(
const bool* a,
const T* b,
const T* c,
T* out,
IdxT size_rest,
IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> c_strides) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx, c_idx] = elem_to_loc_nd<NDIM>(
index,
shape.data(),
a_strides.data(),
b_strides.data(),
c_strides.data());
out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]);
}
auto shape_x = shape[NDIM - 1];
auto a_stride_x = a_strides[NDIM - 1];
auto b_stride_x = b_strides[NDIM - 1];
auto c_stride_x = c_strides[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx, c_idx] = elem_to_loc_nd<NDIM>(
index_rest * shape_x,
shape.data(),
a_strides.data(),
b_strides.data(),
c_strides.data());
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, false);
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, T(0));
auto c_vec =
load_vector<N_READS>(c + c_idx, index_x, shape_x, c_stride_x, T(0));
AlignedVector<T, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i], c_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
template <typename Op, typename T, typename IdxT, int N_READS>
template <typename Op, typename T, typename IdxT>
__global__ void ternary_g(
const bool* a,
const T* b,
const T* c,
T* out,
IdxT size_rest,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides,
const __grid_constant__ Strides c_strides,
int ndim) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx, c_idx] = elem_to_loc(
index,
shape.data(),
a_strides.data(),
b_strides.data(),
c_strides.data(),
ndim);
out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]);
}
auto shape_x = shape[ndim - 1];
auto a_stride_x = a_strides[ndim - 1];
auto b_stride_x = b_strides[ndim - 1];
auto c_stride_x = c_strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx, c_idx] = elem_to_loc(
index_rest * shape_x,
shape.data(),
a_strides.data(),
b_strides.data(),
c_strides.data(),
ndim);
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, false);
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, T(0));
auto c_vec =
load_vector<N_READS>(c + c_idx, index_x, shape_x, c_stride_x, T(0));
AlignedVector<T, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i], c_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
} // namespace cu
@@ -169,55 +123,36 @@ void ternary_op_gpu_inplace(
auto& b_strides = strides[1];
auto& c_strides = strides[2];
int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel =
cu::ternary_g_nd<Op, DType, IdxT, dims_constant(), 1>;
if (work_per_thread == 4) {
kernel =
cu::ternary_g_nd<Op, DType, IdxT, dims_constant(), 4>;
}
auto [num_blocks, block_dims] = get_launch_args(out, large());
encoder.add_kernel_node(
kernel,
{num_blocks_x, num_blocks_y},
cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>,
num_blocks,
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
rest,
out.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides),
const_param<dims_constant()>(c_strides));
});
} else {
auto kernel = cu::ternary_g<Op, DType, IdxT, 1>;
if (work_per_thread == 4) {
kernel = cu::ternary_g<Op, DType, IdxT, 4>;
}
auto [num_blocks, block_dims] = get_launch_args(out, large());
encoder.add_kernel_node(
kernel,
{num_blocks_x, num_blocks_y},
cu::ternary_g<Op, DType, IdxT>,
num_blocks,
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
rest,
out.data_size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),

View File

@@ -37,36 +37,19 @@ __global__ void unary_v(const In* in, Out* out, IdxT size) {
}
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void unary_g(
const In* in,
Out* out,
IdxT size_rest,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides,
int ndim) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto idx = elem_to_loc(index, shape.data(), strides.data(), ndim);
out[index] = Op{}(in[idx]);
}
auto shape_x = shape[ndim - 1];
auto stride_x = strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto idx =
elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim);
auto in_vec =
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(in_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
template <typename Op, typename In, typename Out>
@@ -144,7 +127,8 @@ void unary_op_gpu_inplace(
using OutType = cuda_type_t<CTYPE_OUT>;
if (contig) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(OutType);
// TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large, N_READS);
encoder.add_kernel_node(
@@ -158,30 +142,18 @@ void unary_op_gpu_inplace(
} else {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
auto [shape, strides] = collapse_contiguous_dims(in);
auto ndim = shape.size();
int work_per_thread = 1;
auto kernel = cu::unary_g<Op, InType, OutType, IdxT, 1>;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
kernel = cu::unary_g<Op, InType, OutType, IdxT, 4>;
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
auto [num_blocks, block_dims] = get_launch_args(out, large);
encoder.add_kernel_node(
kernel,
{num_blocks_x, num_blocks_y},
cu::unary_g<Op, InType, OutType, IdxT>,
num_blocks,
block_dims,
0,
in.data<InType>(),
out.data<OutType>(),
rest,
out.data_size(),
const_param(shape),
const_param(strides),
ndim);
shape.size());
}
});
} else {

View File

@@ -1,34 +0,0 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/abs.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arccos.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arccosh.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arcsin.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arcsinh.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctanh.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_invert.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ceil.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conjugate.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cos.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cosh.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/erf.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/erf_inv.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/exp.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/expm1.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/floor.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/imag.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log1p.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_not.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/negative.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/real.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/round.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sigmoid.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sign.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sin.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sinh.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sqrt.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/square.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/tan.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/tanh.cu)

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Abs)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(ArcCos)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(ArcCosh)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(ArcSin)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(ArcSinh)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(ArcTan)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(ArcTanh)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(BitwiseInvert)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Ceil)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Conjugate)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Cos)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Cosh)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Erf)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(ErfInv)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Exp)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Expm1)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Floor)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Imag)
} // namespace mlx::core

View File

@@ -1,21 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Log::eval_gpu");
auto& s = out.primitive().stream();
switch (base_) {
case Base::e:
unary_op_gpu<cu::Log>(inputs, out, name(), s);
break;
case Base::two:
unary_op_gpu<cu::Log2>(inputs, out, name(), s);
break;
case Base::ten:
unary_op_gpu<cu::Log10>(inputs, out, name(), s);
break;
}
}
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Log1p)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(LogicalNot)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Negative)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Real)
} // namespace mlx::core

View File

@@ -1,18 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Round::eval_gpu");
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto& s = out.primitive().stream();
if (issubdtype(in.dtype(), inexact)) {
unary_op_gpu<cu::Round>(inputs, out, name(), s);
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Sigmoid)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Sign)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Sin)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Sinh)
} // namespace mlx::core

View File

@@ -1,15 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Sqrt::eval_gpu");
auto& s = out.primitive().stream();
if (recip_) {
unary_op_gpu<cu::Rsqrt>(inputs, out, "Rsqrt", s);
} else {
unary_op_gpu<cu::Sqrt>(inputs, out, "Sqrt", s);
}
}
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Square)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Tan)
} // namespace mlx::core

View File

@@ -1,7 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Tanh)
} // namespace mlx::core

View File

@@ -1,215 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/unary.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/unary_ops.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void unary_v(const In* in, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(in[i]);
}
} else {
auto in_vec = load_vector<N_READS>(in, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(in_vec[i]);
}
store_vector<N_READS>(out, index, out_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void unary_g(
const In* in,
Out* out,
IdxT size_rest,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides,
int ndim) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[ndim - 1];
auto stride_x = strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto idx =
elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim);
auto in_vec =
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(in_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
template <typename Op, typename In, typename Out>
constexpr bool supports_unary_op() {
if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||
std::is_same_v<Op, Sign> || std::is_same_v<Op, Square>) {
return std::is_same_v<In, Out>;
}
if (std::is_same_v<Op, ArcCosh> || std::is_same_v<Op, ArcSinh> ||
std::is_same_v<Op, ArcTanh> || std::is_same_v<Op, Erf> ||
std::is_same_v<Op, ErfInv> || std::is_same_v<Op, Expm1> ||
std::is_same_v<Op, Sigmoid>) {
return std::is_same_v<In, Out> && is_floating_v<In>;
}
if (std::is_same_v<Op, BitwiseInvert>) {
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
!std::is_same_v<In, bool>;
}
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor>) {
return std::is_same_v<In, Out> && !mlx::core::is_complex_v<In>;
}
if (std::is_same_v<Op, Conjugate>) {
return std::is_same_v<In, Out> && mlx::core::is_complex_v<In>;
}
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcSin> ||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, Cos> ||
std::is_same_v<Op, Cosh> || std::is_same_v<Op, Exp> ||
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p> ||
std::is_same_v<Op, Round> || std::is_same_v<Op, Rsqrt> ||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Sin> ||
std::is_same_v<Op, Sinh> || std::is_same_v<Op, Tan> ||
std::is_same_v<Op, Tanh>) {
return std::is_same_v<In, Out> && is_inexact_v<In>;
}
if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
return mlx::core::is_complex_v<In> && std::is_same_v<Out, float>;
}
if (std::is_same_v<Op, LogicalNot>) {
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
}
return false;
}
} // namespace cu
template <typename Op>
void unary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const char* op,
const Stream& s) {
auto& in = inputs[0];
if (in.size() == 0) {
return;
}
bool contig = in.flags().contiguous;
bool large;
if (!contig) {
large = in.data_size() > INT32_MAX || out.size() > INT32_MAX;
} else {
large = in.data_size() > UINT32_MAX;
}
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
dispatch_bool(large, [&](auto large) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
if (contig) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(OutType);
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large, N_READS);
encoder.add_kernel_node(
cu::unary_v<Op, InType, OutType, IdxT, N_READS>,
num_blocks,
block_dims,
0,
in.data<InType>(),
out.data<OutType>(),
out.data_size());
} else {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
auto [shape, strides] = collapse_contiguous_dims(in);
auto ndim = shape.size();
int work_per_thread = 1;
auto kernel = cu::unary_g<Op, InType, OutType, IdxT, 1>;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
kernel = cu::unary_g<Op, InType, OutType, IdxT, 4>;
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
encoder.add_kernel_node(
kernel,
{num_blocks_x, num_blocks_y},
block_dims,
0,
in.data<InType>(),
out.data<OutType>(),
rest,
const_param(shape),
const_param(strides),
ndim);
}
});
} else {
throw std::runtime_error(fmt::format(
"Can not do unary op {} on input of {} with output of {}.",
op,
dtype_to_string(in.dtype()),
dtype_to_string(out.dtype())));
}
});
});
}
template <typename Op>
void unary_op_gpu(
const std::vector<array>& inputs,
array& out,
const char* op,
const Stream& s) {
set_unary_output_data(inputs[0], out);
unary_op_gpu_inplace<Op>(inputs, out, op, s);
}
#define UNARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \
unary_op_gpu<cu::func>(inputs, out, name(), s); \
}
} // namespace mlx::core

View File

@@ -52,70 +52,4 @@ array contiguous_copy_gpu(const array& arr, const Stream& s) {
return arr_copy;
}
void reshape_gpu(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) {
int ndim = x.ndim();
if (start_axis < 0) {
start_axis += ndim;
}
if (end_axis < 0) {
end_axis += ndim;
}
start_axis = std::max(0, start_axis);
end_axis = std::min(ndim - 1, end_axis);
return reshape_in_eval(x, Flatten::output_shape(x, start_axis, end_axis), s);
}
array reshape_in_eval(const array& x, Shape shape, Stream s) {
array out(std::move(shape), x.dtype(), nullptr, {});
reshape_gpu(x, out, s);
return out;
}
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
int ndim = x.ndim();
if (axis1 < 0) {
axis1 += ndim;
}
if (axis2 < 0) {
axis2 += ndim;
}
auto shape = x.shape();
std::swap(shape[axis1], shape[axis2]);
auto strides = x.strides();
std::swap(strides[axis1], strides[axis2]);
auto [data_size, row_contiguous, col_contiguous] =
check_contiguity(shape, strides);
bool contiguous = data_size == x.data_size();
array out(std::move(shape), x.dtype(), nullptr, {});
out.copy_shared_buffer(
x,
std::move(strides),
{contiguous, row_contiguous, col_contiguous},
x.data_size());
return out;
}
} // namespace mlx::core

View File

@@ -46,12 +46,4 @@ void fill_gpu(const array& val, array& out, const Stream& s);
// Return a contiguous array with same shape that copies the data of |arr|.
array contiguous_copy_gpu(const array& arr, const Stream& s);
// Copy data from |in| and transpose to |out|'s shape.
void reshape_gpu(const array& in, array& out, Stream s);
// Like the normal ops but safe to call in eval_gpu.
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s);
array reshape_in_eval(const array& x, Shape shape, Stream s);
array swapaxes_in_eval(const array& x, int axis1, int axis2);
} // namespace mlx::core

View File

@@ -20,6 +20,29 @@
namespace mlx::core {
namespace {
void reshape(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
} // namespace
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("AsStrided::eval_gpu");
eval(inputs, out);
@@ -101,7 +124,7 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Flatten::eval_gpu");
reshape_gpu(inputs[0], out, stream());
reshape(inputs[0], out, stream());
}
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -127,7 +150,7 @@ void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Reshape::eval_gpu");
reshape_gpu(inputs[0], out, stream());
reshape(inputs[0], out, stream());
}
void Split::eval_gpu(
@@ -201,7 +224,7 @@ void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Unflatten::eval_gpu");
reshape_gpu(inputs[0], out, stream());
reshape(inputs[0], out, stream());
}
void View::eval_gpu(const std::vector<array>& inputs, array& out) {

View File

@@ -60,12 +60,22 @@ struct CommandEncoder {
enc_->updateFence(fence);
}
template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>
void set_vector_bytes(const Vec& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(typename Vec::value_type), idx);
template <typename T>
void set_vector_bytes(const SmallVector<T>& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
}
template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>
void set_vector_bytes(const Vec& vec, int idx) {
template <typename T>
void set_vector_bytes(const SmallVector<T>& vec, int idx) {
return set_vector_bytes(vec, vec.size(), idx);
}
// TODO: Code is duplicated but they should be deleted soon.
template <typename T>
void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
}
template <typename T>
void set_vector_bytes(const std::vector<T>& vec, int idx) {
return set_vector_bytes(vec, vec.size(), idx);
}
@@ -94,7 +104,7 @@ struct CommandEncoder {
};
// Outputs of all kernels in the encoder including temporaries
std::unordered_set<const void*>& outputs() {
std::unordered_set<const void*> outputs() {
return all_outputs_;
};

View File

@@ -166,6 +166,115 @@ instantiate_naive_unfold_nd_dims(float32, float);
instantiate_naive_unfold_nd_dims(float16, half);
instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Slow and naive conv2d kernels
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const int BC = 16>
[[kernel]] void naive_conv_2d(
const device T* in [[buffer(0)]],
const device T* wt [[buffer(1)]],
device T* out [[buffer(2)]],
const constant MLXConvParams<2>& params [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)simd_gid;
(void)simd_lid;
out += tid.z * params.out_strides[0];
in += tid.z * params.in_strides[0];
int out_o = tid.y * BN * TN + lid.y * TN;
int out_hw = tid.x * BM * TM + lid.x * TM;
int out_h[TM];
int out_w[TN];
for (int m = 0; m < TM; ++m) {
int mm = (out_hw + m);
out_h[m] = mm / params.oS[1];
out_w[m] = mm % params.oS[1];
}
T in_local[TM];
T wt_local[TN];
T out_local[TM * TN] = {T(0)};
for (int h = 0; h < params.wS[0]; ++h) {
for (int w = 0; w < params.wS[1]; ++w) {
for (int c = 0; c < params.C; ++c) {
// Local in
for (int m = 0; m < TM; m++) {
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0];
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1];
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
in_local[m] = valid
? in[i * params.in_strides[1] + j * params.in_strides[2] + c]
: T(0);
}
// Load weight
for (int n = 0; n < TN; ++n) {
int o = out_o + n;
wt_local[n] = o < params.O
? wt[o * params.wt_strides[0] + h * params.wt_strides[1] +
w * params.wt_strides[2] + c]
: T(0);
}
// Accumulate
for (int m = 0; m < TM; ++m) {
for (int n = 0; n < TN; ++n) {
out_local[m * TN + n] += in_local[m] * wt_local[n];
}
}
}
}
}
for (int m = 0; m < TM; ++m) {
for (int n = 0; n < TN; ++n) {
if (out_h[m] < params.oS[0] && out_w[m] < params.oS[1] &&
(out_o + n) < params.O)
out[out_h[m] * params.out_strides[1] +
out_w[m] * params.out_strides[2] + out_o + n] =
out_local[m * TN + n];
}
}
}
// Instantiations
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm \
"_tn" #tn)]] [[kernel]] void \
naive_conv_2d<itype, bm, bn, tm, tn>( \
const device itype* in [[buffer(0)]], \
const device itype* wt [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant MLXConvParams<2>& params [[buffer(3)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_naive_conv_2d_blocks(name, itype) \
instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \
instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4)
instantiate_naive_conv_2d_blocks(float32, float);
instantiate_naive_conv_2d_blocks(float16, half);
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Depthwise convolution kernels
///////////////////////////////////////////////////////////////////////////////

View File

@@ -262,37 +262,36 @@ struct GEMVKernel {
vec_mask_offset += vec_mask_step;
}
if (leftover > 0) {
if (!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset]))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
if (leftover > 0 &&
(!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset])))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
load_safe<AccT>(in_vec, v_coeff, bn, in_size);
load_safe<AccT>(in_vec, v_coeff, bn, in_size);
// Apply scale
if (has_mul_operand_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
v_coeff[tn] *= block_scale;
}
}
// Per thread work loop
// Apply scale
if (has_mul_operand_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
// Load for the row
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
for (int tn = 0; tn < TN; tn++) {
v_coeff[tn] *= block_scale;
}
}
// Accumulate results
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn];
}
// Per thread work loop
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
// Load for the row
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
// Accumulate results
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn];
}
}
}
@@ -545,32 +544,31 @@ struct GEMVTKernel {
vec_mask_offset += vec_mask_step;
}
if (leftover > 0) {
if (!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset]))) {
T block_scale{1};
if (leftover > 0 &&
(!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset])))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
v_coeff[tm] *= block_scale;
}
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}
if (has_mul_operand_mask) {
v_coeff[tm] *= block_scale;
}
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
}

View File

@@ -134,10 +134,6 @@ instantiate_and_or(and, And)
instantiate_and_or(or, Or)
#define instantiate_sum_prod(name, op) \
instantiate_reduce_functions(name, uint8, uint8_t, int32_t, op) \
instantiate_reduce_functions(name, uint16, uint16_t, uint32_t, op) \
instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \
instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \

View File

@@ -45,9 +45,7 @@ struct ThreadSort {
for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
if (op(vals[j + 1], vals[j])) {
thread_swap(vals[j + 1], vals[j]);
if (ARG_SORT) {
thread_swap(idxs[j + 1], idxs[j]);
}
thread_swap(idxs[j + 1], idxs[j]);
}
}
}
@@ -113,9 +111,7 @@ struct BlockMergeSort {
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
vals[i] = pred ? b : a;
if (ARG_SORT) {
idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
}
idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
b_idx += short(pred);
a_idx += short(!pred);

View File

@@ -247,25 +247,15 @@ std::pair<Dtype, Dtype> remap_reduce_types(
const std::string& op_name) {
if (op_name == "sum" || op_name == "prod") {
if (issubdtype(in.dtype(), integer)) {
switch (in.dtype()) {
case uint8:
return {uint8, uint32};
case uint16:
return {uint16, uint32};
case uint32:
return {uint32, uint32};
case uint64:
return {uint64, uint64};
case int8:
switch (in.dtype().size()) {
case 1:
return {int8, int32};
case int16:
case 2:
return {int16, int32};
case int32:
case 4:
return {int32, int32};
case int64:
case 8:
return {int64, int64};
default:
throw std::runtime_error("Unsupported integer type");
}
}
if (in.dtype() == bool_) {

View File

@@ -2381,20 +2381,9 @@ array logsumexp(
throw std::invalid_argument(
"[logsumexp] Received non-empty axes for array with 0 dimensions.");
}
bool reduce_last_dim =
!axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1);
if (reduce_last_dim) {
// For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape
// is [1, 1, ..., N].
for (int i = axes.size() - 2; i >= 0; --i) {
if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) {
reduce_last_dim = false;
break;
}
}
}
bool is_complex = issubdtype(a.dtype(), complexfloating);
if (!is_complex && reduce_last_dim) {
if (!is_complex && axes.size() == 1 &&
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
auto dtype = at_least_float(a.dtype());
auto out_shape = a.shape();
out_shape.back() = 1;
@@ -3414,20 +3403,10 @@ array softmax(
throw std::invalid_argument(
"[softmax] Received non-empty axes for array with 0 dimensions.");
}
bool reduce_last_dim =
!axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1);
if (reduce_last_dim) {
// For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape
// is [1, 1, ..., N].
for (int i = axes.size() - 2; i >= 0; --i) {
if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) {
reduce_last_dim = false;
break;
}
}
}
bool is_complex = issubdtype(a.dtype(), complexfloating);
if (!is_complex && reduce_last_dim) {
if (!is_complex && axes.size() == 1 &&
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
auto dtype = at_least_float(a.dtype());
return array(
a.shape(),

View File

@@ -440,7 +440,6 @@ class SmallVector {
end_ = begin_;
}
private:
// Grows the backing store by a factor of two, and at least to {min_capacity}.
// TODO: Move to private after removing external code using this method.
MLX_NOINLINE void grow(size_t min_capacity = 0) {
@@ -470,6 +469,7 @@ class SmallVector {
end_of_storage_ = new_storage + new_capacity;
}
private:
MLX_NOINLINE void free_storage() {
std::destroy_n(begin_, end_ - begin_);
if (is_big()) {
@@ -519,18 +519,6 @@ class SmallVector {
std::is_trivially_destructible<T>::value;
};
template <typename>
struct is_vector : std::false_type {};
template <typename T, size_t Size, typename Allocator>
struct is_vector<SmallVector<T, Size, Allocator>> : std::true_type {};
template <typename T, typename Allocator>
struct is_vector<std::vector<T, Allocator>> : std::true_type {};
template <typename Vec>
inline constexpr bool is_vector_v = is_vector<Vec>::value;
#undef MLX_HAS_BUILTIN
#undef MLX_HAS_ATTRIBUTE
#undef MLX_LIKELY

View File

@@ -259,6 +259,43 @@ std::ostream& operator<<(std::ostream& os, array a) {
return os;
}
std::ostream& operator<<(std::ostream& os, const SmallVector<int>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
}
os << ")";
return os;
}
std::ostream& operator<<(std::ostream& os, const SmallVector<int64_t>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
}
os << ")";
return os;
}
// TODO: Code is duplicated but they should be deleted soon.
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
}
os << ")";
return os;
}
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
}
os << ")";
return os;
}
namespace env {
int get_var(const char* name, int default_value) {

View File

@@ -100,6 +100,10 @@ std::ostream& operator<<(std::ostream& os, const Stream& s);
std::ostream& operator<<(std::ostream& os, const Dtype& d);
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
std::ostream& operator<<(std::ostream& os, array a);
std::ostream& operator<<(std::ostream& os, const SmallVector<int>& v);
std::ostream& operator<<(std::ostream& os, const SmallVector<int64_t>& v);
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
}
@@ -110,19 +114,6 @@ inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
return os << static_cast<float>(v);
}
template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>
inline std::ostream& operator<<(std::ostream& os, const Vec& v) {
os << "(";
for (auto it = v.begin(); it != v.end(); ++it) {
os << *it;
if (it != std::prev(v.end())) {
os << ",";
}
}
os << ")";
return os;
}
inline bool is_power_of_2(int n) {
return ((n & (n - 1)) == 0) && n != 0;
}

View File

@@ -3,8 +3,8 @@
#pragma once
#define MLX_VERSION_MAJOR 0
#define MLX_VERSION_MINOR 28
#define MLX_VERSION_PATCH 0
#define MLX_VERSION_MINOR 27
#define MLX_VERSION_PATCH 1
#define MLX_VERSION_NUMERIC \
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)

View File

@@ -2,6 +2,6 @@
requires = [
"setuptools>=80",
"nanobind==2.4.0",
"cmake>=3.25,<4.1",
"cmake>=3.25",
]
build-backend = "setuptools.build_meta"

View File

@@ -178,7 +178,7 @@ class Module(dict):
if strict:
new_weights = dict(weights)
curr_weights = tree_flatten(self.parameters(), destination={})
curr_weights = dict(tree_flatten(self.parameters()))
if extras := (new_weights.keys() - curr_weights.keys()):
num_extra = len(extras)
extras = ",\n".join(sorted(extras))
@@ -212,7 +212,7 @@ class Module(dict):
- ``.npz`` will use :func:`mx.savez`
- ``.safetensors`` will use :func:`mx.save_safetensors`
"""
params_dict = tree_flatten(self.parameters(), destination={})
params_dict = dict(tree_flatten(self.parameters()))
if file.endswith(".npz"):
mx.savez(file, **params_dict)

View File

@@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc.
from collections import defaultdict
from itertools import zip_longest
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple
def tree_map(
@@ -114,11 +114,8 @@ def tree_map_with_path(
def tree_flatten(
tree: Any,
prefix: str = "",
is_leaf: Optional[Callable] = None,
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None,
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None
) -> Any:
"""Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and
@@ -131,12 +128,9 @@ def tree_flatten(
print(tree_flatten([[[0]]]))
# [("0.0.0", 0)]
print(tree_flatten([[[0]]], prefix=".hello"))
print(tree_flatten([[[0]]], ".hello"))
# [("hello.0.0.0", 0)]
tree_flatten({"a": {"b": 1}}, destination={})
{"a.b": 1}
.. note::
Dictionaries should have keys that are valid Python identifiers.
@@ -146,50 +140,26 @@ def tree_flatten(
always discarded.
is_leaf (callable): An optional callable that returns True if the
passed object is considered a leaf or False otherwise.
destination (list or dict, optional): A list or dictionary to store the
flattened tree. If None an empty list will be used. Default: ``None``.
Returns:
Union[List[Tuple[str, Any]], Dict[str, Any]]: The flat representation of
the Python tree.
List[Tuple[str, Any]]: The flat representation of the Python tree.
"""
if destination is None:
destination = []
flat_tree = []
# Create the function to update the destination. We are taking advantage of
# the fact that list.extend and dict.update have the same API to simplify
# the code a bit.
if isinstance(destination, list):
_add_to_destination = destination.extend
elif isinstance(destination, dict):
_add_to_destination = destination.update
else:
raise ValueError("Destination should be either a list or a dictionary or None")
if is_leaf is None or not is_leaf(tree):
if isinstance(tree, (list, tuple)):
for i, t in enumerate(tree):
flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf))
return flat_tree
if isinstance(tree, dict):
for k, t in tree.items():
flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf))
return flat_tree
# Leaf identified by is_leaf so add it and return
if is_leaf is not None and is_leaf(tree):
_add_to_destination([(prefix[1:], tree)])
return destination
# List or tuple so recursively add each subtree
if isinstance(tree, (list, tuple)):
for i, item in enumerate(tree):
tree_flatten(item, f"{prefix}.{i}", is_leaf, destination)
return destination
# Dictionary so recursively add each subtree
if isinstance(tree, dict):
for key, value in tree.items():
tree_flatten(value, f"{prefix}.{key}", is_leaf, destination)
return destination
# Leaf so add it and return
_add_to_destination([(prefix[1:], tree)])
return destination
return [(prefix[1:], tree)]
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
"""Recreate a Python tree from its flat representation.
.. code-block:: python
@@ -200,34 +170,31 @@ def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
print(d)
# {"hello": {"world": 42}}
d = tree_unflatten({"hello.world": 42})
print(d)
# {"hello": {"world": 42}}
Args:
tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree.
tree (list[tuple[str, Any]]): The flat representation of a Python tree.
For instance as returned by :meth:`tree_flatten`.
Returns:
A Python tree.
"""
items = tree.items() if isinstance(tree, dict) else tree
if len(tree) == 1 and tree[0][0] == "":
return tree[0][1]
# Special case when we have just one element in the tree ie not a tree
if len(items) == 1:
key, value = next(iter(items))
if key == "":
return value
try:
int(tree[0][0].split(".", maxsplit=1)[0])
is_list = True
except ValueError:
is_list = False
# collect children
children = defaultdict(list)
for key, value in items:
for key, value in tree:
current_idx, *next_idx = key.split(".", maxsplit=1)
next_idx = "" if not next_idx else next_idx[0]
children[current_idx].append((next_idx, value))
# Assume they are a list and fail to dict if the keys are not all integers
try:
# recursively map them to the original container
if is_list:
keys = sorted((int(idx), idx) for idx in children.keys())
l = []
for i, k in keys:
@@ -235,7 +202,7 @@ def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
l.extend([{} for _ in range(i - len(l))])
l.append(tree_unflatten(children[k]))
return l
except ValueError:
else:
return {k: tree_unflatten(v) for k, v in children.items()}

View File

@@ -17,6 +17,7 @@ cuda_skip = {
"TestConv.test_1d_conv_with_2d",
"TestConv.test_conv_1d_groups_flipped",
"TestConv.test_conv_general_flip_grad",
"TestConv.test_conv_groups_grad",
"TestConv.test_torch_conv_2D",
"TestConv.test_torch_conv_depthwise",
"TestConv.test_torch_conv_general",

View File

@@ -80,7 +80,7 @@ class TestBase(mlx_tests.MLXTestCase):
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
model = DictModule()
params = tree_flatten(model.parameters(), destination={})
params = dict(tree_flatten(model.parameters()))
self.assertEqual(len(params), 2)
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))

View File

@@ -155,19 +155,6 @@ TEST_CASE("test gpu reduce") {
CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1);
}
// sum and prod overflow
{
auto a = full({256, 2, 2}, 1u, uint8);
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 256 * 4);
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
a = full({65535, 2, 2}, 1u, uint16);
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 65535 * 4);
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
}
}
TEST_CASE("test gpu reduce with axes") {
// reducing only some axes and irregular layouts
{
array a(1.0f);

View File

@@ -915,23 +915,6 @@ TEST_CASE("test reduction ops") {
CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item<bool>());
}
// Test unsigned sum
{
const int num_elems = 1000;
auto x = astype(full({num_elems}, 255), uint8);
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 255 * num_elems);
x = astype(full({num_elems}, 65535), uint16);
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 65535 * num_elems);
x = full({3, 3, 3}, 10000, uint32);
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 270000);
x = full({3, 3, 3}, 10000, uint64);
CHECK_EQ(sum(x, Device::cpu).item<uint64_t>(), 270000);
}
// Test prod
{
auto x = array({});
@@ -964,21 +947,6 @@ TEST_CASE("test reduction ops") {
CHECK(array_equal(prod(x, 1), array({true, false})).item<bool>());
}
// Test unsigned prod
{
auto x = array({255, 255}, {2}, uint8);
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 65025);
x = array({65535, 2}, {2}, uint16);
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 131070);
x = array({100000, 2}, {2}, uint32);
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 200000);
x = array({100000, 2}, {2}, uint64);
CHECK_EQ(prod(x, Device::cpu).item<uint64_t>(), 200000);
}
// Test all
{
auto x = array({});