mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fully wrap the command encoder (#1572)
* fully wrap the command encoder * use consistent style + fix extensions
This commit is contained in:
@@ -59,7 +59,7 @@ void sdpa_full_self_attention_metal(
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname_self_attention.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
uint hidden_dim = q.shape(-1);
|
||||
uint qseq = q.shape(-2);
|
||||
@@ -129,17 +129,14 @@ void sdpa_full_self_attention_metal(
|
||||
compute_encoder.set_input_array(v, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(MLXFastAttentionParams), 4);
|
||||
compute_encoder->setBytes(
|
||||
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
||||
|
||||
MTL::Size grid_dims = MTL::Size(1, tm, batch_size_out);
|
||||
MTL::Size group_dims = MTL::Size(32, wm, wn);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void sdpa_vector(
|
||||
@@ -170,21 +167,21 @@ void sdpa_vector(
|
||||
// Get the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set its arguments
|
||||
compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0);
|
||||
compute_encoder.set_input_array(k, 1);
|
||||
compute_encoder.set_input_array(v, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
compute_encoder->setBytes(&gqa_factor, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&k_stride, sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&v_stride, sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&scale, sizeof(float), 8);
|
||||
compute_encoder.set_bytes(gqa_factor, 4);
|
||||
compute_encoder.set_bytes(N, 5);
|
||||
compute_encoder.set_bytes(k_stride, 6);
|
||||
compute_encoder.set_bytes(v_stride, 7);
|
||||
compute_encoder.set_bytes(scale, 8);
|
||||
|
||||
// Launch
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Reference in New Issue
Block a user