mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Sdpa fix (#1558)
This commit is contained in:
committed by
GitHub
parent
09bc32f62f
commit
62f297b51d
@@ -162,7 +162,8 @@ void sdpa_vector(
|
||||
int gqa_factor = q.shape(1) / k.shape(1);
|
||||
int N = k.shape(2);
|
||||
int B = q.shape(0) * q.shape(1);
|
||||
size_t stride = k.strides()[1];
|
||||
size_t k_stride = k.strides()[1];
|
||||
size_t v_stride = v.strides()[1];
|
||||
MTL::Size group_dims(1024, 1, 1);
|
||||
MTL::Size grid_dims(1, B, 1);
|
||||
|
||||
@@ -178,8 +179,9 @@ void sdpa_vector(
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
compute_encoder->setBytes(&gqa_factor, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&stride, sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&scale, sizeof(float), 7);
|
||||
compute_encoder->setBytes(&k_stride, sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&v_stride, sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&scale, sizeof(float), 8);
|
||||
|
||||
// Launch
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
Reference in New Issue
Block a user