mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Sdpa fix (#1558)
This commit is contained in:
parent
09bc32f62f
commit
62f297b51d
@ -936,6 +936,7 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
|
|||||||
const constant int& gqa_factor, \
|
const constant int& gqa_factor, \
|
||||||
const constant int& N, \
|
const constant int& N, \
|
||||||
const constant size_t& k_stride, \
|
const constant size_t& k_stride, \
|
||||||
|
const constant size_t& v_stride, \
|
||||||
const constant float& scale, \
|
const constant float& scale, \
|
||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
|
@ -13,6 +13,7 @@ template <typename T, int D>
|
|||||||
const constant int& gqa_factor,
|
const constant int& gqa_factor,
|
||||||
const constant int& N,
|
const constant int& N,
|
||||||
const constant size_t& k_stride,
|
const constant size_t& k_stride,
|
||||||
|
const constant size_t& v_stride,
|
||||||
const constant float& scale,
|
const constant float& scale,
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
@ -38,7 +39,7 @@ template <typename T, int D>
|
|||||||
const int kv_head_idx = head_idx / gqa_factor;
|
const int kv_head_idx = head_idx / gqa_factor;
|
||||||
queries += head_idx * D + simd_lid * elem_per_thread;
|
queries += head_idx * D + simd_lid * elem_per_thread;
|
||||||
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
|
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
|
||||||
values += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
|
values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
|
||||||
out += head_idx * D + simd_gid * elem_per_thread;
|
out += head_idx * D + simd_gid * elem_per_thread;
|
||||||
|
|
||||||
// Read the query and 0 the output accumulator
|
// Read the query and 0 the output accumulator
|
||||||
|
@ -162,7 +162,8 @@ void sdpa_vector(
|
|||||||
int gqa_factor = q.shape(1) / k.shape(1);
|
int gqa_factor = q.shape(1) / k.shape(1);
|
||||||
int N = k.shape(2);
|
int N = k.shape(2);
|
||||||
int B = q.shape(0) * q.shape(1);
|
int B = q.shape(0) * q.shape(1);
|
||||||
size_t stride = k.strides()[1];
|
size_t k_stride = k.strides()[1];
|
||||||
|
size_t v_stride = v.strides()[1];
|
||||||
MTL::Size group_dims(1024, 1, 1);
|
MTL::Size group_dims(1024, 1, 1);
|
||||||
MTL::Size grid_dims(1, B, 1);
|
MTL::Size grid_dims(1, B, 1);
|
||||||
|
|
||||||
@ -178,8 +179,9 @@ void sdpa_vector(
|
|||||||
compute_encoder.set_output_array(out, 3);
|
compute_encoder.set_output_array(out, 3);
|
||||||
compute_encoder->setBytes(&gqa_factor, sizeof(int), 4);
|
compute_encoder->setBytes(&gqa_factor, sizeof(int), 4);
|
||||||
compute_encoder->setBytes(&N, sizeof(int), 5);
|
compute_encoder->setBytes(&N, sizeof(int), 5);
|
||||||
compute_encoder->setBytes(&stride, sizeof(size_t), 6);
|
compute_encoder->setBytes(&k_stride, sizeof(size_t), 6);
|
||||||
compute_encoder->setBytes(&scale, sizeof(float), 7);
|
compute_encoder->setBytes(&v_stride, sizeof(size_t), 7);
|
||||||
|
compute_encoder->setBytes(&scale, sizeof(float), 8);
|
||||||
|
|
||||||
// Launch
|
// Launch
|
||||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
@ -167,6 +167,15 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol))
|
self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol))
|
||||||
|
|
||||||
|
q = mx.random.normal(shape=(1, 32, 1, Dk))
|
||||||
|
k = mx.random.normal(shape=(1, 32, 32, Dk))
|
||||||
|
v = mx.random.normal(shape=(1, 32, 128, Dk))
|
||||||
|
|
||||||
|
atol = 1e-6
|
||||||
|
y = mlx_primitives_sdpa(q, k, v[:, :, :32], scale)
|
||||||
|
y_hat = mx.fast.scaled_dot_product_attention(q, k, v[:, :, :32], scale=scale)
|
||||||
|
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main(failfast=True)
|
unittest.main(failfast=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user