mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Add pagedattn primitive
This commit is contained in:
parent
db5a7c6192
commit
190c72739b
@ -12,6 +12,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/paged_attention.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
|
||||
|
@ -102,6 +102,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/paged_attention.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
|
||||
|
31
mlx/backend/metal/paged_attention.cpp
Normal file
31
mlx/backend/metal/paged_attention.cpp
Normal file
@ -0,0 +1,31 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/broadcasting.h"
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/reduce.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/paged_attention_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::paged_attention {
|
||||
|
||||
void PagedAttention::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& q = inputs[0];
|
||||
auto& k_cache = inputs[1];
|
||||
auto& v_cache = inputs[2];
|
||||
auto& block_tables = inputs[3];
|
||||
auto& context_lens = inputs[4];
|
||||
const auto alibi_slopes =
|
||||
inputs.size() == 6 ? std::optional{inputs[5]} : std::nullopt;
|
||||
return;
|
||||
}
|
||||
} // namespace mlx::core::paged_attention
|
@ -16,6 +16,7 @@
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/memory.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/paged_attention.h"
|
||||
#include "mlx/random.h"
|
||||
#include "mlx/stream.h"
|
||||
#include "mlx/transforms.h"
|
||||
|
150
mlx/paged_attention.cpp
Normal file
150
mlx/paged_attention.cpp
Normal file
@ -0,0 +1,150 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
// Required for using M_PI in MSVC.
|
||||
#define _USE_MATH_DEFINES
|
||||
|
||||
#include <algorithm>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "mlx/paged_attention_primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::paged_attention {
|
||||
|
||||
array paged_attention(
|
||||
const array& q,
|
||||
const array& k_cache,
|
||||
const array& v_cache,
|
||||
const array& block_tables,
|
||||
const array& context_lens,
|
||||
int max_context_len,
|
||||
float softmax_scale,
|
||||
std::optional<array> alibi_slopes = std::nullopt,
|
||||
std::optional<float> softcapping = std::nullopt,
|
||||
StreamOrDevice s_ = {}) {
|
||||
auto s = to_stream(s_);
|
||||
|
||||
// supported dtypes
|
||||
if (!issubdtype(q.dtype(), floating)) {
|
||||
throw std::invalid_argument(
|
||||
"[paged_attention] Only real floating types are supported");
|
||||
}
|
||||
if (!(q.dtype() == k_cache.dtype() && k_cache.dtype() == v_cache.dtype())) {
|
||||
throw std::invalid_argument(
|
||||
"[paged_attention] q/k_cache/v_cache dtype must match");
|
||||
}
|
||||
if (!(block_tables.dtype() == uint32 && context_lens.dtype() == uint32)) {
|
||||
throw std::invalid_argument(
|
||||
"[paged_attention] block_tables/context_lens dtype must be uint32");
|
||||
}
|
||||
|
||||
// rank checks
|
||||
if (q.ndim() != 3)
|
||||
throw std::invalid_argument("[paged_attention] `q` must be rank-3");
|
||||
if (k_cache.ndim() != 5)
|
||||
throw std::invalid_argument("[paged_attention] `k_cache` must be rank-5");
|
||||
if (v_cache.ndim() != 4)
|
||||
throw std::invalid_argument("[paged_attention] `v_cache` must be rank-4");
|
||||
if (block_tables.ndim() != 2)
|
||||
throw std::invalid_argument(
|
||||
"[paged_attention] `block_tables` must be rank-2");
|
||||
if (context_lens.ndim() != 1)
|
||||
throw std::invalid_argument(
|
||||
"[paged_attention] `context_lens` must be rank-1");
|
||||
|
||||
// 4. Shape consistency
|
||||
const auto& q_shape = q.shape(); // [num_seqs, num_heads, head_size]
|
||||
const auto& kc_shape = k_cache.shape();
|
||||
const auto& vc_shape = v_cache.shape();
|
||||
const auto& bt_shape = block_tables.shape();
|
||||
const auto& cl_shape = context_lens.shape();
|
||||
|
||||
int64_t num_seqs = q_shape[0];
|
||||
int64_t num_heads = q_shape[1];
|
||||
int64_t head_size = q_shape[2];
|
||||
|
||||
// Allowed head sizes
|
||||
switch (head_size) {
|
||||
case 64:
|
||||
case 80:
|
||||
case 96:
|
||||
case 112:
|
||||
case 128:
|
||||
case 192:
|
||||
case 256:
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[paged_attention] `head_size` must be one of "
|
||||
"{64, 80, 96, 112, 128, 192, 256}");
|
||||
}
|
||||
|
||||
// block_tables first dimension must match num_seqs
|
||||
if (bt_shape[0] != num_seqs) {
|
||||
std::stringstream ss;
|
||||
ss << "[paged_attention] block_tables.shape[0] (" << bt_shape[0]
|
||||
<< ") must equal q.shape[0] (" << num_seqs << ")";
|
||||
throw std::invalid_argument(ss.str());
|
||||
}
|
||||
|
||||
// Extract k_cache dimensions
|
||||
int64_t num_blocks = kc_shape[0];
|
||||
int64_t num_kv_heads = kc_shape[1];
|
||||
int64_t head_size_kc = kc_shape[2];
|
||||
int64_t block_size = kc_shape[3];
|
||||
int64_t x = kc_shape[4];
|
||||
|
||||
if (head_size_kc * x != head_size) {
|
||||
std::stringstream ss;
|
||||
ss << "[paged_attention] k_cache head_size (" << head_size_kc << " * " << x
|
||||
<< ") must equal q head_size (" << head_size << ")";
|
||||
throw std::invalid_argument(ss.str());
|
||||
}
|
||||
|
||||
// v_cache must match the derived dimensions
|
||||
if (!(vc_shape[0] == num_blocks && vc_shape[1] == num_kv_heads &&
|
||||
vc_shape[2] == head_size && vc_shape[3] == block_size)) {
|
||||
throw std::invalid_argument(
|
||||
"[paged_attention] `v_cache` shape mismatch with `k_cache`/`q`");
|
||||
}
|
||||
|
||||
// context_lens length must match num_seqs
|
||||
if (cl_shape[0] != num_seqs) {
|
||||
std::stringstream ss;
|
||||
ss << "paged_attention: context_lens length (" << cl_shape[0]
|
||||
<< ") must equal q.shape[0] (" << num_seqs << ")";
|
||||
throw std::invalid_argument(ss.str());
|
||||
}
|
||||
|
||||
constexpr int64_t partition_size = 512;
|
||||
int64_t max_num_partitions =
|
||||
(max_context_len + partition_size - 1) / partition_size; // ceil‑div
|
||||
bool use_v1 = ((max_num_partitions == 1) || (num_seqs * num_heads > 512)) &&
|
||||
(partition_size % block_size == 0);
|
||||
|
||||
auto out_shape = q.shape();
|
||||
|
||||
auto inputs = std::vector{
|
||||
std::move(q),
|
||||
std::move(k_cache),
|
||||
std::move(v_cache),
|
||||
std::move(block_tables),
|
||||
std::move(context_lens)};
|
||||
if (alibi_slopes.has_value()) {
|
||||
inputs.push_back(std::move(alibi_slopes.value()));
|
||||
}
|
||||
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
q.dtype(),
|
||||
std::make_shared<PagedAttention>(
|
||||
to_stream(s), use_v1, max_context_len, softmax_scale, softcapping),
|
||||
inputs);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::paged_attention
|
34
mlx/paged_attention.h
Normal file
34
mlx/paged_attention.h
Normal file
@ -0,0 +1,34 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/stream.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::paged_attention {
|
||||
|
||||
/**
|
||||
* \defgroup ops Paged attention operations
|
||||
* @{
|
||||
*/
|
||||
|
||||
/** PagedAttention operation. */
|
||||
array paged_attention(
|
||||
const array& q,
|
||||
const array& k_cache,
|
||||
const array& v_cache,
|
||||
const array& block_tables,
|
||||
const array& context_lens,
|
||||
int max_context_len,
|
||||
float softmax_scale,
|
||||
std::optional<array> alibi_slopes = std::nullopt,
|
||||
std::optional<float> softcapping = std::nullopt,
|
||||
StreamOrDevice s_ = {});
|
||||
|
||||
/** @} */
|
||||
|
||||
} // namespace mlx::core::paged_attention
|
47
mlx/paged_attention_primitives.h
Normal file
47
mlx/paged_attention_primitives.h
Normal file
@ -0,0 +1,47 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
// Required for using M_PI in MSVC.
|
||||
#define _USE_MATH_DEFINES
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::paged_attention {
|
||||
|
||||
class PagedAttention : public UnaryPrimitive {
|
||||
public:
|
||||
explicit PagedAttention(
|
||||
Stream stream,
|
||||
bool use_v1,
|
||||
int max_context_len,
|
||||
float softmax_scale,
|
||||
std::optional<float> softcapping = std::nullopt)
|
||||
: UnaryPrimitive(stream),
|
||||
use_v1_(use_v1),
|
||||
max_context_len_(max_context_len),
|
||||
softmax_scale_(softmax_scale),
|
||||
softcapping_(softcapping) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& outputs) override {
|
||||
throw std::runtime_error("NYI");
|
||||
}
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, array& outputs) override;
|
||||
|
||||
DEFINE_PRINT(PagedAttention);
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
auto state() const {
|
||||
return std::make_tuple(max_context_len_, softmax_scale_, softcapping_);
|
||||
}
|
||||
|
||||
private:
|
||||
bool use_v1_;
|
||||
int max_context_len_;
|
||||
float softmax_scale_;
|
||||
std::optional<float> softcapping_ = std::nullopt;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::paged_attention
|
Loading…
Reference in New Issue
Block a user