Compare commits

...

4 Commits

Author SHA1 Message Date
Awni Hannun
56be773610 version (#2470) 2025-08-07 00:36:04 -07:00
Jagrit Digani
a9bdd67baa Add CUDA sdpa vector (#2468) 2025-08-06 21:40:26 -07:00
Angelos Katharopoulos
f2adb5638d Fix typo in metal command encoder (#2471) 2025-08-06 16:58:23 -07:00
Luca Vivona
728d4db582 Support destination arg in tree flatten/unflatten (#2450) 2025-08-06 15:34:59 -07:00
10 changed files with 866 additions and 63 deletions

View File

@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
optimizer.update(model, grads)
# Save the state
state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", dict(state))
state = tree_flatten(optimizer.state, destination={})
mx.save_safetensors("optimizer.safetensors", state)
# Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
state = tree_unflatten(mx.load("optimizer.safetensors"))
optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For

View File

@@ -7,17 +7,17 @@ Exporting Functions
MLX has an API to export and import functions to and from a file. This lets you
run computations written in one MLX front-end (e.g. Python) in another MLX
front-end (e.g. C++).
front-end (e.g. C++).
This guide walks through the basics of the MLX export API with some examples.
To see the full list of functions check-out the :ref:`API documentation
<export>`.
Basics of Exporting
Basics of Exporting
-------------------
Let's start with a simple example:
.. code-block:: python
def fun(x, y):
@@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
x = mx.array(1.0)
y = mx.array(1.0)
# Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y)
@@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
For enclosed arrays inside an exported function, be extra careful to ensure
they are evaluated. The computation graph that gets exported will include
the computation that produces enclosed inputs.
If the above example was missing ``mx.eval(model.parameters()``, the
exported function would include the random initialization of the
:obj:`mlx.nn.Module` parameters.
@@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
# Set the model's parameters to the input parameters
model.update(tree_unflatten(list(params.items())))
return model(x)
params = dict(tree_flatten(model.parameters()))
params = tree_flatten(model.parameters(), destination={})
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
@@ -169,8 +169,8 @@ to export a function which can be used for inputs with variable shapes:
# Ok
out, = imported_abs(mx.array(-1.0))
# Also ok
# Also ok
out, = imported_abs(mx.array([-1.0, -2.0]))
With ``shapeless=False`` (which is the default), the second call to
@@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
def fun(x, y=None):
constant = mx.array(3.0)
if y is not None:
x += y
x += y
return x + constant
with mx.exporter("fun.mlxfn", fun) as exporter:
@@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
print(out)
In the above example the function constant data, (i.e. ``constant``), is only
saved once.
saved once.
Transformations with Imported Functions
---------------------------------------
@@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
# Prints: array(1, dtype=float32)
print(dfdx(x))
# Compile the imported function
# Compile the imported function
mx.compile(imported_fun)
# Prints: array(0, dtype=float32)
print(compiled_fun(x)[0])
@@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
// Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl;
Imported functions can be transformed in C++ just like in Python. Use
Imported functions can be transformed in C++ just like in Python. Use
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
mx::array>`` for keyword arguments when calling imported functions in C++.

View File

@@ -39,6 +39,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu

View File

@@ -6,17 +6,6 @@
namespace mlx::core {
bool fast::ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
return true;
}
#define NO_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
@@ -53,7 +42,6 @@ NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(CustomKernel)
} // namespace fast

View File

@@ -0,0 +1,781 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/transforms_impl.h"
#include <nvtx3/nvtx3.hpp>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
#define PRAGMA_LOOP_UNROLL #pragma unroll
struct AttnParams {
int B;
int H;
int D;
int qL;
int kL;
int gqa_factor;
float scale;
int64_t Q_strides[3];
int64_t K_strides[3];
int64_t V_strides[3];
int64_t O_strides[3];
};
template <typename T, bool do_causal, int D>
__global__ void kernel_sdpav_1pass(
const T* Q,
const T* K,
const T* V,
T* O,
__grid_constant__ const AttnParams params) {
constexpr int BN = 32;
constexpr int BD = 32;
constexpr int v_per_thread = D / BD;
const int inner_k_stride = BN * int(params.K_strides[2]);
const int inner_v_stride = BN * int(params.V_strides[2]);
typedef float U;
U q[v_per_thread];
U k[v_per_thread];
U o[v_per_thread];
__shared__ U outputs[BN][BD + 1];
__shared__ U max_scores[BN];
__shared__ U sum_exp_scores[BN];
const U scale_log2 = params.scale * 1.44269504089f;
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block);
const int lane_idx = warp.thread_rank();
const int warp_idx = warp.meta_group_rank();
// Adjust to thread block and thread
const int batch_idx = blockIdx.z;
const int head_idx = blockIdx.x;
const int kv_head_idx = head_idx / params.gqa_factor;
const int q_seq_idx = blockIdx.y;
const int kv_seq_idx = warp_idx;
Q += batch_idx * params.Q_strides[0] + // Batch
head_idx * params.Q_strides[1] + // Head
q_seq_idx * params.Q_strides[2]; // Sequence
K += batch_idx * params.K_strides[0] + // Batch
kv_head_idx * params.K_strides[1] + // Head
kv_seq_idx * params.K_strides[2]; // Sequence
V += batch_idx * params.V_strides[0] + // Batch
kv_head_idx * params.V_strides[1] + // Head
kv_seq_idx * params.V_strides[2]; // Sequence
O += batch_idx * params.O_strides[0] + // Batch
head_idx * params.O_strides[1] + // Head
q_seq_idx * params.O_strides[2]; // Sequence
// Read the query and 0 the output accumulator
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]);
}
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
o[i] = 0.f;
}
U max_score = -INFINITY;
U sum_exp_score = 0.f;
// For each key
for (int i = kv_seq_idx; i < params.kL; i += BN) {
bool use_key = true;
if constexpr (do_causal) {
use_key = i <= (params.kL - params.qL + q_seq_idx);
}
if (use_key) {
// Read the key
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
k[j] = K[v_per_thread * lane_idx + j];
}
// Compute the i-th score
U score = 0.f;
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
score += q[j] * k[j];
}
// Warp sum
score = cg::reduce(warp, score, cg::plus<U>());
// Update the accumulators
U new_max = max(max_score, score);
U factor = exp2f(max_score - new_max);
U exp_score = exp2f(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
o[j] = o[j] * factor +
exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]);
}
}
// Move the pointers to the next kv
K += inner_k_stride;
V += inner_v_stride;
}
if (lane_idx == 0) {
max_scores[warp_idx] = max_score;
sum_exp_scores[warp_idx] = sum_exp_score;
}
block.sync();
max_score = max_scores[lane_idx];
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
U factor = exp2f(max_score - new_max);
sum_exp_score =
cg::reduce(warp, sum_exp_scores[lane_idx] * factor, cg::plus<U>());
sum_exp_score = __frcp_rn(sum_exp_score);
// Now we need to aggregate all the outputs
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
outputs[lane_idx][warp_idx] = o[i];
block.sync();
U ot = outputs[warp_idx][lane_idx] * factor;
o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;
block.sync();
}
// And write the output
if (lane_idx == 0) {
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]);
}
}
}
template <typename T, bool do_causal, int D>
__global__ void kernel_sdpav_2pass_1(
const T* Q,
const T* K,
const T* V,
float* partials,
float* sums,
float* maxs,
__grid_constant__ const AttnParams params) {
constexpr int BN = 8;
constexpr int BD = 32;
constexpr int blocks = 32;
constexpr int v_per_thread = D / BD;
const int inner_k_stride = blocks * BN * int(params.K_strides[2]);
const int inner_v_stride = blocks * BN * int(params.V_strides[2]);
typedef float U;
U q[v_per_thread];
U k[v_per_thread];
U o[v_per_thread];
__shared__ U outputs[BN][BD + 1];
__shared__ U max_scores[BN];
__shared__ U sum_exp_scores[BN];
const U scale_log2 = params.scale * 1.44269504089f;
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block);
const int lane_idx = warp.thread_rank();
const int warp_idx = warp.meta_group_rank();
// Adjust to thread block and thread
const int batch_idx = blockIdx.z / blocks;
const int block_idx = blockIdx.z % blocks;
const int head_idx = blockIdx.x;
const int kv_head_idx = head_idx / params.gqa_factor;
const int q_seq_idx = blockIdx.y;
const int kv_seq_idx = block_idx * BN + warp_idx;
Q += batch_idx * params.Q_strides[0] + // Batch
head_idx * params.Q_strides[1] + // Head
q_seq_idx * params.Q_strides[2]; // Sequence
K += batch_idx * params.K_strides[0] + // Batch
kv_head_idx * params.K_strides[1] + // Head
kv_seq_idx * params.K_strides[2]; // Sequence
V += batch_idx * params.V_strides[0] + // Batch
kv_head_idx * params.V_strides[1] + // Head
kv_seq_idx * params.V_strides[2]; // Sequence
const int p_stride_s = blocks;
const int p_stride_h = params.qL * p_stride_s;
const int p_stride_b = params.H * p_stride_h;
const int p_offset = batch_idx * p_stride_b + // Batch
head_idx * p_stride_h + // Head
q_seq_idx * p_stride_s + // Sequence
block_idx; // Block
partials += p_offset * D;
sums += p_offset;
maxs += p_offset;
// Read the query and 0 the output accumulator
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]);
}
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
o[i] = 0.f;
}
U max_score = -1e9;
U sum_exp_score = 0.f;
// For each key
for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {
bool use_key = true;
if constexpr (do_causal) {
use_key = i <= (params.kL - params.qL + q_seq_idx);
}
if (use_key) {
// Read the key
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
k[j] = K[v_per_thread * lane_idx + j];
}
// Compute the i-th score
U score = 0.f;
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
score += q[j] * k[j];
}
// Warp sum
score = cg::reduce(warp, score, cg::plus<U>());
// Update the accumulators
U new_max = max(max_score, score);
U factor = exp2f(max_score - new_max);
U exp_score = exp2f(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
o[j] = o[j] * factor +
exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]);
}
}
// Move the pointers to the next kv
K += inner_k_stride;
V += inner_v_stride;
}
if (lane_idx == 0) {
max_scores[warp_idx] = max_score;
sum_exp_scores[warp_idx] = sum_exp_score;
}
block.sync();
max_score = (lane_idx < BN) ? max_scores[lane_idx] : -1e9;
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
U factor = exp2f(max_score - new_max);
sum_exp_score = (lane_idx < BN) ? sum_exp_scores[lane_idx] : 0.f;
sum_exp_score = cg::reduce(warp, sum_exp_score * factor, cg::plus<U>());
// Write the sum and new max
if (warp_idx == 0) {
sums[0] = sum_exp_score;
maxs[0] = new_max;
}
// Now we need to aggregate all the outputs
auto ff = exp2f(max_scores[warp_idx] - new_max);
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
outputs[warp_idx][lane_idx] = o[i] * ff;
block.sync();
if (warp_idx == 0) {
U ot = outputs[0][lane_idx];
PRAGMA_LOOP_UNROLL
for (int j = 1; j < BN; j++) {
ot += outputs[j][lane_idx];
warp.sync();
}
o[i] = ot;
}
block.sync();
}
if (warp_idx == 0) {
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
partials[v_per_thread * lane_idx + i] = o[i];
}
}
}
template <typename T, bool do_causal, int D>
__global__ void kernel_sdpav_2pass_2(
const float* partials,
const float* sums,
const float* maxs,
T* O,
__grid_constant__ const AttnParams params) {
constexpr int BN = 32;
constexpr int BD = 32;
constexpr int blocks = 32;
constexpr int v_per_thread = D / BD;
typedef float U;
U o[v_per_thread];
__shared__ U outputs[BN][BD + 1];
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block);
const int lane_idx = warp.thread_rank();
const int warp_idx = warp.meta_group_rank();
// Adjust to thread block and thread
const int batch_idx = blockIdx.z;
const int head_idx = blockIdx.x;
const int q_seq_idx = blockIdx.y;
const int p_stride_s = blocks;
const int p_stride_h = params.qL * p_stride_s;
const int p_stride_b = params.H * p_stride_h;
const int p_offset = batch_idx * p_stride_b + // Batch
head_idx * p_stride_h + // Head
q_seq_idx * p_stride_s; // Sequence
partials += p_offset * D + warp_idx * D;
sums += p_offset;
maxs += p_offset;
O += batch_idx * params.O_strides[0] + // Batch
head_idx * params.O_strides[1] + // Head
q_seq_idx * params.O_strides[2]; // Sequence
U max_score = maxs[lane_idx];
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
U factor = exp2f(max_score - new_max);
U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());
sum_exp_score = __frcp_rn(sum_exp_score);
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
o[i] = partials[v_per_thread * lane_idx + i];
}
// Now we need to aggregate all the outputs
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
outputs[lane_idx][warp_idx] = o[i];
block.sync();
U ot = outputs[warp_idx][lane_idx] * factor;
o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;
block.sync();
}
// And write the output
if (lane_idx == 0) {
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]);
}
}
}
} // namespace cu
namespace {
template <typename F>
void dispatch_headdim(int n, F&& f) {
switch (n) {
case 64:
f(std::integral_constant<int, 64>{});
break;
case 96:
f(std::integral_constant<int, 96>{});
break;
case 128:
f(std::integral_constant<int, 128>{});
break;
}
}
void sdpa_vector_1pass_fallback(
const Stream& s,
cu::CommandEncoder& encoder,
const array& q,
const array& k,
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
encoder.set_output_array(o);
cu::AttnParams params{
/* int B = */ q.shape(0),
/* int H = */ q.shape(1),
/* int D = */ q.shape(3),
/* int qL = */ q.shape(2),
/* int kL = */ k.shape(2),
/* int gqa_factor = */ q.shape(1) / k.shape(1),
/* float scale = */ scale,
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
dim3 grid_dim(params.H, params.qL, params.B);
dim3 block_dim(1024, 1, 1);
dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) {
dispatch_bool(do_causal_, [&](auto do_causal) {
dispatch_headdim(params.D, [&](auto headdim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel =
cu::kernel_sdpav_1pass<DataType, do_causal.value, headdim.value>;
encoder.add_kernel_node(
kernel,
grid_dim,
block_dim,
0,
q.data<DataType>(),
k.data<DataType>(),
v.data<DataType>(),
o.data<DataType>(),
params);
});
});
});
}
void sdpa_vector_2pass_fallback(
const Stream& s,
cu::CommandEncoder& encoder,
const array& q,
const array& k,
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
cu::AttnParams params{
/* int B = */ q.shape(0),
/* int H = */ q.shape(1),
/* int D = */ q.shape(3),
/* int qL = */ q.shape(2),
/* int kL = */ k.shape(2),
/* int gqa_factor = */ q.shape(1) / k.shape(1),
/* float scale = */ scale,
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
// Allocate the intermediates
int blocks = 32;
Shape intermediate_shape;
intermediate_shape.reserve(o.ndim() + 1);
intermediate_shape.insert(
intermediate_shape.end(), o.shape().begin(), o.shape().end() - 1);
intermediate_shape.push_back(blocks);
intermediate_shape.push_back(o.shape().back());
array intermediate(intermediate_shape, float32, nullptr, {});
intermediate_shape.pop_back();
array sums(intermediate_shape, float32, nullptr, {});
array maxs(std::move(intermediate_shape), float32, nullptr, {});
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
sums.set_data(allocator::malloc(sums.nbytes()));
maxs.set_data(allocator::malloc(maxs.nbytes()));
encoder.add_temporary(intermediate);
encoder.add_temporary(sums);
encoder.add_temporary(maxs);
dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) {
dispatch_bool(do_causal_, [&](auto do_causal) {
dispatch_headdim(params.D, [&](auto headdim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
{
auto kernel = cu::
kernel_sdpav_2pass_1<DataType, do_causal.value, headdim.value>;
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
encoder.set_output_array(intermediate);
encoder.set_output_array(sums);
encoder.set_output_array(maxs);
dim3 grid_dim(params.H, params.qL, params.B * 32);
dim3 block_dim(8 * 32, 1, 1);
encoder.add_kernel_node(
kernel,
grid_dim,
block_dim,
0,
q.data<DataType>(),
k.data<DataType>(),
v.data<DataType>(),
intermediate.data<float>(),
sums.data<float>(),
maxs.data<float>(),
params);
}
{
auto kernel = cu::
kernel_sdpav_2pass_2<DataType, do_causal.value, headdim.value>;
encoder.set_input_array(intermediate);
encoder.set_input_array(sums);
encoder.set_input_array(maxs);
encoder.set_output_array(o);
dim3 grid_dim(params.H, params.qL, params.B);
dim3 block_dim(1024, 1, 1);
encoder.add_kernel_node(
kernel,
grid_dim,
block_dim,
0,
intermediate.data<float>(),
sums.data<float>(),
maxs.data<float>(),
o.data<DataType>(),
params);
}
});
});
});
}
void sdpa_vector_fallback(
const Stream& s,
cu::CommandEncoder& encoder,
const array& q,
const array& k,
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
int kL = k.shape(2);
if (kL > 1024) {
return sdpa_vector_2pass_fallback(
s, encoder, q, k, v, scale, o, do_causal_);
} else {
return sdpa_vector_1pass_fallback(
s, encoder, q, k, v, scale, o, do_causal_);
}
}
} // namespace
namespace fast {
bool ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
if (detail::in_grad_tracing()) {
return true;
}
if (s.device == Device::cpu) {
return true;
}
const int value_head_dim = v.shape(-1);
const int query_head_dim = q.shape(-1);
const int query_sequence_length = q.shape(2);
const int key_sequence_length = k.shape(2);
const bool sdpa_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
const bool supported_vector_config =
sdpa_supported_head_dim && query_sequence_length < 4;
const bool supported_config = supported_vector_config;
return has_arr_mask || !supported_config;
}
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
array& out) {
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
auto& q_pre = inputs[0];
auto& k_pre = inputs[1];
auto& v_pre = inputs[2];
auto& o = out;
std::vector<array> copies;
// Define some copy functions to ensure the layout of the inputs is as
// expected.
copies.reserve(3);
auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) {
array arr_copy = contiguous_copy_gpu(arr, s);
copies.push_back(std::move(arr_copy));
return copies.back();
} else {
return arr;
}
};
// We are in vector mode ie single query
if (q_pre.shape(2) < 4) {
auto q_copy_unless = [](const array& arr) {
if (arr.flags().row_contiguous) {
return true;
}
auto& strides = arr.strides();
auto& shape = arr.shape();
if (shape[0] == 1 || shape[1] == 1) {
// If either the batch or head dimension is a singleton, the other can
// be transposed with the sequence dimension
auto bidx = shape[0] == 1 ? 1 : 0;
return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) &&
(strides[bidx] == shape[3]);
}
return false;
};
auto kv_copy_unless = [](const array& arr) {
// keys and values should be copied if:
// - the last dimension is not contiguous
// - the batch and head dim are not contiguous
auto& strides = arr.strides();
auto& shape = arr.shape();
if (strides.back() != 1) {
return false;
}
if (shape[0] == 1 || shape[1] == 1) {
return true;
}
return (strides[0] == strides[1] * shape[1]);
};
const auto& q = copy_unless(q_copy_unless, q_pre);
const auto& k = copy_unless(kv_copy_unless, k_pre);
const auto& v = copy_unless(kv_copy_unless, v_pre);
for (const auto& cp : copies) {
encoder.add_temporary(cp);
}
// Donate the query if possible
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
o.copy_shared_buffer(q);
} else {
int64_t str_oD = 1;
int64_t str_oH = o.shape(3);
int64_t str_oL = o.shape(1) * str_oH;
int64_t str_oB = o.shape(2) * str_oL;
size_t data_size = o.shape(0) * str_oB;
array::Flags flags{
/* bool contiguous = */ 1,
/* bool row_contiguous = */ o.shape(2) == 1,
/* bool col_contiguous = */ 0,
};
o.set_data(
allocator::malloc(o.nbytes()),
data_size,
{str_oB, str_oH, str_oL, str_oD},
flags);
}
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
}
// Full attention mode should never reach here
else {
throw std::runtime_error("Doesn't support matrix yet.");
}
}
} // namespace fast
} // namespace mlx::core

