mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
6 Commits
sdpav-back
...
7fde1b6a1e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7fde1b6a1e | ||
|
|
aa7b47481a | ||
|
|
56be773610 | ||
|
|
a9bdd67baa | ||
|
|
f2adb5638d | ||
|
|
728d4db582 |
@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
|
|||||||
optimizer.update(model, grads)
|
optimizer.update(model, grads)
|
||||||
|
|
||||||
# Save the state
|
# Save the state
|
||||||
state = tree_flatten(optimizer.state)
|
state = tree_flatten(optimizer.state, destination={})
|
||||||
mx.save_safetensors("optimizer.safetensors", dict(state))
|
mx.save_safetensors("optimizer.safetensors", state)
|
||||||
|
|
||||||
# Later on, for example when loading from a checkpoint,
|
# Later on, for example when loading from a checkpoint,
|
||||||
# recreate the optimizer and load the state
|
# recreate the optimizer and load the state
|
||||||
optimizer = optim.Adam(learning_rate=1e-2)
|
optimizer = optim.Adam(learning_rate=1e-2)
|
||||||
|
|
||||||
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
state = tree_unflatten(mx.load("optimizer.safetensors"))
|
||||||
optimizer.state = state
|
optimizer.state = state
|
||||||
|
|
||||||
Note, not every optimizer configuation parameter is saved in the state. For
|
Note, not every optimizer configuation parameter is saved in the state. For
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ parameters, pass them as inputs to the ``call`` wrapper:
|
|||||||
model.update(tree_unflatten(list(params.items())))
|
model.update(tree_unflatten(list(params.items())))
|
||||||
return model(x)
|
return model(x)
|
||||||
|
|
||||||
params = dict(tree_flatten(model.parameters()))
|
params = tree_flatten(model.parameters(), destination={})
|
||||||
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||||
|
|||||||
@@ -10,7 +10,34 @@ namespace mlx::core::cu {
|
|||||||
|
|
||||||
namespace cg = cooperative_groups;
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
__global__ void set_mm_device_pointers(
|
template <int NDIM>
|
||||||
|
__global__ void set_mm_device_pointers_nd(
|
||||||
|
int8_t** pointers,
|
||||||
|
int8_t* a_start,
|
||||||
|
int8_t* b_start,
|
||||||
|
int8_t* out_start,
|
||||||
|
int item_size,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
|
||||||
|
int64_t batch_stride,
|
||||||
|
int batch_count) {
|
||||||
|
auto index = cg::this_grid().thread_rank();
|
||||||
|
if (index >= batch_count) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto [a_offset, b_offset] = elem_to_loc_nd<NDIM>(
|
||||||
|
index,
|
||||||
|
batch_shape.data(),
|
||||||
|
a_batch_strides.data(),
|
||||||
|
b_batch_strides.data());
|
||||||
|
pointers[index] = a_start + item_size * a_offset;
|
||||||
|
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||||
|
pointers[index + 2 * batch_count] =
|
||||||
|
out_start + item_size * index * batch_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void set_mm_device_pointers_g(
|
||||||
int8_t** pointers,
|
int8_t** pointers,
|
||||||
int8_t* a_start,
|
int8_t* a_start,
|
||||||
int8_t* b_start,
|
int8_t* b_start,
|
||||||
@@ -38,7 +65,38 @@ __global__ void set_mm_device_pointers(
|
|||||||
out_start + item_size * index * batch_stride;
|
out_start + item_size * index * batch_stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void set_addmm_device_pointers(
|
template <int NDIM>
|
||||||
|
__global__ void set_addmm_device_pointers_nd(
|
||||||
|
int8_t** pointers,
|
||||||
|
int8_t* a_start,
|
||||||
|
int8_t* b_start,
|
||||||
|
int8_t* c_start,
|
||||||
|
int8_t* out_start,
|
||||||
|
int item_size,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> c_batch_strides,
|
||||||
|
int64_t batch_stride,
|
||||||
|
int batch_count) {
|
||||||
|
auto index = cg::this_grid().thread_rank();
|
||||||
|
if (index >= batch_count) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto [a_offset, b_offset, c_offset] = elem_to_loc_nd<NDIM>(
|
||||||
|
index,
|
||||||
|
batch_shape.data(),
|
||||||
|
a_batch_strides.data(),
|
||||||
|
b_batch_strides.data(),
|
||||||
|
c_batch_strides.data());
|
||||||
|
pointers[index] = a_start + item_size * a_offset;
|
||||||
|
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||||
|
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
|
||||||
|
pointers[index + 3 * batch_count] =
|
||||||
|
out_start + item_size * index * batch_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void set_addmm_device_pointers_g(
|
||||||
int8_t** pointers,
|
int8_t** pointers,
|
||||||
int8_t* a_start,
|
int8_t* a_start,
|
||||||
int8_t* b_start,
|
int8_t* b_start,
|
||||||
@@ -89,37 +147,62 @@ void Matmul::run_batched(
|
|||||||
const mlx::core::Shape& batch_shape,
|
const mlx::core::Shape& batch_shape,
|
||||||
const mlx::core::Strides& a_batch_strides,
|
const mlx::core::Strides& a_batch_strides,
|
||||||
const mlx::core::Strides& b_batch_strides) {
|
const mlx::core::Strides& b_batch_strides) {
|
||||||
auto batch_count = out.size() / (M_ * N_);
|
int batch_count = out.size() / (M_ * N_);
|
||||||
set_pointer_mode(a_desc_, batch_count);
|
set_pointer_mode(a_desc_, batch_count);
|
||||||
set_pointer_mode(b_desc_, batch_count);
|
set_pointer_mode(b_desc_, batch_count);
|
||||||
set_pointer_mode(out_desc_, batch_count);
|
set_pointer_mode(out_desc_, batch_count);
|
||||||
|
|
||||||
// Launch kernel to set device offsets
|
// Launch kernel to set device offsets
|
||||||
auto pointers = array(
|
auto pointers = array(
|
||||||
allocator::malloc(batch_count * sizeof(uint64_t) * 3),
|
allocator::malloc(batch_count * sizeof(void*) * 3),
|
||||||
{static_cast<int>(batch_count * 3)},
|
{batch_count * 3},
|
||||||
uint64);
|
uint64);
|
||||||
|
|
||||||
encoder.add_temporary(pointers);
|
encoder.add_temporary(pointers);
|
||||||
int block_size = 512;
|
|
||||||
encoder.set_output_array(pointers);
|
encoder.set_output_array(pointers);
|
||||||
|
|
||||||
|
int block_dims = std::min(batch_count, 256);
|
||||||
|
int num_blocks = cuda::ceil_div(batch_count, block_dims);
|
||||||
|
int64_t batch_stride = M_ * N_;
|
||||||
|
int item_size = out.itemsize();
|
||||||
|
|
||||||
|
int ndim = batch_shape.size();
|
||||||
|
if (ndim <= 3) {
|
||||||
|
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
cu::set_mm_device_pointers,
|
cu::set_mm_device_pointers_nd<ndim_constant()>,
|
||||||
cuda::ceil_div(pointers.size(), block_size),
|
num_blocks,
|
||||||
block_size,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
pointers.data<int8_t*>(),
|
pointers.data<int8_t*>(),
|
||||||
a.data<int8_t>(),
|
a.data<int8_t>(),
|
||||||
b.data<int8_t>(),
|
b.data<int8_t>(),
|
||||||
out.data<int8_t>(),
|
out.data<int8_t>(),
|
||||||
static_cast<int>(out.dtype().size()),
|
item_size,
|
||||||
|
const_param<ndim_constant()>(batch_shape),
|
||||||
|
const_param<ndim_constant()>(a_batch_strides),
|
||||||
|
const_param<ndim_constant()>(b_batch_strides),
|
||||||
|
batch_stride,
|
||||||
|
batch_count);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::set_mm_device_pointers_g,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
pointers.data<int8_t*>(),
|
||||||
|
a.data<int8_t>(),
|
||||||
|
b.data<int8_t>(),
|
||||||
|
out.data<int8_t>(),
|
||||||
|
item_size,
|
||||||
const_param(batch_shape),
|
const_param(batch_shape),
|
||||||
const_param(a_batch_strides),
|
const_param(a_batch_strides),
|
||||||
const_param(b_batch_strides),
|
const_param(b_batch_strides),
|
||||||
static_cast<int64_t>(M_) * N_,
|
batch_stride,
|
||||||
static_cast<int>(batch_shape.size()),
|
ndim,
|
||||||
batch_count);
|
batch_count);
|
||||||
|
}
|
||||||
|
|
||||||
// Run matmul
|
// Run matmul
|
||||||
encoder.set_input_array(pointers);
|
encoder.set_input_array(pointers);
|
||||||
@@ -150,7 +233,7 @@ void Matmul::run_batched(
|
|||||||
const mlx::core::Strides& c_batch_strides,
|
const mlx::core::Strides& c_batch_strides,
|
||||||
float alpha,
|
float alpha,
|
||||||
float beta) {
|
float beta) {
|
||||||
auto batch_count = out.size() / (M_ * N_);
|
int batch_count = out.size() / (M_ * N_);
|
||||||
set_pointer_mode(a_desc_, batch_count);
|
set_pointer_mode(a_desc_, batch_count);
|
||||||
set_pointer_mode(b_desc_, batch_count);
|
set_pointer_mode(b_desc_, batch_count);
|
||||||
set_pointer_mode(c_desc_, batch_count);
|
set_pointer_mode(c_desc_, batch_count);
|
||||||
@@ -159,30 +242,58 @@ void Matmul::run_batched(
|
|||||||
// Launch kernel to set device offsets
|
// Launch kernel to set device offsets
|
||||||
auto pointers = array(
|
auto pointers = array(
|
||||||
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
|
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
|
||||||
{static_cast<int>(batch_count * 4)},
|
{batch_count * 4},
|
||||||
uint64);
|
uint64);
|
||||||
|
|
||||||
encoder.add_temporary(pointers);
|
encoder.add_temporary(pointers);
|
||||||
int block_size = 512;
|
|
||||||
encoder.set_output_array(pointers);
|
encoder.set_output_array(pointers);
|
||||||
|
|
||||||
|
int block_dims = std::min(batch_count, 256);
|
||||||
|
int num_blocks = cuda::ceil_div(batch_count, block_dims);
|
||||||
|
int64_t batch_stride = M_ * N_;
|
||||||
|
int item_size = out.itemsize();
|
||||||
|
|
||||||
|
int ndim = batch_shape.size();
|
||||||
|
if (ndim <= 3) {
|
||||||
|
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
cu::set_addmm_device_pointers,
|
cu::set_addmm_device_pointers_nd<ndim_constant()>,
|
||||||
cuda::ceil_div(pointers.size(), block_size),
|
num_blocks,
|
||||||
block_size,
|
block_dims,
|
||||||
0,
|
0,
|
||||||
pointers.data<int8_t*>(),
|
pointers.data<int8_t*>(),
|
||||||
a.data<int8_t>(),
|
a.data<int8_t>(),
|
||||||
b.data<int8_t>(),
|
b.data<int8_t>(),
|
||||||
c.data<int8_t>(),
|
c.data<int8_t>(),
|
||||||
out.data<int8_t>(),
|
out.data<int8_t>(),
|
||||||
static_cast<int>(out.dtype().size()),
|
item_size,
|
||||||
|
const_param<ndim_constant()>(batch_shape),
|
||||||
|
const_param<ndim_constant()>(a_batch_strides),
|
||||||
|
const_param<ndim_constant()>(b_batch_strides),
|
||||||
|
const_param<ndim_constant()>(c_batch_strides),
|
||||||
|
batch_stride,
|
||||||
|
batch_count);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::set_addmm_device_pointers_g,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
pointers.data<int8_t*>(),
|
||||||
|
a.data<int8_t>(),
|
||||||
|
b.data<int8_t>(),
|
||||||
|
c.data<int8_t>(),
|
||||||
|
out.data<int8_t>(),
|
||||||
|
item_size,
|
||||||
const_param(batch_shape),
|
const_param(batch_shape),
|
||||||
const_param(a_batch_strides),
|
const_param(a_batch_strides),
|
||||||
const_param(b_batch_strides),
|
const_param(b_batch_strides),
|
||||||
const_param(c_batch_strides),
|
const_param(c_batch_strides),
|
||||||
static_cast<int64_t>(M_) * N_,
|
batch_stride,
|
||||||
static_cast<int>(batch_shape.size()),
|
ndim,
|
||||||
batch_count);
|
batch_count);
|
||||||
|
}
|
||||||
|
|
||||||
// Run matmul
|
// Run matmul
|
||||||
encoder.set_input_array(pointers);
|
encoder.set_input_array(pointers);
|
||||||
|
|||||||
@@ -6,17 +6,6 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
bool fast::ScaledDotProductAttention::use_fallback(
|
|
||||||
const array& q,
|
|
||||||
const array& k,
|
|
||||||
const array& v,
|
|
||||||
bool has_mask,
|
|
||||||
bool has_arr_mask,
|
|
||||||
bool do_causal,
|
|
||||||
Stream s) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
#define NO_GPU_MULTI(func) \
|
#define NO_GPU_MULTI(func) \
|
||||||
void func::eval_gpu( \
|
void func::eval_gpu( \
|
||||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||||
@@ -53,7 +42,6 @@ NO_GPU_MULTI(Eig)
|
|||||||
NO_GPU_MULTI(Eigh)
|
NO_GPU_MULTI(Eigh)
|
||||||
|
|
||||||
namespace fast {
|
namespace fast {
|
||||||
NO_GPU(ScaledDotProductAttention)
|
|
||||||
NO_GPU_MULTI(CustomKernel)
|
NO_GPU_MULTI(CustomKernel)
|
||||||
} // namespace fast
|
} // namespace fast
|
||||||
|
|
||||||
|
|||||||
781
mlx/backend/cuda/scaled_dot_product_attention.cu
Normal file
781
mlx/backend/cuda/scaled_dot_product_attention.cu
Normal file
@@ -0,0 +1,781 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/config.h"
|
||||||
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/lru_cache.h"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/fast_primitives.h"
|
||||||
|
#include "mlx/transforms_impl.h"
|
||||||
|
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cooperative_groups/reduce.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
#define PRAGMA_LOOP_UNROLL #pragma unroll
|
||||||
|
|
||||||
|
struct AttnParams {
|
||||||
|
int B;
|
||||||
|
int H;
|
||||||
|
int D;
|
||||||
|
|
||||||
|
int qL;
|
||||||
|
int kL;
|
||||||
|
|
||||||
|
int gqa_factor;
|
||||||
|
float scale;
|
||||||
|
|
||||||
|
int64_t Q_strides[3];
|
||||||
|
int64_t K_strides[3];
|
||||||
|
int64_t V_strides[3];
|
||||||
|
int64_t O_strides[3];
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, bool do_causal, int D>
|
||||||
|
__global__ void kernel_sdpav_1pass(
|
||||||
|
const T* Q,
|
||||||
|
const T* K,
|
||||||
|
const T* V,
|
||||||
|
T* O,
|
||||||
|
__grid_constant__ const AttnParams params) {
|
||||||
|
constexpr int BN = 32;
|
||||||
|
constexpr int BD = 32;
|
||||||
|
|
||||||
|
constexpr int v_per_thread = D / BD;
|
||||||
|
|
||||||
|
const int inner_k_stride = BN * int(params.K_strides[2]);
|
||||||
|
const int inner_v_stride = BN * int(params.V_strides[2]);
|
||||||
|
|
||||||
|
typedef float U;
|
||||||
|
|
||||||
|
U q[v_per_thread];
|
||||||
|
U k[v_per_thread];
|
||||||
|
U o[v_per_thread];
|
||||||
|
|
||||||
|
__shared__ U outputs[BN][BD + 1];
|
||||||
|
__shared__ U max_scores[BN];
|
||||||
|
__shared__ U sum_exp_scores[BN];
|
||||||
|
|
||||||
|
const U scale_log2 = params.scale * 1.44269504089f;
|
||||||
|
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<32>(block);
|
||||||
|
|
||||||
|
const int lane_idx = warp.thread_rank();
|
||||||
|
const int warp_idx = warp.meta_group_rank();
|
||||||
|
|
||||||
|
// Adjust to thread block and thread
|
||||||
|
const int batch_idx = blockIdx.z;
|
||||||
|
const int head_idx = blockIdx.x;
|
||||||
|
const int kv_head_idx = head_idx / params.gqa_factor;
|
||||||
|
|
||||||
|
const int q_seq_idx = blockIdx.y;
|
||||||
|
const int kv_seq_idx = warp_idx;
|
||||||
|
|
||||||
|
Q += batch_idx * params.Q_strides[0] + // Batch
|
||||||
|
head_idx * params.Q_strides[1] + // Head
|
||||||
|
q_seq_idx * params.Q_strides[2]; // Sequence
|
||||||
|
|
||||||
|
K += batch_idx * params.K_strides[0] + // Batch
|
||||||
|
kv_head_idx * params.K_strides[1] + // Head
|
||||||
|
kv_seq_idx * params.K_strides[2]; // Sequence
|
||||||
|
|
||||||
|
V += batch_idx * params.V_strides[0] + // Batch
|
||||||
|
kv_head_idx * params.V_strides[1] + // Head
|
||||||
|
kv_seq_idx * params.V_strides[2]; // Sequence
|
||||||
|
|
||||||
|
O += batch_idx * params.O_strides[0] + // Batch
|
||||||
|
head_idx * params.O_strides[1] + // Head
|
||||||
|
q_seq_idx * params.O_strides[2]; // Sequence
|
||||||
|
|
||||||
|
// Read the query and 0 the output accumulator
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
o[i] = 0.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
U max_score = -INFINITY;
|
||||||
|
U sum_exp_score = 0.f;
|
||||||
|
|
||||||
|
// For each key
|
||||||
|
for (int i = kv_seq_idx; i < params.kL; i += BN) {
|
||||||
|
bool use_key = true;
|
||||||
|
if constexpr (do_causal) {
|
||||||
|
use_key = i <= (params.kL - params.qL + q_seq_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (use_key) {
|
||||||
|
// Read the key
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int j = 0; j < v_per_thread; j++) {
|
||||||
|
k[j] = K[v_per_thread * lane_idx + j];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the i-th score
|
||||||
|
U score = 0.f;
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int j = 0; j < v_per_thread; j++) {
|
||||||
|
score += q[j] * k[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warp sum
|
||||||
|
score = cg::reduce(warp, score, cg::plus<U>());
|
||||||
|
|
||||||
|
// Update the accumulators
|
||||||
|
U new_max = max(max_score, score);
|
||||||
|
U factor = exp2f(max_score - new_max);
|
||||||
|
U exp_score = exp2f(score - new_max);
|
||||||
|
|
||||||
|
max_score = new_max;
|
||||||
|
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||||
|
|
||||||
|
// Update the output accumulator
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int j = 0; j < v_per_thread; j++) {
|
||||||
|
o[j] = o[j] * factor +
|
||||||
|
exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move the pointers to the next kv
|
||||||
|
K += inner_k_stride;
|
||||||
|
V += inner_v_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lane_idx == 0) {
|
||||||
|
max_scores[warp_idx] = max_score;
|
||||||
|
sum_exp_scores[warp_idx] = sum_exp_score;
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
|
||||||
|
max_score = max_scores[lane_idx];
|
||||||
|
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
||||||
|
U factor = exp2f(max_score - new_max);
|
||||||
|
sum_exp_score =
|
||||||
|
cg::reduce(warp, sum_exp_scores[lane_idx] * factor, cg::plus<U>());
|
||||||
|
sum_exp_score = __frcp_rn(sum_exp_score);
|
||||||
|
|
||||||
|
// Now we need to aggregate all the outputs
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
outputs[lane_idx][warp_idx] = o[i];
|
||||||
|
block.sync();
|
||||||
|
U ot = outputs[warp_idx][lane_idx] * factor;
|
||||||
|
o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;
|
||||||
|
block.sync();
|
||||||
|
}
|
||||||
|
|
||||||
|
// And write the output
|
||||||
|
if (lane_idx == 0) {
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, bool do_causal, int D>
|
||||||
|
__global__ void kernel_sdpav_2pass_1(
|
||||||
|
const T* Q,
|
||||||
|
const T* K,
|
||||||
|
const T* V,
|
||||||
|
float* partials,
|
||||||
|
float* sums,
|
||||||
|
float* maxs,
|
||||||
|
__grid_constant__ const AttnParams params) {
|
||||||
|
constexpr int BN = 8;
|
||||||
|
constexpr int BD = 32;
|
||||||
|
constexpr int blocks = 32;
|
||||||
|
|
||||||
|
constexpr int v_per_thread = D / BD;
|
||||||
|
|
||||||
|
const int inner_k_stride = blocks * BN * int(params.K_strides[2]);
|
||||||
|
const int inner_v_stride = blocks * BN * int(params.V_strides[2]);
|
||||||
|
|
||||||
|
typedef float U;
|
||||||
|
|
||||||
|
U q[v_per_thread];
|
||||||
|
U k[v_per_thread];
|
||||||
|
U o[v_per_thread];
|
||||||
|
|
||||||
|
__shared__ U outputs[BN][BD + 1];
|
||||||
|
__shared__ U max_scores[BN];
|
||||||
|
__shared__ U sum_exp_scores[BN];
|
||||||
|
|
||||||
|
const U scale_log2 = params.scale * 1.44269504089f;
|
||||||
|
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<32>(block);
|
||||||
|
|
||||||
|
const int lane_idx = warp.thread_rank();
|
||||||
|
const int warp_idx = warp.meta_group_rank();
|
||||||
|
|
||||||
|
// Adjust to thread block and thread
|
||||||
|
const int batch_idx = blockIdx.z / blocks;
|
||||||
|
const int block_idx = blockIdx.z % blocks;
|
||||||
|
const int head_idx = blockIdx.x;
|
||||||
|
const int kv_head_idx = head_idx / params.gqa_factor;
|
||||||
|
|
||||||
|
const int q_seq_idx = blockIdx.y;
|
||||||
|
const int kv_seq_idx = block_idx * BN + warp_idx;
|
||||||
|
|
||||||
|
Q += batch_idx * params.Q_strides[0] + // Batch
|
||||||
|
head_idx * params.Q_strides[1] + // Head
|
||||||
|
q_seq_idx * params.Q_strides[2]; // Sequence
|
||||||
|
|
||||||
|
K += batch_idx * params.K_strides[0] + // Batch
|
||||||
|
kv_head_idx * params.K_strides[1] + // Head
|
||||||
|
kv_seq_idx * params.K_strides[2]; // Sequence
|
||||||
|
|
||||||
|
V += batch_idx * params.V_strides[0] + // Batch
|
||||||
|
kv_head_idx * params.V_strides[1] + // Head
|
||||||
|
kv_seq_idx * params.V_strides[2]; // Sequence
|
||||||
|
|
||||||
|
const int p_stride_s = blocks;
|
||||||
|
const int p_stride_h = params.qL * p_stride_s;
|
||||||
|
const int p_stride_b = params.H * p_stride_h;
|
||||||
|
const int p_offset = batch_idx * p_stride_b + // Batch
|
||||||
|
head_idx * p_stride_h + // Head
|
||||||
|
q_seq_idx * p_stride_s + // Sequence
|
||||||
|
block_idx; // Block
|
||||||
|
|
||||||
|
partials += p_offset * D;
|
||||||
|
sums += p_offset;
|
||||||
|
maxs += p_offset;
|
||||||
|
|
||||||
|
// Read the query and 0 the output accumulator
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
o[i] = 0.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
U max_score = -1e9;
|
||||||
|
U sum_exp_score = 0.f;
|
||||||
|
|
||||||
|
// For each key
|
||||||
|
for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {
|
||||||
|
bool use_key = true;
|
||||||
|
if constexpr (do_causal) {
|
||||||
|
use_key = i <= (params.kL - params.qL + q_seq_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (use_key) {
|
||||||
|
// Read the key
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int j = 0; j < v_per_thread; j++) {
|
||||||
|
k[j] = K[v_per_thread * lane_idx + j];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the i-th score
|
||||||
|
U score = 0.f;
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int j = 0; j < v_per_thread; j++) {
|
||||||
|
score += q[j] * k[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warp sum
|
||||||
|
score = cg::reduce(warp, score, cg::plus<U>());
|
||||||
|
|
||||||
|
// Update the accumulators
|
||||||
|
U new_max = max(max_score, score);
|
||||||
|
U factor = exp2f(max_score - new_max);
|
||||||
|
U exp_score = exp2f(score - new_max);
|
||||||
|
|
||||||
|
max_score = new_max;
|
||||||
|
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||||
|
|
||||||
|
// Update the output accumulator
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int j = 0; j < v_per_thread; j++) {
|
||||||
|
o[j] = o[j] * factor +
|
||||||
|
exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move the pointers to the next kv
|
||||||
|
K += inner_k_stride;
|
||||||
|
V += inner_v_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lane_idx == 0) {
|
||||||
|
max_scores[warp_idx] = max_score;
|
||||||
|
sum_exp_scores[warp_idx] = sum_exp_score;
|
||||||
|
}
|
||||||
|
|
||||||
|
block.sync();
|
||||||
|
|
||||||
|
max_score = (lane_idx < BN) ? max_scores[lane_idx] : -1e9;
|
||||||
|
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
||||||
|
U factor = exp2f(max_score - new_max);
|
||||||
|
sum_exp_score = (lane_idx < BN) ? sum_exp_scores[lane_idx] : 0.f;
|
||||||
|
sum_exp_score = cg::reduce(warp, sum_exp_score * factor, cg::plus<U>());
|
||||||
|
|
||||||
|
// Write the sum and new max
|
||||||
|
if (warp_idx == 0) {
|
||||||
|
sums[0] = sum_exp_score;
|
||||||
|
maxs[0] = new_max;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now we need to aggregate all the outputs
|
||||||
|
auto ff = exp2f(max_scores[warp_idx] - new_max);
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
outputs[warp_idx][lane_idx] = o[i] * ff;
|
||||||
|
block.sync();
|
||||||
|
|
||||||
|
if (warp_idx == 0) {
|
||||||
|
U ot = outputs[0][lane_idx];
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int j = 1; j < BN; j++) {
|
||||||
|
ot += outputs[j][lane_idx];
|
||||||
|
warp.sync();
|
||||||
|
}
|
||||||
|
o[i] = ot;
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (warp_idx == 0) {
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
partials[v_per_thread * lane_idx + i] = o[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, bool do_causal, int D>
|
||||||
|
__global__ void kernel_sdpav_2pass_2(
|
||||||
|
const float* partials,
|
||||||
|
const float* sums,
|
||||||
|
const float* maxs,
|
||||||
|
T* O,
|
||||||
|
__grid_constant__ const AttnParams params) {
|
||||||
|
constexpr int BN = 32;
|
||||||
|
constexpr int BD = 32;
|
||||||
|
constexpr int blocks = 32;
|
||||||
|
|
||||||
|
constexpr int v_per_thread = D / BD;
|
||||||
|
|
||||||
|
typedef float U;
|
||||||
|
|
||||||
|
U o[v_per_thread];
|
||||||
|
__shared__ U outputs[BN][BD + 1];
|
||||||
|
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<32>(block);
|
||||||
|
|
||||||
|
const int lane_idx = warp.thread_rank();
|
||||||
|
const int warp_idx = warp.meta_group_rank();
|
||||||
|
|
||||||
|
// Adjust to thread block and thread
|
||||||
|
const int batch_idx = blockIdx.z;
|
||||||
|
const int head_idx = blockIdx.x;
|
||||||
|
const int q_seq_idx = blockIdx.y;
|
||||||
|
|
||||||
|
const int p_stride_s = blocks;
|
||||||
|
const int p_stride_h = params.qL * p_stride_s;
|
||||||
|
const int p_stride_b = params.H * p_stride_h;
|
||||||
|
const int p_offset = batch_idx * p_stride_b + // Batch
|
||||||
|
head_idx * p_stride_h + // Head
|
||||||
|
q_seq_idx * p_stride_s; // Sequence
|
||||||
|
|
||||||
|
partials += p_offset * D + warp_idx * D;
|
||||||
|
sums += p_offset;
|
||||||
|
maxs += p_offset;
|
||||||
|
|
||||||
|
O += batch_idx * params.O_strides[0] + // Batch
|
||||||
|
head_idx * params.O_strides[1] + // Head
|
||||||
|
q_seq_idx * params.O_strides[2]; // Sequence
|
||||||
|
|
||||||
|
U max_score = maxs[lane_idx];
|
||||||
|
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
||||||
|
U factor = exp2f(max_score - new_max);
|
||||||
|
U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());
|
||||||
|
sum_exp_score = __frcp_rn(sum_exp_score);
|
||||||
|
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
o[i] = partials[v_per_thread * lane_idx + i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now we need to aggregate all the outputs
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
outputs[lane_idx][warp_idx] = o[i];
|
||||||
|
block.sync();
|
||||||
|
U ot = outputs[warp_idx][lane_idx] * factor;
|
||||||
|
o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;
|
||||||
|
block.sync();
|
||||||
|
}
|
||||||
|
|
||||||
|
// And write the output
|
||||||
|
if (lane_idx == 0) {
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
void dispatch_headdim(int n, F&& f) {
|
||||||
|
switch (n) {
|
||||||
|
case 64:
|
||||||
|
f(std::integral_constant<int, 64>{});
|
||||||
|
break;
|
||||||
|
case 96:
|
||||||
|
f(std::integral_constant<int, 96>{});
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
f(std::integral_constant<int, 128>{});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void sdpa_vector_1pass_fallback(
|
||||||
|
const Stream& s,
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
const float scale,
|
||||||
|
array& o,
|
||||||
|
bool do_causal_ = false) {
|
||||||
|
encoder.set_input_array(q);
|
||||||
|
encoder.set_input_array(k);
|
||||||
|
encoder.set_input_array(v);
|
||||||
|
encoder.set_output_array(o);
|
||||||
|
|
||||||
|
cu::AttnParams params{
|
||||||
|
/* int B = */ q.shape(0),
|
||||||
|
/* int H = */ q.shape(1),
|
||||||
|
/* int D = */ q.shape(3),
|
||||||
|
|
||||||
|
/* int qL = */ q.shape(2),
|
||||||
|
/* int kL = */ k.shape(2),
|
||||||
|
|
||||||
|
/* int gqa_factor = */ q.shape(1) / k.shape(1),
|
||||||
|
/* float scale = */ scale,
|
||||||
|
|
||||||
|
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||||
|
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||||
|
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||||
|
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
|
||||||
|
|
||||||
|
dim3 grid_dim(params.H, params.qL, params.B);
|
||||||
|
dim3 block_dim(1024, 1, 1);
|
||||||
|
|
||||||
|
dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) {
|
||||||
|
dispatch_bool(do_causal_, [&](auto do_causal) {
|
||||||
|
dispatch_headdim(params.D, [&](auto headdim) {
|
||||||
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
|
||||||
|
auto kernel =
|
||||||
|
cu::kernel_sdpav_1pass<DataType, do_causal.value, headdim.value>;
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
grid_dim,
|
||||||
|
block_dim,
|
||||||
|
0,
|
||||||
|
q.data<DataType>(),
|
||||||
|
k.data<DataType>(),
|
||||||
|
v.data<DataType>(),
|
||||||
|
o.data<DataType>(),
|
||||||
|
params);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void sdpa_vector_2pass_fallback(
|
||||||
|
const Stream& s,
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
const float scale,
|
||||||
|
array& o,
|
||||||
|
bool do_causal_ = false) {
|
||||||
|
cu::AttnParams params{
|
||||||
|
/* int B = */ q.shape(0),
|
||||||
|
/* int H = */ q.shape(1),
|
||||||
|
/* int D = */ q.shape(3),
|
||||||
|
|
||||||
|
/* int qL = */ q.shape(2),
|
||||||
|
/* int kL = */ k.shape(2),
|
||||||
|
|
||||||
|
/* int gqa_factor = */ q.shape(1) / k.shape(1),
|
||||||
|
/* float scale = */ scale,
|
||||||
|
|
||||||
|
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||||
|
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||||
|
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||||
|
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
|
||||||
|
|
||||||
|
// Allocate the intermediates
|
||||||
|
int blocks = 32;
|
||||||
|
|
||||||
|
Shape intermediate_shape;
|
||||||
|
intermediate_shape.reserve(o.ndim() + 1);
|
||||||
|
intermediate_shape.insert(
|
||||||
|
intermediate_shape.end(), o.shape().begin(), o.shape().end() - 1);
|
||||||
|
intermediate_shape.push_back(blocks);
|
||||||
|
intermediate_shape.push_back(o.shape().back());
|
||||||
|
|
||||||
|
array intermediate(intermediate_shape, float32, nullptr, {});
|
||||||
|
intermediate_shape.pop_back();
|
||||||
|
array sums(intermediate_shape, float32, nullptr, {});
|
||||||
|
array maxs(std::move(intermediate_shape), float32, nullptr, {});
|
||||||
|
|
||||||
|
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||||
|
sums.set_data(allocator::malloc(sums.nbytes()));
|
||||||
|
maxs.set_data(allocator::malloc(maxs.nbytes()));
|
||||||
|
|
||||||
|
encoder.add_temporary(intermediate);
|
||||||
|
encoder.add_temporary(sums);
|
||||||
|
encoder.add_temporary(maxs);
|
||||||
|
|
||||||
|
dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) {
|
||||||
|
dispatch_bool(do_causal_, [&](auto do_causal) {
|
||||||
|
dispatch_headdim(params.D, [&](auto headdim) {
|
||||||
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
|
||||||
|
{
|
||||||
|
auto kernel = cu::
|
||||||
|
kernel_sdpav_2pass_1<DataType, do_causal.value, headdim.value>;
|
||||||
|
|
||||||
|
encoder.set_input_array(q);
|
||||||
|
encoder.set_input_array(k);
|
||||||
|
encoder.set_input_array(v);
|
||||||
|
encoder.set_output_array(intermediate);
|
||||||
|
encoder.set_output_array(sums);
|
||||||
|
encoder.set_output_array(maxs);
|
||||||
|
|
||||||
|
dim3 grid_dim(params.H, params.qL, params.B * 32);
|
||||||
|
dim3 block_dim(8 * 32, 1, 1);
|
||||||
|
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
grid_dim,
|
||||||
|
block_dim,
|
||||||
|
0,
|
||||||
|
q.data<DataType>(),
|
||||||
|
k.data<DataType>(),
|
||||||
|
v.data<DataType>(),
|
||||||
|
intermediate.data<float>(),
|
||||||
|
sums.data<float>(),
|
||||||
|
maxs.data<float>(),
|
||||||
|
params);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto kernel = cu::
|
||||||
|
kernel_sdpav_2pass_2<DataType, do_causal.value, headdim.value>;
|
||||||
|
|
||||||
|
encoder.set_input_array(intermediate);
|
||||||
|
encoder.set_input_array(sums);
|
||||||
|
encoder.set_input_array(maxs);
|
||||||
|
encoder.set_output_array(o);
|
||||||
|
|
||||||
|
dim3 grid_dim(params.H, params.qL, params.B);
|
||||||
|
dim3 block_dim(1024, 1, 1);
|
||||||
|
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
grid_dim,
|
||||||
|
block_dim,
|
||||||
|
0,
|
||||||
|
intermediate.data<float>(),
|
||||||
|
sums.data<float>(),
|
||||||
|
maxs.data<float>(),
|
||||||
|
o.data<DataType>(),
|
||||||
|
params);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void sdpa_vector_fallback(
|
||||||
|
const Stream& s,
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
const float scale,
|
||||||
|
array& o,
|
||||||
|
bool do_causal_ = false) {
|
||||||
|
int kL = k.shape(2);
|
||||||
|
|
||||||
|
if (kL > 1024) {
|
||||||
|
return sdpa_vector_2pass_fallback(
|
||||||
|
s, encoder, q, k, v, scale, o, do_causal_);
|
||||||
|
} else {
|
||||||
|
return sdpa_vector_1pass_fallback(
|
||||||
|
s, encoder, q, k, v, scale, o, do_causal_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace fast {
|
||||||
|
|
||||||
|
bool ScaledDotProductAttention::use_fallback(
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
bool has_mask,
|
||||||
|
bool has_arr_mask,
|
||||||
|
bool do_causal,
|
||||||
|
Stream s) {
|
||||||
|
if (detail::in_grad_tracing()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (s.device == Device::cpu) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int value_head_dim = v.shape(-1);
|
||||||
|
const int query_head_dim = q.shape(-1);
|
||||||
|
const int query_sequence_length = q.shape(2);
|
||||||
|
const int key_sequence_length = k.shape(2);
|
||||||
|
|
||||||
|
const bool sdpa_supported_head_dim = query_head_dim == value_head_dim &&
|
||||||
|
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
|
||||||
|
|
||||||
|
const bool supported_vector_config =
|
||||||
|
sdpa_supported_head_dim && query_sequence_length < 4;
|
||||||
|
|
||||||
|
const bool supported_config = supported_vector_config;
|
||||||
|
|
||||||
|
return has_arr_mask || !supported_config;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ScaledDotProductAttention::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
array& out) {
|
||||||
|
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
|
auto& q_pre = inputs[0];
|
||||||
|
auto& k_pre = inputs[1];
|
||||||
|
auto& v_pre = inputs[2];
|
||||||
|
auto& o = out;
|
||||||
|
|
||||||
|
std::vector<array> copies;
|
||||||
|
|
||||||
|
// Define some copy functions to ensure the layout of the inputs is as
|
||||||
|
// expected.
|
||||||
|
copies.reserve(3);
|
||||||
|
auto copy_unless = [&copies, &s](
|
||||||
|
auto predicate, const array& arr) -> const array& {
|
||||||
|
if (!predicate(arr)) {
|
||||||
|
array arr_copy = contiguous_copy_gpu(arr, s);
|
||||||
|
copies.push_back(std::move(arr_copy));
|
||||||
|
return copies.back();
|
||||||
|
} else {
|
||||||
|
return arr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// We are in vector mode ie single query
|
||||||
|
if (q_pre.shape(2) < 4) {
|
||||||
|
auto q_copy_unless = [](const array& arr) {
|
||||||
|
if (arr.flags().row_contiguous) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
auto& strides = arr.strides();
|
||||||
|
auto& shape = arr.shape();
|
||||||
|
if (shape[0] == 1 || shape[1] == 1) {
|
||||||
|
// If either the batch or head dimension is a singleton, the other can
|
||||||
|
// be transposed with the sequence dimension
|
||||||
|
auto bidx = shape[0] == 1 ? 1 : 0;
|
||||||
|
return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) &&
|
||||||
|
(strides[bidx] == shape[3]);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto kv_copy_unless = [](const array& arr) {
|
||||||
|
// keys and values should be copied if:
|
||||||
|
// - the last dimension is not contiguous
|
||||||
|
// - the batch and head dim are not contiguous
|
||||||
|
auto& strides = arr.strides();
|
||||||
|
auto& shape = arr.shape();
|
||||||
|
if (strides.back() != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (shape[0] == 1 || shape[1] == 1) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return (strides[0] == strides[1] * shape[1]);
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto& q = copy_unless(q_copy_unless, q_pre);
|
||||||
|
const auto& k = copy_unless(kv_copy_unless, k_pre);
|
||||||
|
const auto& v = copy_unless(kv_copy_unless, v_pre);
|
||||||
|
|
||||||
|
for (const auto& cp : copies) {
|
||||||
|
encoder.add_temporary(cp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Donate the query if possible
|
||||||
|
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
|
||||||
|
o.copy_shared_buffer(q);
|
||||||
|
} else {
|
||||||
|
int64_t str_oD = 1;
|
||||||
|
int64_t str_oH = o.shape(3);
|
||||||
|
int64_t str_oL = o.shape(1) * str_oH;
|
||||||
|
int64_t str_oB = o.shape(2) * str_oL;
|
||||||
|
size_t data_size = o.shape(0) * str_oB;
|
||||||
|
|
||||||
|
array::Flags flags{
|
||||||
|
/* bool contiguous = */ 1,
|
||||||
|
/* bool row_contiguous = */ o.shape(2) == 1,
|
||||||
|
/* bool col_contiguous = */ 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
o.set_data(
|
||||||
|
allocator::malloc(o.nbytes()),
|
||||||
|
data_size,
|
||||||
|
{str_oB, str_oH, str_oL, str_oD},
|
||||||
|
flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Full attention mode should never reach here
|
||||||
|
else {
|
||||||
|
throw std::runtime_error("Doesn't support matrix yet.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace fast
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -104,7 +104,7 @@ struct CommandEncoder {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Outputs of all kernels in the encoder including temporaries
|
// Outputs of all kernels in the encoder including temporaries
|
||||||
std::unordered_set<const void*> outputs() {
|
std::unordered_set<const void*>& outputs() {
|
||||||
return all_outputs_;
|
return all_outputs_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
31
mlx/ops.cpp
31
mlx/ops.cpp
@@ -2381,9 +2381,20 @@ array logsumexp(
|
|||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[logsumexp] Received non-empty axes for array with 0 dimensions.");
|
"[logsumexp] Received non-empty axes for array with 0 dimensions.");
|
||||||
}
|
}
|
||||||
|
bool reduce_last_dim =
|
||||||
|
!axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1);
|
||||||
|
if (reduce_last_dim) {
|
||||||
|
// For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape
|
||||||
|
// is [1, 1, ..., N].
|
||||||
|
for (int i = axes.size() - 2; i >= 0; --i) {
|
||||||
|
if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) {
|
||||||
|
reduce_last_dim = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
bool is_complex = issubdtype(a.dtype(), complexfloating);
|
bool is_complex = issubdtype(a.dtype(), complexfloating);
|
||||||
if (!is_complex && axes.size() == 1 &&
|
if (!is_complex && reduce_last_dim) {
|
||||||
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
|
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
auto out_shape = a.shape();
|
auto out_shape = a.shape();
|
||||||
out_shape.back() = 1;
|
out_shape.back() = 1;
|
||||||
@@ -3403,10 +3414,20 @@ array softmax(
|
|||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[softmax] Received non-empty axes for array with 0 dimensions.");
|
"[softmax] Received non-empty axes for array with 0 dimensions.");
|
||||||
}
|
}
|
||||||
|
bool reduce_last_dim =
|
||||||
|
!axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1);
|
||||||
|
if (reduce_last_dim) {
|
||||||
|
// For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape
|
||||||
|
// is [1, 1, ..., N].
|
||||||
|
for (int i = axes.size() - 2; i >= 0; --i) {
|
||||||
|
if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) {
|
||||||
|
reduce_last_dim = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
bool is_complex = issubdtype(a.dtype(), complexfloating);
|
bool is_complex = issubdtype(a.dtype(), complexfloating);
|
||||||
if (!is_complex && axes.size() == 1 &&
|
if (!is_complex && reduce_last_dim) {
|
||||||
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
|
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
return array(
|
return array(
|
||||||
a.shape(),
|
a.shape(),
|
||||||
|
|||||||
@@ -3,8 +3,8 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#define MLX_VERSION_MAJOR 0
|
#define MLX_VERSION_MAJOR 0
|
||||||
#define MLX_VERSION_MINOR 27
|
#define MLX_VERSION_MINOR 28
|
||||||
#define MLX_VERSION_PATCH 1
|
#define MLX_VERSION_PATCH 0
|
||||||
#define MLX_VERSION_NUMERIC \
|
#define MLX_VERSION_NUMERIC \
|
||||||
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
||||||
|
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class Module(dict):
|
|||||||
|
|
||||||
if strict:
|
if strict:
|
||||||
new_weights = dict(weights)
|
new_weights = dict(weights)
|
||||||
curr_weights = dict(tree_flatten(self.parameters()))
|
curr_weights = tree_flatten(self.parameters(), destination={})
|
||||||
if extras := (new_weights.keys() - curr_weights.keys()):
|
if extras := (new_weights.keys() - curr_weights.keys()):
|
||||||
num_extra = len(extras)
|
num_extra = len(extras)
|
||||||
extras = ",\n".join(sorted(extras))
|
extras = ",\n".join(sorted(extras))
|
||||||
@@ -212,7 +212,7 @@ class Module(dict):
|
|||||||
- ``.npz`` will use :func:`mx.savez`
|
- ``.npz`` will use :func:`mx.savez`
|
||||||
- ``.safetensors`` will use :func:`mx.save_safetensors`
|
- ``.safetensors`` will use :func:`mx.save_safetensors`
|
||||||
"""
|
"""
|
||||||
params_dict = dict(tree_flatten(self.parameters()))
|
params_dict = tree_flatten(self.parameters(), destination={})
|
||||||
|
|
||||||
if file.endswith(".npz"):
|
if file.endswith(".npz"):
|
||||||
mx.savez(file, **params_dict)
|
mx.savez(file, **params_dict)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
from typing import Any, Callable, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
def tree_map(
|
def tree_map(
|
||||||
@@ -114,8 +114,11 @@ def tree_map_with_path(
|
|||||||
|
|
||||||
|
|
||||||
def tree_flatten(
|
def tree_flatten(
|
||||||
tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None
|
tree: Any,
|
||||||
) -> Any:
|
prefix: str = "",
|
||||||
|
is_leaf: Optional[Callable] = None,
|
||||||
|
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None,
|
||||||
|
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
|
||||||
"""Flattens a Python tree to a list of key, value tuples.
|
"""Flattens a Python tree to a list of key, value tuples.
|
||||||
|
|
||||||
The keys are using the dot notation to define trees of arbitrary depth and
|
The keys are using the dot notation to define trees of arbitrary depth and
|
||||||
@@ -128,9 +131,12 @@ def tree_flatten(
|
|||||||
print(tree_flatten([[[0]]]))
|
print(tree_flatten([[[0]]]))
|
||||||
# [("0.0.0", 0)]
|
# [("0.0.0", 0)]
|
||||||
|
|
||||||
print(tree_flatten([[[0]]], ".hello"))
|
print(tree_flatten([[[0]]], prefix=".hello"))
|
||||||
# [("hello.0.0.0", 0)]
|
# [("hello.0.0.0", 0)]
|
||||||
|
|
||||||
|
tree_flatten({"a": {"b": 1}}, destination={})
|
||||||
|
{"a.b": 1}
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
Dictionaries should have keys that are valid Python identifiers.
|
Dictionaries should have keys that are valid Python identifiers.
|
||||||
|
|
||||||
@@ -140,26 +146,50 @@ def tree_flatten(
|
|||||||
always discarded.
|
always discarded.
|
||||||
is_leaf (callable): An optional callable that returns True if the
|
is_leaf (callable): An optional callable that returns True if the
|
||||||
passed object is considered a leaf or False otherwise.
|
passed object is considered a leaf or False otherwise.
|
||||||
|
destination (list or dict, optional): A list or dictionary to store the
|
||||||
|
flattened tree. If None an empty list will be used. Default: ``None``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Tuple[str, Any]]: The flat representation of the Python tree.
|
Union[List[Tuple[str, Any]], Dict[str, Any]]: The flat representation of
|
||||||
|
the Python tree.
|
||||||
"""
|
"""
|
||||||
flat_tree = []
|
if destination is None:
|
||||||
|
destination = []
|
||||||
|
|
||||||
if is_leaf is None or not is_leaf(tree):
|
# Create the function to update the destination. We are taking advantage of
|
||||||
|
# the fact that list.extend and dict.update have the same API to simplify
|
||||||
|
# the code a bit.
|
||||||
|
if isinstance(destination, list):
|
||||||
|
_add_to_destination = destination.extend
|
||||||
|
elif isinstance(destination, dict):
|
||||||
|
_add_to_destination = destination.update
|
||||||
|
else:
|
||||||
|
raise ValueError("Destination should be either a list or a dictionary or None")
|
||||||
|
|
||||||
|
# Leaf identified by is_leaf so add it and return
|
||||||
|
if is_leaf is not None and is_leaf(tree):
|
||||||
|
_add_to_destination([(prefix[1:], tree)])
|
||||||
|
return destination
|
||||||
|
|
||||||
|
# List or tuple so recursively add each subtree
|
||||||
if isinstance(tree, (list, tuple)):
|
if isinstance(tree, (list, tuple)):
|
||||||
for i, t in enumerate(tree):
|
for i, item in enumerate(tree):
|
||||||
flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf))
|
tree_flatten(item, f"{prefix}.{i}", is_leaf, destination)
|
||||||
return flat_tree
|
return destination
|
||||||
|
|
||||||
|
# Dictionary so recursively add each subtree
|
||||||
if isinstance(tree, dict):
|
if isinstance(tree, dict):
|
||||||
for k, t in tree.items():
|
for key, value in tree.items():
|
||||||
flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf))
|
tree_flatten(value, f"{prefix}.{key}", is_leaf, destination)
|
||||||
return flat_tree
|
return destination
|
||||||
|
|
||||||
return [(prefix[1:], tree)]
|
# Leaf so add it and return
|
||||||
|
_add_to_destination([(prefix[1:], tree)])
|
||||||
|
|
||||||
|
return destination
|
||||||
|
|
||||||
|
|
||||||
def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
|
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
|
||||||
"""Recreate a Python tree from its flat representation.
|
"""Recreate a Python tree from its flat representation.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@@ -170,31 +200,34 @@ def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
|
|||||||
print(d)
|
print(d)
|
||||||
# {"hello": {"world": 42}}
|
# {"hello": {"world": 42}}
|
||||||
|
|
||||||
|
d = tree_unflatten({"hello.world": 42})
|
||||||
|
print(d)
|
||||||
|
# {"hello": {"world": 42}}
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tree (list[tuple[str, Any]]): The flat representation of a Python tree.
|
tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree.
|
||||||
For instance as returned by :meth:`tree_flatten`.
|
For instance as returned by :meth:`tree_flatten`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A Python tree.
|
A Python tree.
|
||||||
"""
|
"""
|
||||||
if len(tree) == 1 and tree[0][0] == "":
|
items = tree.items() if isinstance(tree, dict) else tree
|
||||||
return tree[0][1]
|
|
||||||
|
|
||||||
try:
|
# Special case when we have just one element in the tree ie not a tree
|
||||||
int(tree[0][0].split(".", maxsplit=1)[0])
|
if len(items) == 1:
|
||||||
is_list = True
|
key, value = next(iter(items))
|
||||||
except ValueError:
|
if key == "":
|
||||||
is_list = False
|
return value
|
||||||
|
|
||||||
# collect children
|
# collect children
|
||||||
children = defaultdict(list)
|
children = defaultdict(list)
|
||||||
for key, value in tree:
|
for key, value in items:
|
||||||
current_idx, *next_idx = key.split(".", maxsplit=1)
|
current_idx, *next_idx = key.split(".", maxsplit=1)
|
||||||
next_idx = "" if not next_idx else next_idx[0]
|
next_idx = "" if not next_idx else next_idx[0]
|
||||||
children[current_idx].append((next_idx, value))
|
children[current_idx].append((next_idx, value))
|
||||||
|
|
||||||
# recursively map them to the original container
|
# Assume they are a list and fail to dict if the keys are not all integers
|
||||||
if is_list:
|
try:
|
||||||
keys = sorted((int(idx), idx) for idx in children.keys())
|
keys = sorted((int(idx), idx) for idx in children.keys())
|
||||||
l = []
|
l = []
|
||||||
for i, k in keys:
|
for i, k in keys:
|
||||||
@@ -202,7 +235,7 @@ def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
|
|||||||
l.extend([{} for _ in range(i - len(l))])
|
l.extend([{} for _ in range(i - len(l))])
|
||||||
l.append(tree_unflatten(children[k]))
|
l.append(tree_unflatten(children[k]))
|
||||||
return l
|
return l
|
||||||
else:
|
except ValueError:
|
||||||
return {k: tree_unflatten(v) for k, v in children.items()}
|
return {k: tree_unflatten(v) for k, v in children.items()}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
|
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
|
||||||
|
|
||||||
model = DictModule()
|
model = DictModule()
|
||||||
params = dict(tree_flatten(model.parameters()))
|
params = tree_flatten(model.parameters(), destination={})
|
||||||
self.assertEqual(len(params), 2)
|
self.assertEqual(len(params), 2)
|
||||||
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
|
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
|
||||||
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))
|
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))
|
||||||
|
|||||||
Reference in New Issue
Block a user