Compare commits

..

8 Commits

Author SHA1 Message Date
Abe Leininger
fce53b61d6 Fix reduce sum/prod overflow (#2477) 2025-08-12 00:05:33 -07:00
Angelos Katharopoulos
8ae4a76308 Use CMake <4.1 to avoid the nvpl error (#2489) 2025-08-12 00:03:42 -07:00
Cheng
7fde1b6a1e Fix logsumexp/softmax not fused for some cases (#2474) 2025-08-08 14:07:17 -07:00
Cheng
aa7b47481a [CUDA] Optimize set_mm_device_pointers for small ndim (#2473) 2025-08-08 15:23:30 +09:00
Awni Hannun
56be773610 version (#2470) 2025-08-07 00:36:04 -07:00
Jagrit Digani
a9bdd67baa Add CUDA sdpa vector (#2468) 2025-08-06 21:40:26 -07:00
Angelos Katharopoulos
f2adb5638d Fix typo in metal command encoder (#2471) 2025-08-06 16:58:23 -07:00
Luca Vivona
728d4db582 Support destination arg in tree flatten/unflatten (#2450) 2025-08-06 15:34:59 -07:00
16 changed files with 348 additions and 445 deletions

View File

@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
optimizer.update(model, grads) optimizer.update(model, grads)
# Save the state # Save the state
state = tree_flatten(optimizer.state) state = tree_flatten(optimizer.state, destination={})
mx.save_safetensors("optimizer.safetensors", dict(state)) mx.save_safetensors("optimizer.safetensors", state)
# Later on, for example when loading from a checkpoint, # Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state # recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2) optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(list(mx.load("optimizer.safetensors").items())) state = tree_unflatten(mx.load("optimizer.safetensors"))
optimizer.state = state optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For Note, not every optimizer configuation parameter is saved in the state. For

View File

@@ -7,17 +7,17 @@ Exporting Functions
MLX has an API to export and import functions to and from a file. This lets you MLX has an API to export and import functions to and from a file. This lets you
run computations written in one MLX front-end (e.g. Python) in another MLX run computations written in one MLX front-end (e.g. Python) in another MLX
front-end (e.g. C++). front-end (e.g. C++).
This guide walks through the basics of the MLX export API with some examples. This guide walks through the basics of the MLX export API with some examples.
To see the full list of functions check-out the :ref:`API documentation To see the full list of functions check-out the :ref:`API documentation
<export>`. <export>`.
Basics of Exporting Basics of Exporting
------------------- -------------------
Let's start with a simple example: Let's start with a simple example:
.. code-block:: python .. code-block:: python
def fun(x, y): def fun(x, y):
@@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
x = mx.array(1.0) x = mx.array(1.0)
y = mx.array(1.0) y = mx.array(1.0)
# Both arguments to fun are positional # Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y) mx.export_function("add.mlxfn", fun, x, y)
@@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
For enclosed arrays inside an exported function, be extra careful to ensure For enclosed arrays inside an exported function, be extra careful to ensure
they are evaluated. The computation graph that gets exported will include they are evaluated. The computation graph that gets exported will include
the computation that produces enclosed inputs. the computation that produces enclosed inputs.
If the above example was missing ``mx.eval(model.parameters()``, the If the above example was missing ``mx.eval(model.parameters()``, the
exported function would include the random initialization of the exported function would include the random initialization of the
:obj:`mlx.nn.Module` parameters. :obj:`mlx.nn.Module` parameters.
@@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
# Set the model's parameters to the input parameters # Set the model's parameters to the input parameters
model.update(tree_unflatten(list(params.items()))) model.update(tree_unflatten(list(params.items())))
return model(x) return model(x)
params = dict(tree_flatten(model.parameters())) params = tree_flatten(model.parameters(), destination={})
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params) mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
@@ -169,8 +169,8 @@ to export a function which can be used for inputs with variable shapes:
# Ok # Ok
out, = imported_abs(mx.array(-1.0)) out, = imported_abs(mx.array(-1.0))
# Also ok # Also ok
out, = imported_abs(mx.array([-1.0, -2.0])) out, = imported_abs(mx.array([-1.0, -2.0]))
With ``shapeless=False`` (which is the default), the second call to With ``shapeless=False`` (which is the default), the second call to
@@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
def fun(x, y=None): def fun(x, y=None):
constant = mx.array(3.0) constant = mx.array(3.0)
if y is not None: if y is not None:
x += y x += y
return x + constant return x + constant
with mx.exporter("fun.mlxfn", fun) as exporter: with mx.exporter("fun.mlxfn", fun) as exporter:
@@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
print(out) print(out)
In the above example the function constant data, (i.e. ``constant``), is only In the above example the function constant data, (i.e. ``constant``), is only
saved once. saved once.
Transformations with Imported Functions Transformations with Imported Functions
--------------------------------------- ---------------------------------------
@@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
# Prints: array(1, dtype=float32) # Prints: array(1, dtype=float32)
print(dfdx(x)) print(dfdx(x))
# Compile the imported function # Compile the imported function
mx.compile(imported_fun) mx.compile(imported_fun)
# Prints: array(0, dtype=float32) # Prints: array(0, dtype=float32)
print(compiled_fun(x)[0]) print(compiled_fun(x)[0])
@@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
// Prints: array(2, dtype=float32) // Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl; std::cout << outputs[0] << std::endl;
Imported functions can be transformed in C++ just like in Python. Use Imported functions can be transformed in C++ just like in Python. Use
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string, ``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
mx::array>`` for keyword arguments when calling imported functions in C++. mx::array>`` for keyword arguments when calling imported functions in C++.

View File

@@ -491,19 +491,27 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
switch (in.dtype()) { switch (in.dtype()) {
case bool_: case bool_:
case uint8: case uint8:
reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8: case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break; break;
case int16: case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break; break;
case int32: case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break; break;
case int64: case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break; break;
case float16: case float16:

View File

@@ -10,7 +10,34 @@ namespace mlx::core::cu {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
__global__ void set_mm_device_pointers( template <int NDIM>
__global__ void set_mm_device_pointers_nd(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* out_start,
int item_size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
int64_t batch_stride,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset] = elem_to_loc_nd<NDIM>(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data());
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_mm_device_pointers_g(
int8_t** pointers, int8_t** pointers,
int8_t* a_start, int8_t* a_start,
int8_t* b_start, int8_t* b_start,
@@ -38,7 +65,38 @@ __global__ void set_mm_device_pointers(
out_start + item_size * index * batch_stride; out_start + item_size * index * batch_stride;
} }
__global__ void set_addmm_device_pointers( template <int NDIM>
__global__ void set_addmm_device_pointers_nd(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* c_start,
int8_t* out_start,
int item_size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> c_batch_strides,
int64_t batch_stride,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset, c_offset] = elem_to_loc_nd<NDIM>(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
c_batch_strides.data());
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
pointers[index + 3 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_addmm_device_pointers_g(
int8_t** pointers, int8_t** pointers,
int8_t* a_start, int8_t* a_start,
int8_t* b_start, int8_t* b_start,
@@ -89,37 +147,62 @@ void Matmul::run_batched(
const mlx::core::Shape& batch_shape, const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides, const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides) { const mlx::core::Strides& b_batch_strides) {
auto batch_count = out.size() / (M_ * N_); int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count); set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count); set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count); set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets // Launch kernel to set device offsets
auto pointers = array( auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 3), allocator::malloc(batch_count * sizeof(void*) * 3),
{static_cast<int>(batch_count * 3)}, {batch_count * 3},
uint64); uint64);
encoder.add_temporary(pointers); encoder.add_temporary(pointers);
int block_size = 512;
encoder.set_output_array(pointers); encoder.set_output_array(pointers);
encoder.add_kernel_node( int block_dims = std::min(batch_count, 256);
cu::set_mm_device_pointers, int num_blocks = cuda::ceil_div(batch_count, block_dims);
cuda::ceil_div(pointers.size(), block_size), int64_t batch_stride = M_ * N_;
block_size, int item_size = out.itemsize();
0,
pointers.data<int8_t*>(), int ndim = batch_shape.size();
a.data<int8_t>(), if (ndim <= 3) {
b.data<int8_t>(), dispatch_1_2_3(ndim, [&](auto ndim_constant) {
out.data<int8_t>(), encoder.add_kernel_node(
static_cast<int>(out.dtype().size()), cu::set_mm_device_pointers_nd<ndim_constant()>,
const_param(batch_shape), num_blocks,
const_param(a_batch_strides), block_dims,
const_param(b_batch_strides), 0,
static_cast<int64_t>(M_) * N_, pointers.data<int8_t*>(),
static_cast<int>(batch_shape.size()), a.data<int8_t>(),
batch_count); b.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
const_param<ndim_constant()>(b_batch_strides),
batch_stride,
batch_count);
});
} else {
encoder.add_kernel_node(
cu::set_mm_device_pointers_g,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
batch_stride,
ndim,
batch_count);
}
// Run matmul // Run matmul
encoder.set_input_array(pointers); encoder.set_input_array(pointers);
@@ -150,7 +233,7 @@ void Matmul::run_batched(
const mlx::core::Strides& c_batch_strides, const mlx::core::Strides& c_batch_strides,
float alpha, float alpha,
float beta) { float beta) {
auto batch_count = out.size() / (M_ * N_); int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count); set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count); set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(c_desc_, batch_count); set_pointer_mode(c_desc_, batch_count);
@@ -159,30 +242,58 @@ void Matmul::run_batched(
// Launch kernel to set device offsets // Launch kernel to set device offsets
auto pointers = array( auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 4), allocator::malloc(batch_count * sizeof(uint64_t) * 4),
{static_cast<int>(batch_count * 4)}, {batch_count * 4},
uint64); uint64);
encoder.add_temporary(pointers); encoder.add_temporary(pointers);
int block_size = 512;
encoder.set_output_array(pointers); encoder.set_output_array(pointers);
encoder.add_kernel_node(
cu::set_addmm_device_pointers, int block_dims = std::min(batch_count, 256);
cuda::ceil_div(pointers.size(), block_size), int num_blocks = cuda::ceil_div(batch_count, block_dims);
block_size, int64_t batch_stride = M_ * N_;
0, int item_size = out.itemsize();
pointers.data<int8_t*>(),
a.data<int8_t>(), int ndim = batch_shape.size();
b.data<int8_t>(), if (ndim <= 3) {
c.data<int8_t>(), dispatch_1_2_3(ndim, [&](auto ndim_constant) {
out.data<int8_t>(), encoder.add_kernel_node(
static_cast<int>(out.dtype().size()), cu::set_addmm_device_pointers_nd<ndim_constant()>,
const_param(batch_shape), num_blocks,
const_param(a_batch_strides), block_dims,
const_param(b_batch_strides), 0,
const_param(c_batch_strides), pointers.data<int8_t*>(),
static_cast<int64_t>(M_) * N_, a.data<int8_t>(),
static_cast<int>(batch_shape.size()), b.data<int8_t>(),
batch_count); c.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
const_param<ndim_constant()>(b_batch_strides),
const_param<ndim_constant()>(c_batch_strides),
batch_stride,
batch_count);
});
} else {
encoder.add_kernel_node(
cu::set_addmm_device_pointers_g,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
const_param(c_batch_strides),
batch_stride,
ndim,
batch_count);
}
// Run matmul // Run matmul
encoder.set_input_array(pointers); encoder.set_input_array(pointers);

View File

@@ -8,19 +8,13 @@
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#include "mlx/transforms_impl.h"
// cudnn_frontend.h redefines this macro.
#undef CHECK_CUDA_ERROR
#include <cudnn_frontend.h>
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <cooperative_groups.h> #include <cooperative_groups.h>
#include <cooperative_groups/reduce.h> #include <cooperative_groups/reduce.h>
namespace fe = cudnn_frontend;
namespace mlx::core { namespace mlx::core {
namespace cu { namespace cu {
@@ -645,294 +639,6 @@ void sdpa_vector_fallback(
} }
} }
struct SDPACacheKey {
int device_id;
fe::DataType_t cudnn_type;
int B;
int H;
int D;
int qL;
int kL;
int gqa_factor;
float scale;
int64_t Q_strides[3];
int64_t K_strides[3];
int64_t V_strides[3];
int64_t O_strides[3];
bool generate_stats;
bool causal_mask;
};
auto& sdpa_cache() {
static LRUBytesKeyCache<SDPACacheKey, std::shared_ptr<fe::graph::Graph>>
cache(
/* capacity */ 128);
return cache;
}
#define Q_UID 1
#define K_UID 2
#define V_UID 3
#define O_UID 4
#define STATS_UID 5
std::shared_ptr<fe::graph::Graph> get_sdpa_forward_graph(
cu::CommandEncoder& encoder,
const SDPACacheKey& cache_key) {
// Check if graph has already been fully built
if (auto it = sdpa_cache().find(cache_key); it != sdpa_cache().end()) {
return it->second;
}
// Set up new graph
auto graph = std::make_shared<fe::graph::Graph>();
graph->set_io_data_type(cache_key.cudnn_type)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
auto Q = graph->tensor(
fe::graph::Tensor_attributes()
.set_name("Q")
.set_uid(Q_UID)
.set_dim({cache_key.B, cache_key.H, cache_key.qL, cache_key.D})
.set_stride(
{cache_key.Q_strides[0],
cache_key.Q_strides[1],
cache_key.Q_strides[2],
1}));
int h_kv = cache_key.H / cache_key.gqa_factor;
auto K =
graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_uid(K_UID)
.set_dim({cache_key.B, h_kv, cache_key.kL, cache_key.D})
.set_stride(
{cache_key.K_strides[0],
cache_key.K_strides[1],
cache_key.V_strides[2],
1}));
auto V =
graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_uid(V_UID)
.set_dim({cache_key.B, h_kv, cache_key.kL, cache_key.D})
.set_stride(
{cache_key.V_strides[0],
cache_key.V_strides[1],
cache_key.V_strides[2],
1}));
auto sdpa_options = fe::graph::SDPA_attributes()
.set_name("flash_attention")
.set_is_inference(!cache_key.generate_stats)
.set_attn_scale(cache_key.scale);
if (cache_key.causal_mask && cache_key.qL > 1) {
sdpa_options.set_diagonal_alignment(fe::DiagonalAlignment_t::TOP_LEFT)
.set_diagonal_band_right_bound(0);
}
auto [O, Stats] = graph->sdpa(Q, K, V, sdpa_options);
O->set_output(true)
.set_uid(O_UID)
.set_dim({cache_key.B, cache_key.H, cache_key.qL, cache_key.D})
.set_stride(
{cache_key.O_strides[0],
cache_key.O_strides[1],
cache_key.O_strides[2],
1});
if (cache_key.generate_stats) {
Stats->set_output(true)
.set_data_type(fe::DataType_t::FLOAT)
.set_uid(STATS_UID);
}
// Build and Validate cudnn graph
auto handle = encoder.device().cudnn_handle();
// cuDNN only supports native CUDA graphs for sdpa in 9.6 or above.
if (cudnnGetVersion() < 90600) {
auto build_status = graph->build(handle, {fe::HeurMode_t::A});
if (!build_status.is_good()) {
throw std::runtime_error(
"Unable to build cudnn graph for attention."
" Failed with message: " +
build_status.get_message());
}
} else {
auto val_status = graph->validate();
auto op_status = graph->build_operation_graph(handle);
auto plan_stauts =
graph->create_execution_plans({cudnn_frontend::HeurMode_t::A});
if (!plan_stauts.is_good()) {
throw std::runtime_error(
"Unable to create exec plan for cudnn attention."
" Failed with message: " +
plan_stauts.get_message());
}
graph->select_behavior_notes(
{cudnn_frontend::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
auto support_status = graph->check_support(handle);
if (!support_status.is_good()) {
throw std::runtime_error(
"No cuda graph support for cudnn attention."
" Failed with message: " +
support_status.get_message());
}
auto build_status = graph->build_plans(handle);
if (!build_status.is_good()) {
throw std::runtime_error(
"Unable to build cudnn graph for attention."
" Failed with message: " +
build_status.get_message());
}
}
auto [it, _] = sdpa_cache().emplace(cache_key, graph);
return it->second;
}
inline fe::DataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return fe::DataType_t::INT8;
case int32:
return fe::DataType_t::INT32;
case uint8:
return fe::DataType_t::UINT8;
case float16:
return fe::DataType_t::HALF;
case bfloat16:
return fe::DataType_t::BFLOAT16;
case float32:
return fe::DataType_t::FLOAT;
case float64:
return fe::DataType_t::DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in SDPA: {}.", dtype_to_string(dtype)));
}
}
void sdpa_cudnn(
const Stream& s,
cu::CommandEncoder& encoder,
const array& q,
const array& k,
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
encoder.set_output_array(o);
auto cudnn_type = dtype_to_cudnn_type(q.dtype());
int B = q.shape(0);
int H = q.shape(1);
int D = q.shape(3);
int gqa_factor = q.shape(1) / k.shape(1);
int qL = q.shape(2);
int kL = k.shape(2);
SDPACacheKey cache_key{
/* int device_id = */ encoder.device().cuda_device(),
/* fe::DataType_t cudnn_type = */ cudnn_type,
/* int B = */ B,
/* int H = */ H,
/* int D = */ D,
/* int qL = */ qL,
/* int kL = */ kL,
/* int gqa_factor = */ gqa_factor,
/* float scale = */ scale,
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)},
/* bool generate_stats = */ false,
/* bool causal_mask = */ do_causal_};
auto graph = get_sdpa_forward_graph(encoder, cache_key);
int64_t workspace_size = 0;
auto workspace_status = graph->get_workspace_size(workspace_size);
if (!workspace_status.is_good()) {
throw std::runtime_error("Unable to get workspace for cudnn attention.");
}
array workspace(
allocator::malloc(workspace_size), {int(workspace_size)}, uint8);
auto workspace_ptr = workspace.data<void>();
std::unordered_map<int64_t, void*> variant_pack = {
{Q_UID, const_cast<void*>(q.data<void>())},
{K_UID, const_cast<void*>(k.data<void>())},
{V_UID, const_cast<void*>(v.data<void>())},
{O_UID, o.data<void>()}};
auto handle = encoder.device().cudnn_handle();
cudnnSetStream(handle, encoder.stream());
// cuDNN only supports native CUDA graphs for sdpa in 9.6 or above.
if (cudnnGetVersion() < 90600) {
auto capture = encoder.capture_context();
auto exec_status = graph->execute(handle, variant_pack, workspace_ptr);
if (!exec_status.is_good()) {
capture.discard = true;
throw std::runtime_error(
"Unable to execute cudnn attention."
" Failed with message: " +
exec_status.get_message());
}
} else {
cudaGraph_t cu_graph;
cudaGraphCreate(&cu_graph, 0);
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&cu_graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
auto cu_graph_status = graph->populate_cuda_graph(
handle, variant_pack, workspace_ptr, cu_graph);
if (!cu_graph_status.is_good()) {
throw std::runtime_error(
"Unable to add cuda graph for cudnn attention."
" Failed with message: " +
cu_graph_status.get_message());
}
encoder.add_graph_node(cu_graph);
}
encoder.add_temporary(workspace);
}
} // namespace } // namespace
namespace fast { namespace fast {
@@ -945,6 +651,9 @@ bool ScaledDotProductAttention::use_fallback(
bool has_arr_mask, bool has_arr_mask,
bool do_causal, bool do_causal,
Stream s) { Stream s) {
if (detail::in_grad_tracing()) {
return true;
}
if (s.device == Device::cpu) { if (s.device == Device::cpu) {
return true; return true;
} }
@@ -960,15 +669,7 @@ bool ScaledDotProductAttention::use_fallback(
const bool supported_vector_config = const bool supported_vector_config =
sdpa_supported_head_dim && query_sequence_length < 4; sdpa_supported_head_dim && query_sequence_length < 4;
auto& cu_device = cu::device(s.device); const bool supported_config = supported_vector_config;
const bool supported_matrix_config = query_sequence_length > 4 &&
cu_device.compute_capability_major() >= 8 &&
query_sequence_length == key_sequence_length &&
(q.dtype() == float16 || q.dtype() == bfloat16);
const bool supported_config =
(supported_matrix_config || supported_vector_config);
return has_arr_mask || !supported_config; return has_arr_mask || !supported_config;
} }
@@ -1002,10 +703,6 @@ void ScaledDotProductAttention::eval_gpu(
} }
}; };
auto is_matrix_contiguous = [](const array& arr) {
return arr.strides(-1) == 1;
};
// We are in vector mode ie single query // We are in vector mode ie single query
if (q_pre.shape(2) < 4) { if (q_pre.shape(2) < 4) {
auto q_copy_unless = [](const array& arr) { auto q_copy_unless = [](const array& arr) {
@@ -1059,7 +756,7 @@ void ScaledDotProductAttention::eval_gpu(
array::Flags flags{ array::Flags flags{
/* bool contiguous = */ 1, /* bool contiguous = */ 1,
/* bool row_contiguous = */ 0, /* bool row_contiguous = */ o.shape(2) == 1,
/* bool col_contiguous = */ 0, /* bool col_contiguous = */ 0,
}; };
@@ -1073,35 +770,9 @@ void ScaledDotProductAttention::eval_gpu(
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_); return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
} }
// Full attention mode // Full attention mode should never reach here
else { else {
const auto& q = copy_unless(is_matrix_contiguous, q_pre); throw std::runtime_error("Doesn't support matrix yet.");
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
for (const auto& cp : copies) {
encoder.add_temporary(cp);
}
int64_t str_oD = 1;
int64_t str_oH = o.shape(3);
int64_t str_oL = o.shape(1) * str_oH;
int64_t str_oB = o.shape(2) * str_oL;
size_t data_size = o.shape(0) * str_oB;
array::Flags flags{
/* bool contiguous = */ 1,
/* bool row_contiguous = */ 0,
/* bool col_contiguous = */ 0,
};
o.set_data(
allocator::malloc(o.nbytes()),
data_size,
{str_oB, str_oH, str_oL, str_oD},
flags);
return sdpa_cudnn(s, encoder, q, k, v, scale_, o, do_causal_);
} }
} }

View File

@@ -104,7 +104,7 @@ struct CommandEncoder {
}; };
// Outputs of all kernels in the encoder including temporaries // Outputs of all kernels in the encoder including temporaries
std::unordered_set<const void*> outputs() { std::unordered_set<const void*>& outputs() {
return all_outputs_; return all_outputs_;
}; };

View File

@@ -134,6 +134,10 @@ instantiate_and_or(and, And)
instantiate_and_or(or, Or) instantiate_and_or(or, Or)
#define instantiate_sum_prod(name, op) \ #define instantiate_sum_prod(name, op) \
instantiate_reduce_functions(name, uint8, uint8_t, int32_t, op) \
instantiate_reduce_functions(name, uint16, uint16_t, uint32_t, op) \
instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \
instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \ instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \ instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \ instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \

View File

@@ -247,15 +247,25 @@ std::pair<Dtype, Dtype> remap_reduce_types(
const std::string& op_name) { const std::string& op_name) {
if (op_name == "sum" || op_name == "prod") { if (op_name == "sum" || op_name == "prod") {
if (issubdtype(in.dtype(), integer)) { if (issubdtype(in.dtype(), integer)) {
switch (in.dtype().size()) { switch (in.dtype()) {
case 1: case uint8:
return {uint8, uint32};
case uint16:
return {uint16, uint32};
case uint32:
return {uint32, uint32};
case uint64:
return {uint64, uint64};
case int8:
return {int8, int32}; return {int8, int32};
case 2: case int16:
return {int16, int32}; return {int16, int32};
case 4: case int32:
return {int32, int32}; return {int32, int32};
case 8: case int64:
return {int64, int64}; return {int64, int64};
default:
throw std::runtime_error("Unsupported integer type");
} }
} }
if (in.dtype() == bool_) { if (in.dtype() == bool_) {

View File

@@ -2381,9 +2381,20 @@ array logsumexp(
throw std::invalid_argument( throw std::invalid_argument(
"[logsumexp] Received non-empty axes for array with 0 dimensions."); "[logsumexp] Received non-empty axes for array with 0 dimensions.");
} }
bool reduce_last_dim =
!axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1);
if (reduce_last_dim) {
// For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape
// is [1, 1, ..., N].
for (int i = axes.size() - 2; i >= 0; --i) {
if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) {
reduce_last_dim = false;
break;
}
}
}
bool is_complex = issubdtype(a.dtype(), complexfloating); bool is_complex = issubdtype(a.dtype(), complexfloating);
if (!is_complex && axes.size() == 1 && if (!is_complex && reduce_last_dim) {
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
auto out_shape = a.shape(); auto out_shape = a.shape();
out_shape.back() = 1; out_shape.back() = 1;
@@ -3403,10 +3414,20 @@ array softmax(
throw std::invalid_argument( throw std::invalid_argument(
"[softmax] Received non-empty axes for array with 0 dimensions."); "[softmax] Received non-empty axes for array with 0 dimensions.");
} }
bool reduce_last_dim =
!axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1);
if (reduce_last_dim) {
// For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape
// is [1, 1, ..., N].
for (int i = axes.size() - 2; i >= 0; --i) {
if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) {
reduce_last_dim = false;
break;
}
}
}
bool is_complex = issubdtype(a.dtype(), complexfloating); bool is_complex = issubdtype(a.dtype(), complexfloating);
if (!is_complex && axes.size() == 1 && if (!is_complex && reduce_last_dim) {
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
return array( return array(
a.shape(), a.shape(),

View File

@@ -3,8 +3,8 @@
#pragma once #pragma once
#define MLX_VERSION_MAJOR 0 #define MLX_VERSION_MAJOR 0
#define MLX_VERSION_MINOR 27 #define MLX_VERSION_MINOR 28
#define MLX_VERSION_PATCH 1 #define MLX_VERSION_PATCH 0
#define MLX_VERSION_NUMERIC \ #define MLX_VERSION_NUMERIC \
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)

View File

@@ -2,6 +2,6 @@
requires = [ requires = [
"setuptools>=80", "setuptools>=80",
"nanobind==2.4.0", "nanobind==2.4.0",
"cmake>=3.25", "cmake>=3.25,<4.1",
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"

View File

@@ -178,7 +178,7 @@ class Module(dict):
if strict: if strict:
new_weights = dict(weights) new_weights = dict(weights)
curr_weights = dict(tree_flatten(self.parameters())) curr_weights = tree_flatten(self.parameters(), destination={})
if extras := (new_weights.keys() - curr_weights.keys()): if extras := (new_weights.keys() - curr_weights.keys()):
num_extra = len(extras) num_extra = len(extras)
extras = ",\n".join(sorted(extras)) extras = ",\n".join(sorted(extras))
@@ -212,7 +212,7 @@ class Module(dict):
- ``.npz`` will use :func:`mx.savez` - ``.npz`` will use :func:`mx.savez`
- ``.safetensors`` will use :func:`mx.save_safetensors` - ``.safetensors`` will use :func:`mx.save_safetensors`
""" """
params_dict = dict(tree_flatten(self.parameters())) params_dict = tree_flatten(self.parameters(), destination={})
if file.endswith(".npz"): if file.endswith(".npz"):
mx.savez(file, **params_dict) mx.savez(file, **params_dict)

View File

@@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
from collections import defaultdict from collections import defaultdict
from itertools import zip_longest from itertools import zip_longest
from typing import Any, Callable, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple, Union
def tree_map( def tree_map(
@@ -114,8 +114,11 @@ def tree_map_with_path(
def tree_flatten( def tree_flatten(
tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None tree: Any,
) -> Any: prefix: str = "",
is_leaf: Optional[Callable] = None,
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None,
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
"""Flattens a Python tree to a list of key, value tuples. """Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and The keys are using the dot notation to define trees of arbitrary depth and
@@ -128,9 +131,12 @@ def tree_flatten(
print(tree_flatten([[[0]]])) print(tree_flatten([[[0]]]))
# [("0.0.0", 0)] # [("0.0.0", 0)]
print(tree_flatten([[[0]]], ".hello")) print(tree_flatten([[[0]]], prefix=".hello"))
# [("hello.0.0.0", 0)] # [("hello.0.0.0", 0)]
tree_flatten({"a": {"b": 1}}, destination={})
{"a.b": 1}
.. note:: .. note::
Dictionaries should have keys that are valid Python identifiers. Dictionaries should have keys that are valid Python identifiers.
@@ -140,26 +146,50 @@ def tree_flatten(
always discarded. always discarded.
is_leaf (callable): An optional callable that returns True if the is_leaf (callable): An optional callable that returns True if the
passed object is considered a leaf or False otherwise. passed object is considered a leaf or False otherwise.
destination (list or dict, optional): A list or dictionary to store the
flattened tree. If None an empty list will be used. Default: ``None``.
Returns: Returns:
List[Tuple[str, Any]]: The flat representation of the Python tree. Union[List[Tuple[str, Any]], Dict[str, Any]]: The flat representation of
the Python tree.
""" """
flat_tree = [] if destination is None:
destination = []
if is_leaf is None or not is_leaf(tree): # Create the function to update the destination. We are taking advantage of
if isinstance(tree, (list, tuple)): # the fact that list.extend and dict.update have the same API to simplify
for i, t in enumerate(tree): # the code a bit.
flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf)) if isinstance(destination, list):
return flat_tree _add_to_destination = destination.extend
if isinstance(tree, dict): elif isinstance(destination, dict):
for k, t in tree.items(): _add_to_destination = destination.update
flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf)) else:
return flat_tree raise ValueError("Destination should be either a list or a dictionary or None")
return [(prefix[1:], tree)] # Leaf identified by is_leaf so add it and return
if is_leaf is not None and is_leaf(tree):
_add_to_destination([(prefix[1:], tree)])
return destination
# List or tuple so recursively add each subtree
if isinstance(tree, (list, tuple)):
for i, item in enumerate(tree):
tree_flatten(item, f"{prefix}.{i}", is_leaf, destination)
return destination
# Dictionary so recursively add each subtree
if isinstance(tree, dict):
for key, value in tree.items():
tree_flatten(value, f"{prefix}.{key}", is_leaf, destination)
return destination
# Leaf so add it and return
_add_to_destination([(prefix[1:], tree)])
return destination
def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any: def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
"""Recreate a Python tree from its flat representation. """Recreate a Python tree from its flat representation.
.. code-block:: python .. code-block:: python
@@ -170,31 +200,34 @@ def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
print(d) print(d)
# {"hello": {"world": 42}} # {"hello": {"world": 42}}
d = tree_unflatten({"hello.world": 42})
print(d)
# {"hello": {"world": 42}}
Args: Args:
tree (list[tuple[str, Any]]): The flat representation of a Python tree. tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree.
For instance as returned by :meth:`tree_flatten`. For instance as returned by :meth:`tree_flatten`.
Returns: Returns:
A Python tree. A Python tree.
""" """
if len(tree) == 1 and tree[0][0] == "": items = tree.items() if isinstance(tree, dict) else tree
return tree[0][1]
try: # Special case when we have just one element in the tree ie not a tree
int(tree[0][0].split(".", maxsplit=1)[0]) if len(items) == 1:
is_list = True key, value = next(iter(items))
except ValueError: if key == "":
is_list = False return value
# collect children # collect children
children = defaultdict(list) children = defaultdict(list)
for key, value in tree: for key, value in items:
current_idx, *next_idx = key.split(".", maxsplit=1) current_idx, *next_idx = key.split(".", maxsplit=1)
next_idx = "" if not next_idx else next_idx[0] next_idx = "" if not next_idx else next_idx[0]
children[current_idx].append((next_idx, value)) children[current_idx].append((next_idx, value))
# recursively map them to the original container # Assume they are a list and fail to dict if the keys are not all integers
if is_list: try:
keys = sorted((int(idx), idx) for idx in children.keys()) keys = sorted((int(idx), idx) for idx in children.keys())
l = [] l = []
for i, k in keys: for i, k in keys:
@@ -202,7 +235,7 @@ def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
l.extend([{} for _ in range(i - len(l))]) l.extend([{} for _ in range(i - len(l))])
l.append(tree_unflatten(children[k])) l.append(tree_unflatten(children[k]))
return l return l
else: except ValueError:
return {k: tree_unflatten(v) for k, v in children.items()} return {k: tree_unflatten(v) for k, v in children.items()}

View File

@@ -80,7 +80,7 @@ class TestBase(mlx_tests.MLXTestCase):
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))} self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
model = DictModule() model = DictModule()
params = dict(tree_flatten(model.parameters())) params = tree_flatten(model.parameters(), destination={})
self.assertEqual(len(params), 2) self.assertEqual(len(params), 2)
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2)))) self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2)))) self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))

