mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Compare commits
7 Commits
14d22ddedb
...
49b14ddb25
Author | SHA1 | Date | |
---|---|---|---|
![]() |
49b14ddb25 | ||
![]() |
b8022c578a | ||
![]() |
4d68bd3250 | ||
![]() |
5fbce6c49e | ||
![]() |
0b5c5680f4 | ||
![]() |
221edc4a65 | ||
![]() |
190c72739b |
@ -12,6 +12,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/paged_attention.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
|
||||
|
@ -8,6 +8,7 @@ target_sources(
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
||||
|
@ -125,13 +125,12 @@ constexpr bool supports_binary_op() {
|
||||
template <typename Op>
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
array& out,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() > 1);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@ -146,7 +145,6 @@ void binary_op_gpu_inplace(
|
||||
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = cuda_type_t<CTYPE_IN>;
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
@ -219,20 +217,6 @@ void binary_op_gpu_inplace(
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
@ -243,8 +227,7 @@ void binary_op_gpu(
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
std::vector<array> outputs{out};
|
||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||
binary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||
}
|
||||
|
||||
#define BINARY_GPU(func) \
|
||||
@ -254,14 +237,6 @@ void binary_op_gpu(
|
||||
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
|
||||
}
|
||||
|
||||
#define BINARY_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||
auto& s = outputs[0].primitive().stream(); \
|
||||
binary_op_gpu<cu::func>(inputs, outputs, get_primitive_string(this), s); \
|
||||
}
|
||||
|
||||
BINARY_GPU(Add)
|
||||
BINARY_GPU(ArcTan2)
|
||||
BINARY_GPU(Divide)
|
||||
|
248
mlx/backend/cuda/binary_two.cu
Normal file
248
mlx/backend/cuda/binary_two.cu
Normal file
@ -0,0 +1,248 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void
|
||||
binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto out = Op{}(a[0], b[0]);
|
||||
out_a[0] = out[0];
|
||||
out_b[0] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void
|
||||
binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto out = Op{}(a[0], b[index]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void
|
||||
binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto out = Op{}(a[index], b[0]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void
|
||||
binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto out = Op{}(a[index], b[index]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void binary_g_nd(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out_a,
|
||||
Out* out_b,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
||||
index, shape.data(), a_strides.data(), b_strides.data());
|
||||
auto out = Op{}(a[a_idx], b[b_idx]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_g(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out_a,
|
||||
Out* out_b,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides a_strides,
|
||||
const __grid_constant__ Strides b_strides,
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_4d(
|
||||
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
||||
auto out = Op{}(a[a_idx], b[b_idx]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out>
|
||||
constexpr bool supports_binary_op() {
|
||||
if (std::is_same_v<Op, DivMod>) {
|
||||
return std::is_same_v<In, Out> &&
|
||||
(std::is_integral_v<Out> || is_floating_v<Out>);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() > 1);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
auto& out_a = outputs[0];
|
||||
auto& out_b = outputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out_a, bopt);
|
||||
set_binary_op_output_data(a, b, out_b, bopt);
|
||||
|
||||
if (out_a.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out_a);
|
||||
encoder.set_output_array(out_b);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
|
||||
MLX_SWITCH_ALL_TYPES(out_a.dtype(), CTYPE_OUT, {
|
||||
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = cuda_type_t<CTYPE_IN>;
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out_a);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
bool large = a.data_size() > INT32_MAX ||
|
||||
b.data_size() > INT32_MAX || out_a.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel =
|
||||
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out_a, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
out_a.size(),
|
||||
const_param<NDIM>(shape),
|
||||
const_param<NDIM>(a_strides),
|
||||
const_param<NDIM>(b_strides));
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out_a, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
out_a.size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
MLX_SWITCH_BOOL(out_a.data_size() > UINT32_MAX, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
|
||||
} else if (bopt == BinaryOpType::VectorVector) {
|
||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
kernel,
|
||||
out_a.data_size(),
|
||||
out_a.shape(),
|
||||
out_a.strides(),
|
||||
LARGE);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
out_a.data_size());
|
||||
});
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do binary op {} on inputs of {} with result of {}.",
|
||||
op,
|
||||
dtype_to_string(a.dtype()),
|
||||
dtype_to_string(out_a.dtype())));
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
void DivMod::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("DivMod::eval_gpu");
|
||||
auto& s = outputs[0].primitive().stream();
|
||||
binary_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -22,7 +22,7 @@ struct FloorDivide {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return x / y;
|
||||
} else {
|
||||
return trunc(x / y);
|
||||
return truncf(x / y);
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -132,7 +132,7 @@ struct LogAddExp {
|
||||
cuda::std::numeric_limits<float>::quiet_NaN(),
|
||||
cuda::std::numeric_limits<float>::quiet_NaN()};
|
||||
}
|
||||
constexpr float inf = cuda::std::numeric_limits<float>::infinity();
|
||||
float inf = cuda::std::numeric_limits<float>::infinity();
|
||||
auto maxval = x > y ? x : y;
|
||||
auto minval = x < y ? x : y;
|
||||
if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf)
|
||||
|
@ -5,7 +5,7 @@
|
||||
#pragma once
|
||||
|
||||
// The maximum dimensions of shape/strides passed as kernel parameters.
|
||||
#define MAX_NDIM 8
|
||||
#define MAX_NDIM 10
|
||||
|
||||
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
|
||||
// warpSize variable exists, using it would prevent compile-time optimizations.
|
||||
|
@ -71,10 +71,8 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||
}
|
||||
|
||||
NO_GPU(ArgPartition)
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU(Convolution)
|
||||
NO_GPU_MULTI(DivMod)
|
||||
NO_GPU(DynamicSlice)
|
||||
NO_GPU(DynamicSliceUpdate)
|
||||
NO_GPU(FFT)
|
||||
@ -83,7 +81,6 @@ NO_GPU(GatherQMM)
|
||||
NO_GPU(Hadamard)
|
||||
NO_GPU(Load)
|
||||
NO_GPU_MULTI(LUF)
|
||||
NO_GPU(Partition)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
NO_GPU(Scan)
|
||||
|
@ -86,7 +86,6 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
axis += in.ndim();
|
||||
}
|
||||
int nsort = in.shape(axis);
|
||||
int nsegments = in.data_size() / nsort;
|
||||
int last_dim = in.ndim() - 1;
|
||||
|
||||
// If we are not sorting the innermost dimension of a contiguous array,
|
||||
@ -100,7 +99,11 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
||||
encoder.add_temporary(out);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
@ -134,7 +137,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
indices.data<uint32_t>(),
|
||||
out.data<uint32_t>(),
|
||||
in.data_size(),
|
||||
nsegments,
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
stream);
|
||||
@ -144,7 +147,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
in.data<Type>(),
|
||||
out.data<Type>(),
|
||||
in.data_size(),
|
||||
nsegments,
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
stream);
|
||||
@ -177,4 +180,14 @@ void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
gpu_sort(stream(), inputs[0], out, axis_, false);
|
||||
}
|
||||
|
||||
void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("ArgPartition::eval_gpu");
|
||||
gpu_sort(stream(), inputs[0], out, axis_, true);
|
||||
}
|
||||
|
||||
void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Partition::eval_gpu");
|
||||
gpu_sort(stream(), inputs[0], out, axis_, false);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -102,6 +102,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/paged_attention.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
|
||||
|
@ -241,6 +241,13 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
int wn,
|
||||
bool transpose);
|
||||
|
||||
MTL::ComputePipelineState* get_paged_attention_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const std::string&);
|
||||
|
||||
// Create a GPU kernel template definition for JIT compilation
|
||||
template <typename... Args>
|
||||
std::string
|
||||
|
@ -109,6 +109,7 @@ if(NOT MLX_METAL_JIT)
|
||||
reduction/reduce_row.h)
|
||||
build_kernel(quantized quantized.h ${STEEL_HEADERS})
|
||||
build_kernel(scan scan.h)
|
||||
build_kernel(paged_attention paged_attention.h)
|
||||
build_kernel(softmax softmax.h)
|
||||
build_kernel(logsumexp logsumexp.h)
|
||||
build_kernel(sort sort.h)
|
||||
|
1196
mlx/backend/metal/kernels/paged_attention.h
Normal file
1196
mlx/backend/metal/kernels/paged_attention.h
Normal file
File diff suppressed because it is too large
Load Diff
131
mlx/backend/metal/kernels/paged_attention.metal
Normal file
131
mlx/backend/metal/kernels/paged_attention.metal
Normal file
@ -0,0 +1,131 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/paged_attention.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#define instantiate_paged_attention_inner( \
|
||||
type, head_size, block_size, num_threads, num_simd_lanes, partition_size) \
|
||||
template \
|
||||
[[host_name("paged_attention_" #type "_hs" #head_size "_bs" #block_size \
|
||||
"_nt" #num_threads "_nsl" #num_simd_lanes \
|
||||
"_ps" #partition_size)]] [[kernel]] void \
|
||||
paged_attention< \
|
||||
type, \
|
||||
head_size, \
|
||||
block_size, \
|
||||
num_threads, \
|
||||
num_simd_lanes, \
|
||||
partition_size>( \
|
||||
device float* exp_sums \
|
||||
[[buffer(0), function_constant(use_partitioning)]], \
|
||||
device float* max_logits \
|
||||
[[buffer(1), function_constant(use_partitioning)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
device const type* q [[buffer(3)]], \
|
||||
device const type* k_cache [[buffer(4)]], \
|
||||
device const type* v_cache [[buffer(5)]], \
|
||||
const constant int& num_kv_heads [[buffer(6)]], \
|
||||
const constant float& scale [[buffer(7)]], \
|
||||
const constant float& softcapping [[buffer(8)]], \
|
||||
device const uint32_t* block_tables [[buffer(9)]], \
|
||||
device const uint32_t* context_lens [[buffer(10)]], \
|
||||
const constant int& max_num_blocks_per_seq [[buffer(11)]], \
|
||||
device const float* alibi_slopes \
|
||||
[[buffer(12), function_constant(use_alibi)]], \
|
||||
const constant int& q_stride [[buffer(13)]], \
|
||||
const constant int& kv_block_stride [[buffer(14)]], \
|
||||
const constant int& kv_head_stride [[buffer(15)]], \
|
||||
threadgroup char* shared_mem [[threadgroup(0)]], \
|
||||
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
|
||||
uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
|
||||
uint3 thread_position_in_threadgroup \
|
||||
[[thread_position_in_threadgroup]], \
|
||||
uint simd_tid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, head_size, num_threads, num_simd_lanes, partition_size) \
|
||||
template [[host_name("paged_attention_v2_reduce_" #type "_hs" #head_size \
|
||||
"_nt" #num_threads "_nsl" #num_simd_lanes \
|
||||
"_ps" #partition_size)]] [[kernel]] void \
|
||||
paged_attention_v2_reduce< \
|
||||
type, \
|
||||
head_size, \
|
||||
num_threads, \
|
||||
num_simd_lanes, \
|
||||
partition_size>( \
|
||||
device type * out [[buffer(0)]], \
|
||||
const device float* exp_sums [[buffer(1)]], \
|
||||
const device float* max_logits [[buffer(2)]], \
|
||||
const device type* tmp_out [[buffer(3)]], \
|
||||
device uint32_t* context_lens [[buffer(4)]], \
|
||||
const constant int& max_num_partitions [[buffer(5)]], \
|
||||
threadgroup char* shared_mem [[threadgroup(0)]], \
|
||||
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
|
||||
uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
|
||||
uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \
|
||||
uint3 threads_per_threadgroup [[threads_per_threadgroup]], \
|
||||
uint simd_tid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_paged_attention_heads( \
|
||||
type, block_size, num_threads, num_simd_lanes, partition_size) \
|
||||
instantiate_paged_attention_inner( \
|
||||
type, 64, block_size, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_inner( \
|
||||
type, 80, block_size, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_inner( \
|
||||
type, 96, block_size, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_inner( \
|
||||
type, 112, block_size, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_inner( \
|
||||
type, 128, block_size, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_inner( \
|
||||
type, 192, block_size, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_inner( \
|
||||
type, 256, block_size, num_threads, num_simd_lanes, partition_size);
|
||||
|
||||
#define instantiate_paged_attention_v2_reduce_heads( \
|
||||
type, num_threads, num_simd_lanes, partition_size) \
|
||||
instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, 64, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, 80, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, 96, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, 112, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, 128, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, 192, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_v2_reduce_inner( \
|
||||
type, 256, num_threads, num_simd_lanes, partition_size);
|
||||
|
||||
#define instantiate_paged_attention_block_size( \
|
||||
type, num_threads, num_simd_lanes, partition_size) \
|
||||
instantiate_paged_attention_heads( \
|
||||
type, 8, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_heads( \
|
||||
type, 16, num_threads, num_simd_lanes, partition_size); \
|
||||
instantiate_paged_attention_heads( \
|
||||
type, 32, num_threads, num_simd_lanes, partition_size);
|
||||
|
||||
// TODO: tune num_threads = 256
|
||||
// NOTE: partition_size = 0
|
||||
#define instantiate_paged_attention_v1(type, num_simd_lanes) \
|
||||
instantiate_paged_attention_block_size(type, 256, num_simd_lanes, 0);
|
||||
|
||||
// TODO: tune num_threads = 256
|
||||
// NOTE: partition_size = 512
|
||||
#define instantiate_paged_attention_v2(type, num_simd_lanes) \
|
||||
instantiate_paged_attention_block_size(type, 256, num_simd_lanes, 512); \
|
||||
instantiate_paged_attention_v2_reduce_heads(type, 256, num_simd_lanes, 512);
|
||||
|
||||
instantiate_paged_attention_v1(float, 32);
|
||||
instantiate_paged_attention_v1(bfloat16_t, 32);
|
||||
instantiate_paged_attention_v1(half, 32);
|
||||
|
||||
instantiate_paged_attention_v2(float, 32);
|
||||
instantiate_paged_attention_v2(bfloat16_t, 32);
|
||||
instantiate_paged_attention_v2(half, 32);
|
@ -288,4 +288,13 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
return d.get_kernel(kernel_name, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_paged_attention_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const std::string&) {
|
||||
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
324
mlx/backend/metal/paged_attention.cpp
Normal file
324
mlx/backend/metal/paged_attention.cpp
Normal file
@ -0,0 +1,324 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/paged_attention_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::paged_attention {
|
||||
|
||||
static void run_paged_attention(
|
||||
const array& q,
|
||||
const array& k_cache,
|
||||
const array& v_cache,
|
||||
const array& block_tables,
|
||||
const array& context_lens,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int num_kv_heads,
|
||||
const float scale,
|
||||
const float softcapping,
|
||||
const int max_context_len,
|
||||
const int max_num_blocks_per_seq,
|
||||
const bool use_partitioning,
|
||||
const std::optional<array> alibi,
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride,
|
||||
const int num_heads,
|
||||
const int num_seqs,
|
||||
array& out,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const int partition_size = use_partitioning ? 512 : 0;
|
||||
const int num_threads = 256;
|
||||
const int num_simd_lanes = 32;
|
||||
const bool use_alibi = alibi.has_value();
|
||||
|
||||
std::string type_string = get_type_string(q.dtype());
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
concatenate(
|
||||
kname,
|
||||
"paged_attention_",
|
||||
type_string,
|
||||
"_hs",
|
||||
head_size,
|
||||
"_bs",
|
||||
block_size,
|
||||
"_nt",
|
||||
num_threads,
|
||||
"_nsl",
|
||||
num_simd_lanes,
|
||||
"_ps",
|
||||
partition_size);
|
||||
|
||||
auto template_def = get_template_definition(
|
||||
kname,
|
||||
"paged_attention",
|
||||
type_string,
|
||||
head_size,
|
||||
block_size,
|
||||
num_threads,
|
||||
num_simd_lanes,
|
||||
partition_size);
|
||||
|
||||
// Encode and dispatch kernel
|
||||
metal::MTLFCList func_consts = {
|
||||
{use_partitioning, MTL::DataType::DataTypeBool, 10},
|
||||
{use_alibi, MTL::DataType::DataTypeBool, 20},
|
||||
};
|
||||
|
||||
std::string hash_name = kname;
|
||||
auto kernel = get_paged_attention_kernel(
|
||||
d, kname, hash_name, func_consts, template_def);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int local_max_num_partitions = 1;
|
||||
if (use_partitioning) {
|
||||
local_max_num_partitions =
|
||||
(max_context_len + partition_size - 1) / partition_size;
|
||||
}
|
||||
|
||||
int logits_size = use_partitioning ? partition_size * size_of(float32) : 0;
|
||||
int outputs_size = use_partitioning
|
||||
? ((num_threads / num_simd_lanes) / 2) * head_size * size_of(float32)
|
||||
: 0;
|
||||
int shared_mem_size =
|
||||
use_partitioning ? std::max(logits_size, outputs_size) : 0;
|
||||
if (use_partitioning) {
|
||||
compute_encoder.set_threadgroup_memory_length(shared_mem_size, 0);
|
||||
}
|
||||
|
||||
if (use_partitioning) {
|
||||
auto tmp_out = array(
|
||||
{num_seqs, num_heads, local_max_num_partitions, head_size}, float32);
|
||||
tmp_out.set_data(allocator::malloc(tmp_out.nbytes()));
|
||||
auto exp_sums =
|
||||
array({num_seqs, num_heads, local_max_num_partitions}, float32);
|
||||
exp_sums.set_data(allocator::malloc(exp_sums.nbytes()));
|
||||
|
||||
std::vector<array> temporaries = {tmp_out, exp_sums};
|
||||
|
||||
compute_encoder.set_output_array(tmp_out, 0);
|
||||
compute_encoder.set_output_array(exp_sums, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
compute_encoder.set_input_array(q, 3);
|
||||
compute_encoder.set_input_array(k_cache, 4);
|
||||
compute_encoder.set_input_array(v_cache, 5);
|
||||
|
||||
compute_encoder.set_bytes(num_kv_heads, 6);
|
||||
compute_encoder.set_bytes(scale, 7);
|
||||
compute_encoder.set_bytes(softcapping, 8);
|
||||
|
||||
compute_encoder.set_input_array(block_tables, 9);
|
||||
compute_encoder.set_input_array(context_lens, 10);
|
||||
|
||||
compute_encoder.set_bytes(max_num_blocks_per_seq, 11);
|
||||
|
||||
if (use_alibi) {
|
||||
compute_encoder.set_input_array(alibi.value(), 12);
|
||||
}
|
||||
|
||||
compute_encoder.set_bytes(q_stride, 13);
|
||||
compute_encoder.set_bytes(kv_block_stride, 14);
|
||||
compute_encoder.set_bytes(kv_head_stride, 15);
|
||||
|
||||
MTL::Size grid_dims(num_heads, num_seqs, local_max_num_partitions);
|
||||
MTL::Size group_dims(num_threads, 1, 1);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
d.add_temporaries(std::move(temporaries), s.index);
|
||||
} else {
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
compute_encoder.set_input_array(q, 3);
|
||||
compute_encoder.set_input_array(k_cache, 4);
|
||||
compute_encoder.set_input_array(v_cache, 5);
|
||||
|
||||
compute_encoder.set_bytes(num_kv_heads, 6);
|
||||
compute_encoder.set_bytes(scale, 7);
|
||||
compute_encoder.set_bytes(softcapping, 8);
|
||||
|
||||
compute_encoder.set_input_array(block_tables, 9);
|
||||
compute_encoder.set_input_array(context_lens, 10);
|
||||
|
||||
compute_encoder.set_bytes(max_num_blocks_per_seq, 11);
|
||||
|
||||
if (use_alibi) {
|
||||
compute_encoder.set_input_array(alibi.value(), 12);
|
||||
}
|
||||
|
||||
compute_encoder.set_bytes(q_stride, 13);
|
||||
compute_encoder.set_bytes(kv_block_stride, 14);
|
||||
compute_encoder.set_bytes(kv_head_stride, 15);
|
||||
|
||||
MTL::Size grid_dims(num_heads, num_seqs, 1);
|
||||
MTL::Size group_dims(num_threads, 1, 1);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void paged_attention_v1(
|
||||
const array& q,
|
||||
const array& k_cache,
|
||||
const array& v_cache,
|
||||
const array& block_tables,
|
||||
const array& context_lens,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int num_kv_heads,
|
||||
const float scale,
|
||||
const float softcapping,
|
||||
const int max_context_len,
|
||||
const int max_num_blocks_per_seq,
|
||||
const std::optional<array> alibi,
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride,
|
||||
const int num_heads,
|
||||
const int num_seqs,
|
||||
array& out,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
run_paged_attention(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
context_lens,
|
||||
head_size,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
softcapping,
|
||||
max_context_len,
|
||||
max_num_blocks_per_seq,
|
||||
/*use_partitioning=*/false,
|
||||
alibi,
|
||||
q_stride,
|
||||
kv_block_stride,
|
||||
kv_head_stride,
|
||||
num_heads,
|
||||
num_seqs,
|
||||
out,
|
||||
d,
|
||||
s);
|
||||
}
|
||||
|
||||
void paged_attention_v2(
|
||||
const array& q,
|
||||
const array& k_cache,
|
||||
const array& v_cache,
|
||||
const array& block_tables,
|
||||
const array& context_lens,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int num_kv_heads,
|
||||
const float scale,
|
||||
const float softcapping,
|
||||
const int max_context_len,
|
||||
const int max_num_blocks_per_seq,
|
||||
const int /* max_num_partitions */,
|
||||
const std::optional<array> alibi,
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride,
|
||||
const int num_heads,
|
||||
const int num_seqs,
|
||||
array& out,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
run_paged_attention(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
context_lens,
|
||||
head_size,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
softcapping,
|
||||
max_context_len,
|
||||
max_num_blocks_per_seq,
|
||||
/*use_partitioning=*/true,
|
||||
alibi,
|
||||
q_stride,
|
||||
kv_block_stride,
|
||||
kv_head_stride,
|
||||
num_heads,
|
||||
num_seqs,
|
||||
out,
|
||||
d,
|
||||
s);
|
||||
}
|
||||
|
||||
void PagedAttention::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& q = inputs[0];
|
||||
auto& k_cache = inputs[1];
|
||||
auto& v_cache = inputs[2];
|
||||
auto& block_tables = inputs[3];
|
||||
auto& context_lens = inputs[4];
|
||||
const auto alibi_slopes =
|
||||
inputs.size() == 6 ? std::optional{inputs[5]} : std::nullopt;
|
||||
|
||||
if (use_v1_) {
|
||||
paged_attention_v1(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
context_lens,
|
||||
head_size_,
|
||||
block_size_,
|
||||
num_kv_heads_,
|
||||
softmax_scale_,
|
||||
softcapping_.value_or(1.),
|
||||
max_context_len_,
|
||||
max_num_blocks_per_seq_,
|
||||
alibi_slopes,
|
||||
q_stride_,
|
||||
kv_block_stride_,
|
||||
kv_head_stride_,
|
||||
num_heads_,
|
||||
num_seqs_,
|
||||
out,
|
||||
d,
|
||||
s);
|
||||
} else {
|
||||
paged_attention_v2(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
context_lens,
|
||||
head_size_,
|
||||
block_size_,
|
||||
num_kv_heads_,
|
||||
softmax_scale_,
|
||||
softcapping_.value_or(1.),
|
||||
max_context_len_,
|
||||
max_num_blocks_per_seq_,
|
||||
max_num_partitions_,
|
||||
alibi_slopes,
|
||||
q_stride_,
|
||||
kv_block_stride_,
|
||||
kv_head_stride_,
|
||||
num_heads_,
|
||||
num_seqs_,
|
||||
out,
|
||||
d,
|
||||
s);
|
||||
}
|
||||
}
|
||||
} // namespace mlx::core::paged_attention
|
@ -17,6 +17,7 @@
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/memory.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/paged_attention.h"
|
||||
#include "mlx/random.h"
|
||||
#include "mlx/stream.h"
|
||||
#include "mlx/transforms.h"
|
||||
|
170
mlx/paged_attention.cpp
Normal file
170
mlx/paged_attention.cpp
Normal file
@ -0,0 +1,170 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
// Required for using M_PI in MSVC.
|
||||
#define _USE_MATH_DEFINES
|
||||
|
||||
#include <algorithm>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "mlx/paged_attention_primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::paged_attention {
|
||||
|
||||
array paged_attention(
|
||||
const array& q,
|
||||
const array& k_cache,
|
||||
const array& v_cache,
|
||||
const array& block_tables,
|
||||
const array& context_lens,
|
||||
int max_context_len,
|
||||
float softmax_scale,
|
||||
std::optional<array> alibi_slopes = std::nullopt,
|
||||
std::optional<float> softcapping = std::nullopt,
|
||||
StreamOrDevice s_ = {}) {
|
||||
auto s = to_stream(s_);
|
||||
|
||||
// supported dtypes
|
||||
if (!issubdtype(q.dtype(), floating)) {
|
||||
throw std::invalid_argument(
|
||||
"[paged_attention] Only real floating types are supported");
|
||||
}
|
||||
if (!(q.dtype() == k_cache.dtype() && k_cache.dtype() == v_cache.dtype())) {
|
||||
throw std::invalid_argument(
|
||||
"[paged_attention] q/k_cache/v_cache dtype must match");
|
||||
}
|
||||
if (!(block_tables.dtype() == uint32 && context_lens.dtype() == uint32)) {
|
||||
throw std::invalid_argument(
|
||||
"[paged_attention] block_tables/context_lens dtype must be uint32");
|
||||
}
|
||||
|
||||
// rank checks
|
||||
if (q.ndim() != 3)
|
||||
throw std::invalid_argument("[paged_attention] `q` must be rank-3");
|
||||
if (k_cache.ndim() != 5)
|
||||
throw std::invalid_argument("[paged_attention] `k_cache` must be rank-5");
|
||||
if (v_cache.ndim() != 4)
|
||||
throw std::invalid_argument("[paged_attention] `v_cache` must be rank-4");
|
||||
if (block_tables.ndim() != 2)
|
||||
throw std::invalid_argument(
|
||||
"[paged_attention] `block_tables` must be rank-2");
|
||||
if (context_lens.ndim() != 1)
|
||||
throw std::invalid_argument(
|
||||
"[paged_attention] `context_lens` must be rank-1");
|
||||
|
||||
// 4. Shape consistency
|
||||
const auto& q_shape = q.shape(); // [num_seqs, num_heads, head_size]
|
||||
const auto& kc_shape = k_cache.shape();
|
||||
const auto& vc_shape = v_cache.shape();
|
||||
const auto& bt_shape = block_tables.shape();
|
||||
const auto& cl_shape = context_lens.shape();
|
||||
|
||||
int num_seqs = q_shape[0];
|
||||
int num_heads = q_shape[1];
|
||||
int head_size = q_shape[2];
|
||||
|
||||
// Allowed head sizes
|
||||
switch (head_size) {
|
||||
case 64:
|
||||
case 80:
|
||||
case 96:
|
||||
case 112:
|
||||
case 128:
|
||||
case 192:
|
||||
case 256:
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[paged_attention] `head_size` must be one of "
|
||||
"{64, 80, 96, 112, 128, 192, 256}");
|
||||
}
|
||||
|
||||
int max_num_blocks_per_seq = bt_shape[1];
|
||||
|
||||
// block_tables first dimension must match num_seqs
|
||||
if (bt_shape[0] != num_seqs) {
|
||||
std::stringstream ss;
|
||||
ss << "[paged_attention] block_tables.shape[0] (" << bt_shape[0]
|
||||
<< ") must equal q.shape[0] (" << num_seqs << ")";
|
||||
throw std::invalid_argument(ss.str());
|
||||
}
|
||||
|
||||
// Extract k_cache dimensions
|
||||
int num_blocks = kc_shape[0];
|
||||
int num_kv_heads = kc_shape[1];
|
||||
int head_size_kc = kc_shape[2];
|
||||
int block_size = kc_shape[3];
|
||||
int x = kc_shape[4];
|
||||
|
||||
if (head_size_kc * x != head_size) {
|
||||
std::stringstream ss;
|
||||
ss << "[paged_attention] k_cache head_size (" << head_size_kc << " * " << x
|
||||
<< ") must equal q head_size (" << head_size << ")";
|
||||
throw std::invalid_argument(ss.str());
|
||||
}
|
||||
|
||||
// v_cache must match the derived dimensions
|
||||
if (!(vc_shape[0] == num_blocks && vc_shape[1] == num_kv_heads &&
|
||||
vc_shape[2] == head_size && vc_shape[3] == block_size)) {
|
||||
throw std::invalid_argument(
|
||||
"[paged_attention] `v_cache` shape mismatch with `k_cache`/`q`");
|
||||
}
|
||||
|
||||
// context_lens length must match num_seqs
|
||||
if (cl_shape[0] != num_seqs) {
|
||||
std::stringstream ss;
|
||||
ss << "paged_attention: context_lens length (" << cl_shape[0]
|
||||
<< ") must equal q.shape[0] (" << num_seqs << ")";
|
||||
throw std::invalid_argument(ss.str());
|
||||
}
|
||||
|
||||
constexpr int partition_size = 512;
|
||||
int max_num_partitions =
|
||||
(max_context_len + partition_size - 1) / partition_size; // ceil‑div
|
||||
bool use_v1 = ((max_num_partitions == 1) || (num_seqs * num_heads > 512)) &&
|
||||
(partition_size % block_size == 0);
|
||||
|
||||
auto out_shape = q.shape();
|
||||
|
||||
auto inputs = std::vector{
|
||||
std::move(q),
|
||||
std::move(k_cache),
|
||||
std::move(v_cache),
|
||||
std::move(block_tables),
|
||||
std::move(context_lens)};
|
||||
if (alibi_slopes.has_value()) {
|
||||
inputs.push_back(std::move(alibi_slopes.value()));
|
||||
}
|
||||
|
||||
int q_stride = q.strides()[0];
|
||||
int kv_block_stride = k_cache.strides()[0];
|
||||
int kv_head_stride = k_cache.strides()[1];
|
||||
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
q.dtype(),
|
||||
std::make_shared<PagedAttention>(
|
||||
to_stream(s),
|
||||
use_v1,
|
||||
max_context_len,
|
||||
head_size,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
softmax_scale,
|
||||
max_num_blocks_per_seq,
|
||||
max_num_partitions,
|
||||
q_stride,
|
||||
kv_block_stride,
|
||||
kv_head_stride,
|
||||
num_heads,
|
||||
num_seqs,
|
||||
softcapping),
|
||||
inputs);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::paged_attention
|
34
mlx/paged_attention.h
Normal file
34
mlx/paged_attention.h
Normal file
@ -0,0 +1,34 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/stream.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::paged_attention {
|
||||
|
||||
/**
|
||||
* \defgroup ops Paged attention operations
|
||||
* @{
|
||||
*/
|
||||
|
||||
/** PagedAttention operation. */
|
||||
array paged_attention(
|
||||
const array& q,
|
||||
const array& k_cache,
|
||||
const array& v_cache,
|
||||
const array& block_tables,
|
||||
const array& context_lens,
|
||||
int max_context_len,
|
||||
float softmax_scale,
|
||||
std::optional<array> alibi_slopes = std::nullopt,
|
||||
std::optional<float> softcapping = std::nullopt,
|
||||
StreamOrDevice s_ = {});
|
||||
|
||||
/** @} */
|
||||
|
||||
} // namespace mlx::core::paged_attention
|
82
mlx/paged_attention_primitives.h
Normal file
82
mlx/paged_attention_primitives.h
Normal file
@ -0,0 +1,82 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
// Required for using M_PI in MSVC.
|
||||
#define _USE_MATH_DEFINES
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::paged_attention {
|
||||
|
||||
class PagedAttention : public UnaryPrimitive {
|
||||
public:
|
||||
explicit PagedAttention(
|
||||
Stream stream,
|
||||
bool use_v1,
|
||||
int max_context_len,
|
||||
int head_size,
|
||||
int block_size,
|
||||
int num_kv_heads,
|
||||
int max_num_blocks_per_seq,
|
||||
int max_num_partitions,
|
||||
int q_stride,
|
||||
int kv_block_stride,
|
||||
int kv_head_stride,
|
||||
int num_heads,
|
||||
int num_seqs,
|
||||
float softmax_scale,
|
||||
std::optional<float> softcapping = std::nullopt)
|
||||
: UnaryPrimitive(stream),
|
||||
use_v1_(use_v1),
|
||||
max_context_len_(max_context_len),
|
||||
head_size_(head_size),
|
||||
block_size_(block_size),
|
||||
num_kv_heads_(num_kv_heads),
|
||||
max_num_blocks_per_seq_(max_num_blocks_per_seq),
|
||||
max_num_partitions_(max_num_partitions),
|
||||
q_stride_(q_stride),
|
||||
kv_block_stride_(kv_block_stride),
|
||||
kv_head_stride_(kv_head_stride),
|
||||
num_heads_(num_heads),
|
||||
num_seqs_(num_seqs),
|
||||
softmax_scale_(softmax_scale),
|
||||
softcapping_(softcapping) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& outputs) override {
|
||||
throw std::runtime_error("NYI");
|
||||
}
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, array& outputs) override;
|
||||
|
||||
DEFINE_PRINT(PagedAttention);
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
auto state() const {
|
||||
return std::make_tuple(
|
||||
max_context_len_,
|
||||
head_size_,
|
||||
block_size_,
|
||||
softmax_scale_,
|
||||
softcapping_);
|
||||
}
|
||||
|
||||
private:
|
||||
bool use_v1_;
|
||||
int max_context_len_;
|
||||
int head_size_;
|
||||
int block_size_;
|
||||
int num_kv_heads_;
|
||||
int max_num_blocks_per_seq_;
|
||||
int max_num_partitions_;
|
||||
int q_stride_;
|
||||
int kv_block_stride_;
|
||||
int kv_head_stride_;
|
||||
int num_heads_;
|
||||
int num_seqs_;
|
||||
float softmax_scale_;
|
||||
std::optional<float> softcapping_ = std::nullopt;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::paged_attention
|
@ -1,10 +1,8 @@
|
||||
cuda_skip = {
|
||||
"TestArray.test_api",
|
||||
"TestAutograd.test_update_state",
|
||||
"TestBF16.test_arg_reduction_ops",
|
||||
"TestBF16.test_reduction_ops",
|
||||
"TestBlas.test_complex_gemm",
|
||||
"TestCompile.test_compile_dynamic_dims",
|
||||
"TestEinsum.test_ellipses",
|
||||
"TestEinsum.test_opt_einsum_test_cases",
|
||||
"TestLoad.test_load_f8_e4m3",
|
||||
@ -14,24 +12,14 @@ cuda_skip = {
|
||||
"TestLayers.test_quantized_embedding",
|
||||
"TestLayers.test_sin_pe",
|
||||
"TestLayers.test_upsample",
|
||||
"TestOps.test_array_equal",
|
||||
"TestOps.test_complex_ops",
|
||||
"TestOps.test_dynamic_slicing",
|
||||
"TestOps.test_softmax",
|
||||
"TestOps.test_sort",
|
||||
"TestOps.test_tile",
|
||||
"TestReduce.test_axis_permutation_sums",
|
||||
"TestReduce.test_dtypes",
|
||||
"TestReduce.test_expand_sums",
|
||||
"TestReduce.test_many_reduction_axes",
|
||||
"TestUpsample.test_torch_upsample",
|
||||
# DivMod NYI
|
||||
"TestOps.test_divmod",
|
||||
"TestEval.test_multi_output_eval_during_transform",
|
||||
# Partition NYI
|
||||
"TestAutograd.test_topk_grad",
|
||||
"TestOps.test_argpartition",
|
||||
"TestOps.test_partition",
|
||||
# Block masked matmul NYI
|
||||
"TestBlas.test_block_masked_matmul",
|
||||
# Gather matmul NYI
|
||||
|
Loading…
Reference in New Issue
Block a user