Compare commits

...

8 Commits

Author SHA1 Message Date
Eric Buehler
9c62563122
Merge 4d68bd3250 into 8402a2acf4 2025-06-13 21:05:44 +02:00
Awni Hannun
8402a2acf4
Fix complex power and print (#2286)
* fix complex power and print

* fix complex matmul shape
2025-06-13 11:13:00 -07:00
Jagrit Digani
fddb6933e1
Collection of refactors (#2274)
* Refactor gemv into a function

* Refactor splitk step 1

* Refactor split k axpby

* Rearrange steel_gemm_regular

* Redirect steel_gemm_regular

* Add axpby routing to steel_matmul_regular

* Refactor AddMM step 1

* Redirect steel_gemm

* Update addmm

* Comments and format

* Some cleanup

* Add architecture gen to device

* Update no copy condition in normalization to account for axis size 1
2025-06-13 10:44:56 -07:00
Eric Buehler
4d68bd3250 Refactor v1/v2 caller code 2025-05-31 09:48:24 -04:00
Eric Buehler
5fbce6c49e Add v2 call 2025-05-31 09:30:51 -04:00
Eric Buehler
0b5c5680f4 Add v1 call 2025-05-31 09:20:22 -04:00
Eric Buehler
221edc4a65 Add the attention kernel 2025-05-31 08:25:25 -04:00
Eric Buehler
190c72739b Add pagedattn primitive 2025-05-31 08:10:33 -04:00
25 changed files with 2740 additions and 668 deletions

View File

@ -12,6 +12,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/paged_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp

View File

@ -194,6 +194,13 @@ struct Power {
}
return res;
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (base.y == 0 && base.x == 0) {
if (isnan(exp.x) || isnan(exp.y)) {
auto nan = cuda::std::numeric_limits<float>::quiet_NaN();
return make_cuFloatComplex(nan, nan);
}
return make_cuFloatComplex(0.0, 0.0);
}
auto x_theta = atan2f(base.y, base.x);
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);

View File

@ -102,6 +102,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/paged_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp

View File

@ -155,26 +155,26 @@ void explicit_gemm_conv_group_ND_gpu(
// Perform gemm
std::vector<array> copies = {in_unfolded, wt_transpose};
return steel_matmul_regular(
s,
d,
/* a = */ in_unfolded,
/* b = */ wt_transpose,
/* c = */ out,
/* M = */ implicit_M,
/* N = */ implicit_N,
/* K = */ implicit_K,
/* batch_size_out = */ groups,
/* a_cols = */ implicit_K * groups,
/* b_cols = */ implicit_K,
/* out_cols = */ implicit_N * groups,
/* a_transposed = */ false,
/* b_transposed = */ true,
/* batch_shape = */ {1},
/* batch_strides = */ {0},
/* A_batch_strides = */ size_t(implicit_K),
/* B_batch_strides = */ size_t(implicit_N) * implicit_K,
/* matrix_stride_out = */ size_t(implicit_N),
/*copies = */ copies);
/* const Stream& s = */ s,
/* Device& d = */ d,
/* const array& a = */ in_unfolded,
/* const array& b = */ wt_transpose,
/* array& c = */ out,
/* int M = */ implicit_M,
/* int N = */ implicit_N,
/* int K = */ implicit_K,
/* int batch_size_out = */ groups,
/* int lda = */ implicit_K * groups,
/* int ldb = */ implicit_K,
/* int ldd = */ implicit_N * groups,
/* bool transpose_a = */ false,
/* bool transpose_b = */ true,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ {1},
/* Strides batch_strides = */ {0},
/* int64_t A_batch_strides = */ int64_t(implicit_K),
/* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K,
/* int64_t matrix_stride_out = */ int64_t(implicit_N));
}
void implicit_gemm_conv_2D_gpu(

View File

@ -297,6 +297,9 @@ Device::Device() {
device_ = load_device();
default_library_ = load_default_library(device_);
arch_ = std::string(device_->architecture()->name()->utf8String());
int ag_tens = arch_[arch_.size() - 3] - '0';
int ag_ones = arch_[arch_.size() - 2] - '0';
arch_gen_ = ag_tens * 10 + ag_ones;
auto arch = arch_.back();
switch (arch) {
case 'p': // phone

View File

@ -177,6 +177,10 @@ class Device {
return arch_;
}
int get_architecture_gen() const {
return arch_gen_;
}
void new_queue(int index);
MTL::CommandQueue* get_queue(Stream stream);
@ -268,6 +272,7 @@ class Device {
library_kernels_;
const MTL::ResidencySet* residency_set_{nullptr};
std::string arch_;
int arch_gen_;
int max_ops_per_buffer_;
int max_mb_per_buffer_;
};

View File

@ -241,6 +241,13 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
int wn,
bool transpose);
MTL::ComputePipelineState* get_paged_attention_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const std::string&);
// Create a GPU kernel template definition for JIT compilation
template <typename... Args>
std::string

View File

@ -109,6 +109,7 @@ if(NOT MLX_METAL_JIT)
reduction/reduce_row.h)
build_kernel(quantized quantized.h ${STEEL_HEADERS})
build_kernel(scan scan.h)
build_kernel(paged_attention paged_attention.h)
build_kernel(softmax softmax.h)
build_kernel(logsumexp logsumexp.h)
build_kernel(sort sort.h)