View File

@@ -155,6 +155,19 @@ TEST_CASE("test gpu reduce") {
CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1); CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1);
} }
// sum and prod overflow
{
auto a = full({256, 2, 2}, 1u, uint8);
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 256 * 4);
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
a = full({65535, 2, 2}, 1u, uint16);
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 65535 * 4);
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
}
}
TEST_CASE("test gpu reduce with axes") {
// reducing only some axes and irregular layouts // reducing only some axes and irregular layouts
{ {
array a(1.0f); array a(1.0f);

View File

@@ -915,6 +915,23 @@ TEST_CASE("test reduction ops") {
CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item<bool>()); CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item<bool>());
} }
// Test unsigned sum
{
const int num_elems = 1000;
auto x = astype(full({num_elems}, 255), uint8);
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 255 * num_elems);
x = astype(full({num_elems}, 65535), uint16);
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 65535 * num_elems);
x = full({3, 3, 3}, 10000, uint32);
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 270000);
x = full({3, 3, 3}, 10000, uint64);
CHECK_EQ(sum(x, Device::cpu).item<uint64_t>(), 270000);
}
// Test prod // Test prod
{ {
auto x = array({}); auto x = array({});
@@ -947,6 +964,21 @@ TEST_CASE("test reduction ops") {
CHECK(array_equal(prod(x, 1), array({true, false})).item<bool>()); CHECK(array_equal(prod(x, 1), array({true, false})).item<bool>());
} }
// Test unsigned prod
{
auto x = array({255, 255}, {2}, uint8);
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 65025);
x = array({65535, 2}, {2}, uint16);
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 131070);
x = array({100000, 2}, {2}, uint32);
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 200000);
x = array({100000, 2}, {2}, uint64);
CHECK_EQ(prod(x, Device::cpu).item<uint64_t>(), 200000);
}
// Test all // Test all
{ {
auto x = array({}); auto x = array({});