View File

@@ -104,7 +104,7 @@ struct CommandEncoder {
};
// Outputs of all kernels in the encoder including temporaries
std::unordered_set<const void*> outputs() {
std::unordered_set<const void*>& outputs() {
return all_outputs_;
};

View File

@@ -3,8 +3,8 @@
#pragma once
#define MLX_VERSION_MAJOR 0
#define MLX_VERSION_MINOR 27
#define MLX_VERSION_PATCH 1
#define MLX_VERSION_MINOR 28
#define MLX_VERSION_PATCH 0
#define MLX_VERSION_NUMERIC \
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)

View File

@@ -178,7 +178,7 @@ class Module(dict):
if strict:
new_weights = dict(weights)
curr_weights = dict(tree_flatten(self.parameters()))
curr_weights = tree_flatten(self.parameters(), destination={})
if extras := (new_weights.keys() - curr_weights.keys()):
num_extra = len(extras)
extras = ",\n".join(sorted(extras))
@@ -212,7 +212,7 @@ class Module(dict):
- ``.npz`` will use :func:`mx.savez`
- ``.safetensors`` will use :func:`mx.save_safetensors`
"""
params_dict = dict(tree_flatten(self.parameters()))
params_dict = tree_flatten(self.parameters(), destination={})
if file.endswith(".npz"):
mx.savez(file, **params_dict)

View File

@@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc.
from collections import defaultdict
from itertools import zip_longest
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
def tree_map(
@@ -114,8 +114,11 @@ def tree_map_with_path(
def tree_flatten(
tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None
) -> Any:
tree: Any,
prefix: str = "",
is_leaf: Optional[Callable] = None,
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None,
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
"""Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and
@@ -128,9 +131,12 @@ def tree_flatten(
print(tree_flatten([[[0]]]))
# [("0.0.0", 0)]
print(tree_flatten([[[0]]], ".hello"))
print(tree_flatten([[[0]]], prefix=".hello"))
# [("hello.0.0.0", 0)]
tree_flatten({"a": {"b": 1}}, destination={})
{"a.b": 1}
.. note::
Dictionaries should have keys that are valid Python identifiers.
@@ -140,26 +146,50 @@ def tree_flatten(
always discarded.
is_leaf (callable): An optional callable that returns True if the
passed object is considered a leaf or False otherwise.
destination (list or dict, optional): A list or dictionary to store the
flattened tree. If None an empty list will be used. Default: ``None``.
Returns:
List[Tuple[str, Any]]: The flat representation of the Python tree.
Union[List[Tuple[str, Any]], Dict[str, Any]]: The flat representation of
the Python tree.
"""
flat_tree = []
if destination is None:
destination = []
if is_leaf is None or not is_leaf(tree):
if isinstance(tree, (list, tuple)):
for i, t in enumerate(tree):
flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf))
return flat_tree
if isinstance(tree, dict):
for k, t in tree.items():
flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf))
return flat_tree
# Create the function to update the destination. We are taking advantage of
# the fact that list.extend and dict.update have the same API to simplify
# the code a bit.
if isinstance(destination, list):
_add_to_destination = destination.extend
elif isinstance(destination, dict):
_add_to_destination = destination.update
else:
raise ValueError("Destination should be either a list or a dictionary or None")
return [(prefix[1:], tree)]
# Leaf identified by is_leaf so add it and return
if is_leaf is not None and is_leaf(tree):
_add_to_destination([(prefix[1:], tree)])
return destination
# List or tuple so recursively add each subtree
if isinstance(tree, (list, tuple)):
for i, item in enumerate(tree):
tree_flatten(item, f"{prefix}.{i}", is_leaf, destination)
return destination
# Dictionary so recursively add each subtree
if isinstance(tree, dict):
for key, value in tree.items():
tree_flatten(value, f"{prefix}.{key}", is_leaf, destination)
return destination
# Leaf so add it and return
_add_to_destination([(prefix[1:], tree)])
return destination
def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
"""Recreate a Python tree from its flat representation.
.. code-block:: python
@@ -170,31 +200,34 @@ def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
print(d)
# {"hello": {"world": 42}}
d = tree_unflatten({"hello.world": 42})
print(d)
# {"hello": {"world": 42}}
Args:
tree (list[tuple[str, Any]]): The flat representation of a Python tree.
tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree.
For instance as returned by :meth:`tree_flatten`.
Returns:
A Python tree.
"""
if len(tree) == 1 and tree[0][0] == "":
return tree[0][1]
items = tree.items() if isinstance(tree, dict) else tree
try:
int(tree[0][0].split(".", maxsplit=1)[0])
is_list = True
except ValueError:
is_list = False
# Special case when we have just one element in the tree ie not a tree
if len(items) == 1:
key, value = next(iter(items))
if key == "":
return value
# collect children
children = defaultdict(list)
for key, value in tree:
for key, value in items:
current_idx, *next_idx = key.split(".", maxsplit=1)
next_idx = "" if not next_idx else next_idx[0]
children[current_idx].append((next_idx, value))
# recursively map them to the original container
if is_list:
# Assume they are a list and fail to dict if the keys are not all integers
try:
keys = sorted((int(idx), idx) for idx in children.keys())
l = []
for i, k in keys:
@@ -202,7 +235,7 @@ def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
l.extend([{} for _ in range(i - len(l))])
l.append(tree_unflatten(children[k]))
return l
else:
except ValueError:
return {k: tree_unflatten(v) for k, v in children.items()}

View File

@@ -80,7 +80,7 @@ class TestBase(mlx_tests.MLXTestCase):
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
model = DictModule()
params = dict(tree_flatten(model.parameters()))
params = tree_flatten(model.parameters(), destination={})
self.assertEqual(len(params), 2)
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))