View File

@ -235,6 +235,13 @@ struct Power {
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (x.real == 0 && x.imag == 0) {
if (metal::isnan(y.real) || metal::isnan(y.imag)) {
auto nan = metal::numeric_limits<float>::quiet_NaN();
return {nan, nan};
}
return {0.0, 0.0};
}
auto x_theta = metal::atan2(x.imag, x.real);
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,131 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/paged_attention.h"
#include "mlx/backend/metal/kernels/utils.h"
#define instantiate_paged_attention_inner( \
type, head_size, block_size, num_threads, num_simd_lanes, partition_size) \
template \
[[host_name("paged_attention_" #type "_hs" #head_size "_bs" #block_size \
"_nt" #num_threads "_nsl" #num_simd_lanes \
"_ps" #partition_size)]] [[kernel]] void \
paged_attention< \
type, \
head_size, \
block_size, \
num_threads, \
num_simd_lanes, \
partition_size>( \
device float* exp_sums \
[[buffer(0), function_constant(use_partitioning)]], \
device float* max_logits \
[[buffer(1), function_constant(use_partitioning)]], \
device type* out [[buffer(2)]], \
device const type* q [[buffer(3)]], \
device const type* k_cache [[buffer(4)]], \
device const type* v_cache [[buffer(5)]], \
const constant int& num_kv_heads [[buffer(6)]], \
const constant float& scale [[buffer(7)]], \
const constant float& softcapping [[buffer(8)]], \
device const uint32_t* block_tables [[buffer(9)]], \
device const uint32_t* context_lens [[buffer(10)]], \
const constant int& max_num_blocks_per_seq [[buffer(11)]], \
device const float* alibi_slopes \
[[buffer(12), function_constant(use_alibi)]], \
const constant int& q_stride [[buffer(13)]], \
const constant int& kv_block_stride [[buffer(14)]], \
const constant int& kv_head_stride [[buffer(15)]], \
threadgroup char* shared_mem [[threadgroup(0)]], \
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
uint3 thread_position_in_threadgroup \
[[thread_position_in_threadgroup]], \
uint simd_tid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_paged_attention_v2_reduce_inner( \
type, head_size, num_threads, num_simd_lanes, partition_size) \
template [[host_name("paged_attention_v2_reduce_" #type "_hs" #head_size \
"_nt" #num_threads "_nsl" #num_simd_lanes \
"_ps" #partition_size)]] [[kernel]] void \
paged_attention_v2_reduce< \
type, \
head_size, \
num_threads, \
num_simd_lanes, \
partition_size>( \
device type * out [[buffer(0)]], \
const device float* exp_sums [[buffer(1)]], \
const device float* max_logits [[buffer(2)]], \
const device type* tmp_out [[buffer(3)]], \
device uint32_t* context_lens [[buffer(4)]], \
const constant int& max_num_partitions [[buffer(5)]], \
threadgroup char* shared_mem [[threadgroup(0)]], \
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \
uint3 threads_per_threadgroup [[threads_per_threadgroup]], \
uint simd_tid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_paged_attention_heads( \
type, block_size, num_threads, num_simd_lanes, partition_size) \
instantiate_paged_attention_inner( \
type, 64, block_size, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_inner( \
type, 80, block_size, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_inner( \
type, 96, block_size, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_inner( \
type, 112, block_size, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_inner( \
type, 128, block_size, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_inner( \
type, 192, block_size, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_inner( \
type, 256, block_size, num_threads, num_simd_lanes, partition_size);
#define instantiate_paged_attention_v2_reduce_heads( \
type, num_threads, num_simd_lanes, partition_size) \
instantiate_paged_attention_v2_reduce_inner( \
type, 64, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_v2_reduce_inner( \
type, 80, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_v2_reduce_inner( \
type, 96, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_v2_reduce_inner( \
type, 112, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_v2_reduce_inner( \
type, 128, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_v2_reduce_inner( \
type, 192, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_v2_reduce_inner( \
type, 256, num_threads, num_simd_lanes, partition_size);
#define instantiate_paged_attention_block_size( \
type, num_threads, num_simd_lanes, partition_size) \
instantiate_paged_attention_heads( \
type, 8, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_heads( \
type, 16, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_heads( \
type, 32, num_threads, num_simd_lanes, partition_size);
// TODO: tune num_threads = 256
// NOTE: partition_size = 0
#define instantiate_paged_attention_v1(type, num_simd_lanes) \
instantiate_paged_attention_block_size(type, 256, num_simd_lanes, 0);
// TODO: tune num_threads = 256
// NOTE: partition_size = 512
#define instantiate_paged_attention_v2(type, num_simd_lanes) \
instantiate_paged_attention_block_size(type, 256, num_simd_lanes, 512); \
instantiate_paged_attention_v2_reduce_heads(type, 256, num_simd_lanes, 512);
instantiate_paged_attention_v1(float, 32);
instantiate_paged_attention_v1(bfloat16_t, 32);
instantiate_paged_attention_v1(half, 32);
instantiate_paged_attention_v2(float, 32);
instantiate_paged_attention_v2(bfloat16_t, 32);
instantiate_paged_attention_v2(half, 32);

View File

@ -33,8 +33,8 @@ template <
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]],
const constant int64_t* batch_strides [[buffer(7)]],
const constant int* batch_shape [[buffer(6), function_constant(has_batch)]],
const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],

File diff suppressed because it is too large Load Diff

View File

@ -6,7 +6,34 @@
namespace mlx::core {
void steel_matmul_regular(
template <bool CHECK_AB = true>
void steel_matmul_regular_axpby(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
const array& c,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
int ldd,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies,
Shape batch_shape,
Strides batch_strides,
int64_t A_batch_stride,
int64_t B_batch_stride,
int64_t matrix_stride_out,
int64_t C_batch_stride = 0,
float alpha = 1.0f,
float beta = 0.0f);
inline void steel_matmul_regular(
const Stream& s,
metal::Device& d,
const array& a,
@ -21,14 +48,61 @@ void steel_matmul_regular(
int ldd,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies,
Shape batch_shape,
Strides batch_strides,
int64_t A_batch_stride,
int64_t B_batch_stride,
int64_t matrix_stride_out,
std::vector<array>& copies);
int64_t matrix_stride_out) {
return steel_matmul_regular_axpby<false>(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* const array& c = */ b,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int batch_size_out = */ batch_size_out,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* int ldd = */ ldd,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ batch_shape,
/* Strides batch_strides = */ batch_strides,
/* int64_t A_batch_stride = */ A_batch_stride,
/* int64_t B_batch_stride = */ B_batch_stride,
/* int64_t matrix_stride_out = */ matrix_stride_out);
}
void steel_matmul(
template <bool CHECK_AB = true>
void steel_matmul_axpby(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
const array& c,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies,
Shape batch_shape = {},
Strides A_batch_stride = {},
Strides B_batch_stride = {},
Strides C_batch_stride = {},
float alpha = 1.0f,
float beta = 0.0f);
inline void steel_matmul(
const Stream& s,
metal::Device& d,
const array& a,
@ -45,6 +119,26 @@ void steel_matmul(
std::vector<array>& copies,
Shape batch_shape = {},
Strides A_batch_stride = {},
Strides B_batch_stride = {});
Strides B_batch_stride = {}) {
return steel_matmul_axpby<false>(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* const array& c = */ b,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int batch_size_out = */ batch_size_out,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ batch_shape,
/* Strides A_batch_stride = */ A_batch_stride,
/* Strides B_batch_stride = */ B_batch_stride);
}
} // namespace mlx::core

View File

@ -288,4 +288,13 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
return d.get_kernel(kernel_name, hash_name, func_consts);
}
MTL::ComputePipelineState* get_paged_attention_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const std::string&) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
}
} // namespace mlx::core

View File

@ -26,7 +26,7 @@ void RMSNorm::eval_gpu(
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);
}
if (no_copy) {
if (x.is_donatable()) {
@ -227,7 +227,7 @@ void LayerNorm::eval_gpu(
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);
}
if (no_copy) {
if (x.is_donatable()) {

View File

@ -0,0 +1,324 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/paged_attention_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core::paged_attention {
static void run_paged_attention(
const array& q,
const array& k_cache,
const array& v_cache,
const array& block_tables,
const array& context_lens,
const int head_size,
const int block_size,
const int num_kv_heads,
const float scale,
const float softcapping,
const int max_context_len,
const int max_num_blocks_per_seq,
const bool use_partitioning,
const std::optional<array> alibi,
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
const int num_heads,
const int num_seqs,
array& out,
metal::Device& d,
const Stream& s) {
const int partition_size = use_partitioning ? 512 : 0;
const int num_threads = 256;
const int num_simd_lanes = 32;
const bool use_alibi = alibi.has_value();
std::string type_string = get_type_string(q.dtype());
std::string kname;
kname.reserve(64);
concatenate(
kname,
"paged_attention_",
type_string,
"_hs",
head_size,
"_bs",
block_size,
"_nt",
num_threads,
"_nsl",
num_simd_lanes,
"_ps",
partition_size);
auto template_def = get_template_definition(
kname,
"paged_attention",
type_string,
head_size,
block_size,
num_threads,
num_simd_lanes,
partition_size);
// Encode and dispatch kernel
metal::MTLFCList func_consts = {
{use_partitioning, MTL::DataType::DataTypeBool, 10},
{use_alibi, MTL::DataType::DataTypeBool, 20},
};
std::string hash_name = kname;
auto kernel = get_paged_attention_kernel(
d, kname, hash_name, func_consts, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
int local_max_num_partitions = 1;
if (use_partitioning) {
local_max_num_partitions =
(max_context_len + partition_size - 1) / partition_size;
}
int logits_size = use_partitioning ? partition_size * size_of(float32) : 0;
int outputs_size = use_partitioning
? ((num_threads / num_simd_lanes) / 2) * head_size * size_of(float32)
: 0;
int shared_mem_size =
use_partitioning ? std::max(logits_size, outputs_size) : 0;
if (use_partitioning) {
compute_encoder.set_threadgroup_memory_length(shared_mem_size, 0);
}
if (use_partitioning) {
auto tmp_out = array(
{num_seqs, num_heads, local_max_num_partitions, head_size}, float32);
tmp_out.set_data(allocator::malloc(tmp_out.nbytes()));
auto exp_sums =
array({num_seqs, num_heads, local_max_num_partitions}, float32);
exp_sums.set_data(allocator::malloc(exp_sums.nbytes()));
std::vector<array> temporaries = {tmp_out, exp_sums};
compute_encoder.set_output_array(tmp_out, 0);
compute_encoder.set_output_array(exp_sums, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder.set_input_array(q, 3);
compute_encoder.set_input_array(k_cache, 4);
compute_encoder.set_input_array(v_cache, 5);
compute_encoder.set_bytes(num_kv_heads, 6);
compute_encoder.set_bytes(scale, 7);
compute_encoder.set_bytes(softcapping, 8);
compute_encoder.set_input_array(block_tables, 9);
compute_encoder.set_input_array(context_lens, 10);
compute_encoder.set_bytes(max_num_blocks_per_seq, 11);
if (use_alibi) {
compute_encoder.set_input_array(alibi.value(), 12);
}
compute_encoder.set_bytes(q_stride, 13);
compute_encoder.set_bytes(kv_block_stride, 14);
compute_encoder.set_bytes(kv_head_stride, 15);
MTL::Size grid_dims(num_heads, num_seqs, local_max_num_partitions);
MTL::Size group_dims(num_threads, 1, 1);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(temporaries), s.index);
} else {
compute_encoder.set_output_array(out, 2);
compute_encoder.set_input_array(q, 3);
compute_encoder.set_input_array(k_cache, 4);
compute_encoder.set_input_array(v_cache, 5);
compute_encoder.set_bytes(num_kv_heads, 6);
compute_encoder.set_bytes(scale, 7);
compute_encoder.set_bytes(softcapping, 8);
compute_encoder.set_input_array(block_tables, 9);
compute_encoder.set_input_array(context_lens, 10);
compute_encoder.set_bytes(max_num_blocks_per_seq, 11);
if (use_alibi) {
compute_encoder.set_input_array(alibi.value(), 12);
}
compute_encoder.set_bytes(q_stride, 13);
compute_encoder.set_bytes(kv_block_stride, 14);
compute_encoder.set_bytes(kv_head_stride, 15);
MTL::Size grid_dims(num_heads, num_seqs, 1);
MTL::Size group_dims(num_threads, 1, 1);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
}
void paged_attention_v1(
const array& q,
const array& k_cache,
const array& v_cache,
const array& block_tables,
const array& context_lens,
const int head_size,
const int block_size,
const int num_kv_heads,
const float scale,
const float softcapping,
const int max_context_len,
const int max_num_blocks_per_seq,
const std::optional<array> alibi,
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
const int num_heads,
const int num_seqs,
array& out,
metal::Device& d,
const Stream& s) {
run_paged_attention(
q,
k_cache,
v_cache,
block_tables,
context_lens,
head_size,
block_size,
num_kv_heads,
scale,
softcapping,
max_context_len,
max_num_blocks_per_seq,
/*use_partitioning=*/false,
alibi,
q_stride,
kv_block_stride,
kv_head_stride,
num_heads,
num_seqs,
out,
d,
s);
}
void paged_attention_v2(
const array& q,
const array& k_cache,
const array& v_cache,
const array& block_tables,
const array& context_lens,
const int head_size,
const int block_size,
const int num_kv_heads,
const float scale,
const float softcapping,
const int max_context_len,
const int max_num_blocks_per_seq,
const int /* max_num_partitions */,
const std::optional<array> alibi,
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
const int num_heads,
const int num_seqs,
array& out,
metal::Device& d,
const Stream& s) {
run_paged_attention(
q,
k_cache,
v_cache,
block_tables,
context_lens,
head_size,
block_size,
num_kv_heads,
scale,
softcapping,
max_context_len,
max_num_blocks_per_seq,
/*use_partitioning=*/true,
alibi,
q_stride,
kv_block_stride,
kv_head_stride,
num_heads,
num_seqs,
out,
d,
s);
}
void PagedAttention::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
out.set_data(allocator::malloc(out.nbytes()));
auto& q = inputs[0];
auto& k_cache = inputs[1];
auto& v_cache = inputs[2];
auto& block_tables = inputs[3];
auto& context_lens = inputs[4];
const auto alibi_slopes =
inputs.size() == 6 ? std::optional{inputs[5]} : std::nullopt;
if (use_v1_) {
paged_attention_v1(
q,
k_cache,
v_cache,
block_tables,
context_lens,
head_size_,
block_size_,
num_kv_heads_,
softmax_scale_,
softcapping_.value_or(1.),
max_context_len_,
max_num_blocks_per_seq_,
alibi_slopes,
q_stride_,
kv_block_stride_,
kv_head_stride_,
num_heads_,
num_seqs_,
out,
d,
s);
} else {
paged_attention_v2(
q,
k_cache,
v_cache,
block_tables,
context_lens,
head_size_,
block_size_,
num_kv_heads_,
softmax_scale_,
softcapping_.value_or(1.),
max_context_len_,
max_num_blocks_per_seq_,
max_num_partitions_,
alibi_slopes,
q_stride_,
kv_block_stride_,
kv_head_stride_,
num_heads_,
num_seqs_,
out,
d,
s);
}
}
} // namespace mlx::core::paged_attention

View File

@ -17,6 +17,7 @@
#include "mlx/linalg.h"
#include "mlx/memory.h"
#include "mlx/ops.h"
#include "mlx/paged_attention.h"
#include "mlx/random.h"
#include "mlx/stream.h"
#include "mlx/transforms.h"

View File

@ -2847,21 +2847,6 @@ array matmul(
"[matmul] Got 0 dimension input. Inputs must "
"have at least one dimension.");
}
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = expand_dims(a, 0, s);
}
if (b.ndim() == 1) {
// Insert a singleton dim at the end
b = expand_dims(b, 1, s);
}
if (a.shape(-1) != b.shape(-2)) {
std::ostringstream msg;
msg << "[matmul] Last dimension of first input with shape " << a.shape()
<< " must match second to last dimension of"
<< " second input with shape " << b.shape() << ".";
throw std::invalid_argument(msg.str());
}
// complex matmul using Karatsuba's Algorithm
if (a.dtype() == complex64 || b.dtype() == complex64) {
@ -2883,6 +2868,22 @@ array matmul(
c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s);
}
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = expand_dims(a, 0, s);
}
if (b.ndim() == 1) {
// Insert a singleton dim at the end
b = expand_dims(b, 1, s);
}
if (a.shape(-1) != b.shape(-2)) {
std::ostringstream msg;
msg << "[matmul] Last dimension of first input with shape " << a.shape()
<< " must match second to last dimension of"
<< " second input with shape " << b.shape() << ".";
throw std::invalid_argument(msg.str());
}
// Type promotion
auto out_type = promote_types(a.dtype(), b.dtype());
@ -4240,6 +4241,16 @@ array addmm(
"have at least one dimension.");
}
// Type promotion
auto out_type = result_type(a, b, c);
if (out_type == complex64) {
return add(
multiply(matmul(a, b, s), array(alpha), s),
multiply(array(beta), c, s),
s);
}
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = expand_dims(a, 0, s);
@ -4257,16 +4268,6 @@ array addmm(
throw std::invalid_argument(msg.str());
}
// Type promotion
auto out_type = result_type(a, b, c);
if (out_type == complex64) {
return add(
multiply(matmul(a, b, s), array(alpha), s),
multiply(array(beta), c, s),
s);
}
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[addmm] Only real floating point types are supported but "

170
mlx/paged_attention.cpp Normal file
View File

@ -0,0 +1,170 @@
// Copyright © 2023-2024 Apple Inc.
// Required for using M_PI in MSVC.
#define _USE_MATH_DEFINES
#include <algorithm>
#include <climits>
#include <cmath>
#include <numeric>
#include <set>
#include <sstream>
#include <stdexcept>
#include "mlx/paged_attention_primitives.h"
#include "mlx/utils.h"
namespace mlx::core::paged_attention {
array paged_attention(
const array& q,
const array& k_cache,
const array& v_cache,
const array& block_tables,
const array& context_lens,
int max_context_len,
float softmax_scale,
std::optional<array> alibi_slopes = std::nullopt,
std::optional<float> softcapping = std::nullopt,
StreamOrDevice s_ = {}) {
auto s = to_stream(s_);
// supported dtypes
if (!issubdtype(q.dtype(), floating)) {
throw std::invalid_argument(
"[paged_attention] Only real floating types are supported");
}
if (!(q.dtype() == k_cache.dtype() && k_cache.dtype() == v_cache.dtype())) {
throw std::invalid_argument(
"[paged_attention] q/k_cache/v_cache dtype must match");
}
if (!(block_tables.dtype() == uint32 && context_lens.dtype() == uint32)) {
throw std::invalid_argument(
"[paged_attention] block_tables/context_lens dtype must be uint32");
}
// rank checks
if (q.ndim() != 3)
throw std::invalid_argument("[paged_attention] `q` must be rank-3");
if (k_cache.ndim() != 5)
throw std::invalid_argument("[paged_attention] `k_cache` must be rank-5");
if (v_cache.ndim() != 4)
throw std::invalid_argument("[paged_attention] `v_cache` must be rank-4");
if (block_tables.ndim() != 2)
throw std::invalid_argument(
"[paged_attention] `block_tables` must be rank-2");
if (context_lens.ndim() != 1)
throw std::invalid_argument(
"[paged_attention] `context_lens` must be rank-1");
// 4. Shape consistency
const auto& q_shape = q.shape(); // [num_seqs, num_heads, head_size]
const auto& kc_shape = k_cache.shape();
const auto& vc_shape = v_cache.shape();
const auto& bt_shape = block_tables.shape();
const auto& cl_shape = context_lens.shape();
int num_seqs = q_shape[0];
int num_heads = q_shape[1];
int head_size = q_shape[2];
// Allowed head sizes
switch (head_size) {
case 64:
case 80:
case 96:
case 112:
case 128:
case 192:
case 256:
break;
default:
throw std::invalid_argument(
"[paged_attention] `head_size` must be one of "
"{64, 80, 96, 112, 128, 192, 256}");
}
int max_num_blocks_per_seq = bt_shape[1];
// block_tables first dimension must match num_seqs
if (bt_shape[0] != num_seqs) {
std::stringstream ss;
ss << "[paged_attention] block_tables.shape[0] (" << bt_shape[0]
<< ") must equal q.shape[0] (" << num_seqs << ")";
throw std::invalid_argument(ss.str());
}
// Extract k_cache dimensions
int num_blocks = kc_shape[0];
int num_kv_heads = kc_shape[1];
int head_size_kc = kc_shape[2];
int block_size = kc_shape[3];
int x = kc_shape[4];
if (head_size_kc * x != head_size) {
std::stringstream ss;
ss << "[paged_attention] k_cache head_size (" << head_size_kc << " * " << x
<< ") must equal q head_size (" << head_size << ")";
throw std::invalid_argument(ss.str());
}
// v_cache must match the derived dimensions
if (!(vc_shape[0] == num_blocks && vc_shape[1] == num_kv_heads &&
vc_shape[2] == head_size && vc_shape[3] == block_size)) {
throw std::invalid_argument(
"[paged_attention] `v_cache` shape mismatch with `k_cache`/`q`");
}
// context_lens length must match num_seqs
if (cl_shape[0] != num_seqs) {
std::stringstream ss;
ss << "paged_attention: context_lens length (" << cl_shape[0]
<< ") must equal q.shape[0] (" << num_seqs << ")";
throw std::invalid_argument(ss.str());
}
constexpr int partition_size = 512;
int max_num_partitions =
(max_context_len + partition_size - 1) / partition_size; // ceildiv
bool use_v1 = ((max_num_partitions == 1) || (num_seqs * num_heads > 512)) &&
(partition_size % block_size == 0);
auto out_shape = q.shape();
auto inputs = std::vector{
std::move(q),
std::move(k_cache),
std::move(v_cache),
std::move(block_tables),
std::move(context_lens)};
if (alibi_slopes.has_value()) {
inputs.push_back(std::move(alibi_slopes.value()));
}
int q_stride = q.strides()[0];
int kv_block_stride = k_cache.strides()[0];
int kv_head_stride = k_cache.strides()[1];
return array(
std::move(out_shape),
q.dtype(),
std::make_shared<PagedAttention>(
to_stream(s),
use_v1,
max_context_len,
head_size,
block_size,
num_kv_heads,
softmax_scale,
max_num_blocks_per_seq,
max_num_partitions,
q_stride,
kv_block_stride,
kv_head_stride,
num_heads,
num_seqs,
softcapping),
inputs);
}
} // namespace mlx::core::paged_attention

34
mlx/paged_attention.h Normal file
View File

@ -0,0 +1,34 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <optional>
#include "mlx/array.h"
#include "mlx/device.h"
#include "mlx/stream.h"
#include "mlx/utils.h"
namespace mlx::core::paged_attention {
/**
* \defgroup ops Paged attention operations
* @{
*/
/** PagedAttention operation. */
array paged_attention(
const array& q,
const array& k_cache,
const array& v_cache,
const array& block_tables,
const array& context_lens,
int max_context_len,
float softmax_scale,
std::optional<array> alibi_slopes = std::nullopt,
std::optional<float> softcapping = std::nullopt,
StreamOrDevice s_ = {});
/** @} */
} // namespace mlx::core::paged_attention

View File

@ -0,0 +1,82 @@
// Copyright © 2023-2024 Apple Inc.
// Required for using M_PI in MSVC.
#define _USE_MATH_DEFINES
#include <optional>
#include "mlx/primitives.h"
namespace mlx::core::paged_attention {
class PagedAttention : public UnaryPrimitive {
public:
explicit PagedAttention(
Stream stream,
bool use_v1,
int max_context_len,
int head_size,
int block_size,
int num_kv_heads,
int max_num_blocks_per_seq,
int max_num_partitions,
int q_stride,
int kv_block_stride,
int kv_head_stride,
int num_heads,
int num_seqs,
float softmax_scale,
std::optional<float> softcapping = std::nullopt)
: UnaryPrimitive(stream),
use_v1_(use_v1),
max_context_len_(max_context_len),
head_size_(head_size),
block_size_(block_size),
num_kv_heads_(num_kv_heads),
max_num_blocks_per_seq_(max_num_blocks_per_seq),
max_num_partitions_(max_num_partitions),
q_stride_(q_stride),
kv_block_stride_(kv_block_stride),
kv_head_stride_(kv_head_stride),
num_heads_(num_heads),
num_seqs_(num_seqs),
softmax_scale_(softmax_scale),
softcapping_(softcapping) {}
void eval_cpu(const std::vector<array>& inputs, array& outputs) override {
throw std::runtime_error("NYI");
}
void eval_gpu(const std::vector<array>& inputs, array& outputs) override;
DEFINE_PRINT(PagedAttention);
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
auto state() const {
return std::make_tuple(
max_context_len_,
head_size_,
block_size_,
softmax_scale_,
softcapping_);
}
private:
bool use_v1_;
int max_context_len_;
int head_size_;
int block_size_;
int num_kv_heads_;
int max_num_blocks_per_seq_;
int max_num_partitions_;
int q_stride_;
int kv_block_stride_;
int kv_head_stride_;
int num_heads_;
int num_seqs_;
float softmax_scale_;
std::optional<float> softcapping_ = std::nullopt;
};
} // namespace mlx::core::paged_attention

View File

@ -69,7 +69,12 @@ inline void PrintFormatter::print(std::ostream& os, double val) {
os << val;
}
inline void PrintFormatter::print(std::ostream& os, complex64_t val) {
os << val;
os << val.real();
if (val.imag() >= 0 || std::isnan(val.imag())) {
os << "+" << val.imag() << "j";
} else {
os << "-" << -val.imag() << "j";
}
}
PrintFormatter& get_global_formatter() {

View File

@ -1195,6 +1195,16 @@ class TestBlas(mlx_tests.MLXTestCase):
c_np = np.matmul(np.array(a).T, b)
self.assertTrue(np.allclose(c, c_np))
# Check shapes
a = mx.random.normal((2, 3)).astype(mx.complex64)
b = mx.random.normal((3,))
self.assertEqual((a @ b).shape, (2,))
a = mx.random.normal((2, 3)).astype(mx.complex64)
b = mx.random.normal((3,))
c = mx.random.normal((2,))
self.assertEqual(mx.addmm(c, a, b).shape, (2,))
def test_complex_gemm(self):
M = 16
K = 50

View File

@ -3078,6 +3078,13 @@ class TestOps(mlx_tests.MLXTestCase):
)
self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x)))
def test_complex_power(self):
out = mx.power(mx.array(0j), 2)
self.assertEqual(out.item(), 0j)
out = mx.power(mx.array(0j), float("nan"))
self.assertTrue(mx.isnan(out))
class TestBroadcast(mlx_tests.MLXTestCase):
def test_broadcast_shapes(self):