mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-01 07:56:44 +08:00
add checks
This commit is contained in:
parent
3507c104a5
commit
c89ddf62b4
@ -22,18 +22,14 @@ instantiate_sdpa_vector_heads(bfloat16_t)
|
|||||||
instantiate_sdpa_vector_heads(float16_t)
|
instantiate_sdpa_vector_heads(float16_t)
|
||||||
|
|
||||||
// Quantized SDPA vector instantiations
|
// Quantized SDPA vector instantiations
|
||||||
#define instantiate_quant_sdpa_vector(name, type, head_dim, group_size, bits) \
|
#define instantiate_quant_sdpa_vector(type, head_dim, group_size, bits) \
|
||||||
instantiate_kernel( \
|
instantiate_kernel( \
|
||||||
#name "_" #type "_" #head_dim "_" #group_size "_" #bits, \
|
"quant_sdpa_vector_2pass_1_" #type "_" #head_dim "_" #group_size "_" #bits, \
|
||||||
name, type, head_dim, group_size, bits)
|
quant_sdpa_vector_2pass_1, type, head_dim, group_size, bits)
|
||||||
|
|
||||||
#define instantiate_quant_sdpa_vector_passes(type, heads, group_size, bits) \
|
|
||||||
instantiate_quant_sdpa_vector(quant_sdpa_vector, type, heads, group_size, bits) \
|
|
||||||
instantiate_quant_sdpa_vector(quant_sdpa_vector_2pass_1, type, heads, group_size, bits)
|
|
||||||
|
|
||||||
#define instantiate_quant_sdpa_vector_bits(type, heads, group_size) \
|
#define instantiate_quant_sdpa_vector_bits(type, heads, group_size) \
|
||||||
instantiate_quant_sdpa_vector_passes(type, heads, group_size, 4) \
|
instantiate_quant_sdpa_vector(type, heads, group_size, 4) \
|
||||||
instantiate_quant_sdpa_vector_passes(type, heads, group_size, 8)
|
instantiate_quant_sdpa_vector(type, heads, group_size, 8)
|
||||||
|
|
||||||
#define instantiate_quant_sdpa_vector_group_size(type, heads) \
|
#define instantiate_quant_sdpa_vector_group_size(type, heads) \
|
||||||
instantiate_quant_sdpa_vector_bits(type, heads, 32) \
|
instantiate_quant_sdpa_vector_bits(type, heads, 32) \
|
||||||
@ -42,7 +38,6 @@ instantiate_sdpa_vector_heads(float16_t)
|
|||||||
|
|
||||||
#define instantiate_quant_sdpa_vector_heads(type) \
|
#define instantiate_quant_sdpa_vector_heads(type) \
|
||||||
instantiate_quant_sdpa_vector_group_size(type, 64) \
|
instantiate_quant_sdpa_vector_group_size(type, 64) \
|
||||||
instantiate_quant_sdpa_vector_group_size(type, 96) \
|
|
||||||
instantiate_quant_sdpa_vector_group_size(type, 128)
|
instantiate_quant_sdpa_vector_group_size(type, 128)
|
||||||
|
|
||||||
instantiate_quant_sdpa_vector_heads(float)
|
instantiate_quant_sdpa_vector_heads(float)
|
||||||
|
@ -174,147 +174,6 @@ METAL_FUNC void load_values(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int D, int group_size, int bits>
|
|
||||||
[[kernel]] void quant_sdpa_vector(
|
|
||||||
const device T* queries [[buffer(0)]],
|
|
||||||
const device uint32_t* keys [[buffer(1)]],
|
|
||||||
const device T* key_scales [[buffer(2)]],
|
|
||||||
const device T* key_biases [[buffer(3)]],
|
|
||||||
const device uint32_t* values [[buffer(4)]],
|
|
||||||
const device T* value_scales [[buffer(5)]],
|
|
||||||
const device T* value_biases [[buffer(6)]],
|
|
||||||
device T* out [[buffer(7)]],
|
|
||||||
const constant int& gqa_factor,
|
|
||||||
const constant int& N,
|
|
||||||
const constant size_t& k_stride,
|
|
||||||
const constant size_t& group_stride,
|
|
||||||
const constant float& scale,
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint simd_lid [[thread_index_in_simdgroup]],
|
|
||||||
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
|
||||||
uint quad_lid [[thread_index_in_quadgroup]]) {
|
|
||||||
constexpr int BN = 32;
|
|
||||||
constexpr int BD = 4;
|
|
||||||
constexpr int elem_per_thread = D / BD;
|
|
||||||
constexpr int pack_factor = 32 / bits;
|
|
||||||
|
|
||||||
const int stride = BN * D;
|
|
||||||
|
|
||||||
typedef float U;
|
|
||||||
|
|
||||||
thread U q[elem_per_thread];
|
|
||||||
thread U k[elem_per_thread];
|
|
||||||
thread U v[elem_per_thread];
|
|
||||||
thread U o[elem_per_thread];
|
|
||||||
|
|
||||||
threadgroup U outputs[BN * BD];
|
|
||||||
threadgroup U max_scores[BN];
|
|
||||||
threadgroup U sum_exp_scores[BN];
|
|
||||||
|
|
||||||
// Adjust positions
|
|
||||||
const int head_idx = tid.y;
|
|
||||||
const int kv_head_idx = head_idx / gqa_factor;
|
|
||||||
queries += head_idx * D + quad_lid * elem_per_thread;
|
|
||||||
|
|
||||||
const int kv_idx = quad_gid * D + quad_lid * elem_per_thread;
|
|
||||||
const int packed_idx = kv_head_idx * k_stride + kv_idx / pack_factor;
|
|
||||||
const int group_idx = kv_head_idx * group_stride + kv_idx / group_size;
|
|
||||||
keys += packed_idx;
|
|
||||||
key_scales += group_idx;
|
|
||||||
key_biases += group_idx;
|
|
||||||
values += packed_idx;
|
|
||||||
value_scales += group_idx;
|
|
||||||
value_biases += group_idx;
|
|
||||||
|
|
||||||
out += head_idx * D + simd_gid * elem_per_thread;
|
|
||||||
|
|
||||||
// Read the query and 0 the output accumulator
|
|
||||||
U query_sum = load_queries<T, U, elem_per_thread, bits>(
|
|
||||||
queries, q, static_cast<U>(scale));
|
|
||||||
for (int i = 0; i < elem_per_thread; i++) {
|
|
||||||
o[i] = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
U max_score = -INFINITY;
|
|
||||||
U sum_exp_score = 0;
|
|
||||||
|
|
||||||
// For each key
|
|
||||||
for (int i = quad_gid; i < N; i += BN) {
|
|
||||||
load_keys<U, elem_per_thread, bits>(keys, k);
|
|
||||||
|
|
||||||
// Assume D % group_size == 0 so all the keys are in the same group
|
|
||||||
U key_scale = key_scales[0];
|
|
||||||
U key_bias = key_biases[0];
|
|
||||||
|
|
||||||
// Compute the i-th score
|
|
||||||
U score = 0;
|
|
||||||
for (int i = 0; i < elem_per_thread; i++) {
|
|
||||||
score += q[i] * k[i];
|
|
||||||
}
|
|
||||||
score = score * key_scale + query_sum * key_bias;
|
|
||||||
score = quad_sum(score);
|
|
||||||
|
|
||||||
// Update the accumulators
|
|
||||||
U new_max = max(max_score, score);
|
|
||||||
U factor = fast::exp(max_score - new_max);
|
|
||||||
U exp_score = fast::exp(score - new_max);
|
|
||||||
|
|
||||||
max_score = new_max;
|
|
||||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
|
||||||
|
|
||||||
U value_scale = value_scales[0];
|
|
||||||
U value_bias = value_biases[0];
|
|
||||||
|
|
||||||
// Load the values
|
|
||||||
load_values<U, elem_per_thread, bits>(values, v, value_scale, value_bias);
|
|
||||||
|
|
||||||
// Update the output accumulator
|
|
||||||
for (int i = 0; i < elem_per_thread; i++) {
|
|
||||||
o[i] = o[i] * factor + exp_score * v[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Move the pointers to the next kv
|
|
||||||
keys += stride / pack_factor;
|
|
||||||
key_scales += stride / group_size;
|
|
||||||
key_biases += stride / group_size;
|
|
||||||
values += stride / pack_factor;
|
|
||||||
value_scales += stride / group_size;
|
|
||||||
value_biases += stride / group_size;
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Each thread has a partial part of the output so we need to combine them.
|
|
||||||
|
|
||||||
// First let's communicate the max and sum_exp
|
|
||||||
// Each quadgroup communicates it's max score
|
|
||||||
if (quad_lid == 0) {
|
|
||||||
max_scores[quad_gid] = max_score;
|
|
||||||
sum_exp_scores[quad_gid] = sum_exp_score;
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
max_score = max_scores[simd_lid];
|
|
||||||
U new_max = simd_max(max_score);
|
|
||||||
U factor = fast::exp(max_score - new_max);
|
|
||||||
sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
|
|
||||||
|
|
||||||
// Now we need to aggregate all the outputs
|
|
||||||
for (int i = 0; i < elem_per_thread; i++) {
|
|
||||||
// 128 threads with 32 values per thread
|
|
||||||
outputs[simd_gid * BN + simd_lid] = o[i];
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
o[i] = simd_sum(outputs[simd_lid * BD + simd_gid] * factor) / sum_exp_score;
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
}
|
|
||||||
|
|
||||||
// And write the output
|
|
||||||
if (simd_lid == 0) {
|
|
||||||
for (int i = 0; i < elem_per_thread; i++) {
|
|
||||||
out[i] = static_cast<T>(o[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, int D>
|
template <typename T, int D>
|
||||||
[[kernel]] void sdpa_vector_2pass_1(
|
[[kernel]] void sdpa_vector_2pass_1(
|
||||||
const device T* queries [[buffer(0)]],
|
const device T* queries [[buffer(0)]],
|
||||||
|
@ -242,65 +242,6 @@ void sdpa_vector_2pass(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
void quant_sdpa_vector(
|
|
||||||
const Stream& s,
|
|
||||||
metal::Device& d,
|
|
||||||
const array& q,
|
|
||||||
const array& k,
|
|
||||||
const array& k_scales,
|
|
||||||
const array& k_biases,
|
|
||||||
const array& v,
|
|
||||||
const array& v_scales,
|
|
||||||
const array& v_biases,
|
|
||||||
array& out,
|
|
||||||
float scale,
|
|
||||||
int group_size,
|
|
||||||
int bits) {
|
|
||||||
// Set the kernel name
|
|
||||||
std::string kname;
|
|
||||||
kname.reserve(96);
|
|
||||||
kname += "quant_sdpa_vector_";
|
|
||||||
kname += get_type_string(q.dtype());
|
|
||||||
kname += "_";
|
|
||||||
kname += std::to_string(q.shape(-1));
|
|
||||||
kname += "_";
|
|
||||||
kname += std::to_string(group_size);
|
|
||||||
kname += "_";
|
|
||||||
kname += std::to_string(bits);
|
|
||||||
|
|
||||||
// Compute the necessary sizes
|
|
||||||
int gqa_factor = q.shape(1) / k.shape(1);
|
|
||||||
int N = k.shape(2);
|
|
||||||
int B = q.shape(0) * q.shape(1);
|
|
||||||
size_t stride = k.strides()[1];
|
|
||||||
size_t group_stride = k_scales.strides()[1];
|
|
||||||
MTL::Size group_dims(128, 1, 1);
|
|
||||||
MTL::Size grid_dims(1, B, 1);
|
|
||||||
|
|
||||||
// Get the kernel
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
||||||
auto kernel = d.get_kernel(kname);
|
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
|
||||||
|
|
||||||
// Set its arguments
|
|
||||||
compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0);
|
|
||||||
compute_encoder.set_input_array(k, 1);
|
|
||||||
compute_encoder.set_input_array(k_scales, 2);
|
|
||||||
compute_encoder.set_input_array(k_biases, 3);
|
|
||||||
compute_encoder.set_input_array(v, 4);
|
|
||||||
compute_encoder.set_input_array(v_scales, 5);
|
|
||||||
compute_encoder.set_input_array(v_biases, 6);
|
|
||||||
compute_encoder.set_output_array(out, 7);
|
|
||||||
compute_encoder.set_bytes(&gqa_factor, sizeof(int), 8);
|
|
||||||
compute_encoder.set_bytes(&N, sizeof(int), 9);
|
|
||||||
compute_encoder.set_bytes(&stride, sizeof(size_t), 10);
|
|
||||||
compute_encoder.set_bytes(&group_stride, sizeof(size_t), 11);
|
|
||||||
compute_encoder.set_bytes(&scale, sizeof(float), 12);
|
|
||||||
|
|
||||||
// Launch
|
|
||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
void quant_sdpa_vector_2pass(
|
void quant_sdpa_vector_2pass(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
|
61
mlx/fast.cpp
61
mlx/fast.cpp
@ -516,15 +516,11 @@ bool RoPE::is_equivalent(const Primitive& other) const {
|
|||||||
offset_ == a_other.offset_ && forward_ == a_other.forward_);
|
offset_ == a_other.offset_ && forward_ == a_other.forward_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Computes: O = softmax(Q @ K.T) @ V **/
|
void check_sdpa_arguments(
|
||||||
array scaled_dot_product_attention(
|
|
||||||
const array& queries,
|
const array& queries,
|
||||||
const array& keys,
|
const array& keys,
|
||||||
const array& values,
|
const array& values,
|
||||||
const float scale,
|
const std::optional<array>& mask) {
|
||||||
const std::optional<array>& mask,
|
|
||||||
const std::optional<int> memory_efficient_threshold,
|
|
||||||
StreamOrDevice s) {
|
|
||||||
for (const auto& tensor : {queries, keys, values}) {
|
for (const auto& tensor : {queries, keys, values}) {
|
||||||
if (tensor.ndim() != 4) {
|
if (tensor.ndim() != 4) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -550,14 +546,6 @@ array scaled_dot_product_attention(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Q, K must have matching last dims (d_k aka 'head_dim');
|
|
||||||
if (queries.shape(-1) != keys.shape(-1)) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[scaled_dot_product_attention] query, keys expected to have matching last dimension; found query shape "
|
|
||||||
<< queries.shape() << " for keys shape " << keys.shape() << ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// K, V must have matching number of heads (n_kv_heads);
|
// K, V must have matching number of heads (n_kv_heads);
|
||||||
auto n_q_heads = queries.shape(-3);
|
auto n_q_heads = queries.shape(-3);
|
||||||
auto n_kv_heads = keys.shape(-3);
|
auto n_kv_heads = keys.shape(-3);
|
||||||
@ -577,6 +565,26 @@ array scaled_dot_product_attention(
|
|||||||
<< n_q_heads << " for n_kv_heads " << n_kv_heads << ".";
|
<< n_q_heads << " for n_kv_heads " << n_kv_heads << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Computes: O = softmax(Q @ K.T) @ V **/
|
||||||
|
array scaled_dot_product_attention(
|
||||||
|
const array& queries,
|
||||||
|
const array& keys,
|
||||||
|
const array& values,
|
||||||
|
const float scale,
|
||||||
|
const std::optional<array>& mask,
|
||||||
|
const std::optional<int> memory_efficient_threshold,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
check_sdpa_arguments(queries, keys, values, mask);
|
||||||
|
|
||||||
|
// Q, K must have matching last dims (d_k aka 'head_dim');
|
||||||
|
if (queries.shape(-1) != keys.shape(-1)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[scaled_dot_product_attention] query, keys expected to have matching last dimension; found query shape "
|
||||||
|
<< queries.shape() << " for keys shape " << keys.shape() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
auto final_type = result_type(queries, keys, values);
|
auto final_type = result_type(queries, keys, values);
|
||||||
if (!issubdtype(final_type, floating)) {
|
if (!issubdtype(final_type, floating)) {
|
||||||
@ -590,6 +598,9 @@ array scaled_dot_product_attention(
|
|||||||
auto k = astype(keys, final_type, s);
|
auto k = astype(keys, final_type, s);
|
||||||
auto v = astype(values, final_type, s);
|
auto v = astype(values, final_type, s);
|
||||||
|
|
||||||
|
auto n_q_heads = queries.shape(-3);
|
||||||
|
auto n_kv_heads = keys.shape(-3);
|
||||||
|
|
||||||
/* generic implementation for use cases that Metal implementation does not
|
/* generic implementation for use cases that Metal implementation does not
|
||||||
* support. For non-supported cases listed below, use MLX primitives:
|
* support. For non-supported cases listed below, use MLX primitives:
|
||||||
* * CPU implementation
|
* * CPU implementation
|
||||||
@ -696,6 +707,25 @@ array quantized_scaled_dot_product_attention(
|
|||||||
const int bits,
|
const int bits,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
int el_per_int = 32 / bits;
|
int el_per_int = 32 / bits;
|
||||||
|
|
||||||
|
check_sdpa_arguments(queries, keys, values, mask);
|
||||||
|
|
||||||
|
// Q, K must have matching last dims (d_k aka 'head_dim');
|
||||||
|
if (queries.shape(-1) != keys.shape(-1) * el_per_int) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[scaled_dot_product_attention] query, keys expected to have matching last dimension; found query shape "
|
||||||
|
<< queries.shape() << " for keys shape " << keys.shape() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto final_type = result_type(queries, key_scales, value_scales);
|
||||||
|
if (!issubdtype(final_type, floating)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[scaled_dot_product_attention] Received unsupported type "
|
||||||
|
<< final_type << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
int out_dim = values.shape(-1) * el_per_int;
|
int out_dim = values.shape(-1) * el_per_int;
|
||||||
|
|
||||||
auto n_q_heads = queries.shape(-3);
|
auto n_q_heads = queries.shape(-3);
|
||||||
@ -760,8 +790,9 @@ array quantized_scaled_dot_product_attention(
|
|||||||
return std::vector<array>{out};
|
return std::vector<array>{out};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
int query_head_dim = queries.shape(-1);
|
||||||
int L = queries.shape(2);
|
int L = queries.shape(2);
|
||||||
if (L > 1) {
|
if (L > 1 && query_head_dim != 64 && query_head_dim != 128) {
|
||||||
if (needs_mask) {
|
if (needs_mask) {
|
||||||
return fallback(
|
return fallback(
|
||||||
{queries,
|
{queries,
|
||||||
|
Loading…
Reference in New Issue
Block a user