mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fully wrap the command encoder (#1572)
* fully wrap the command encoder * use consistent style + fix extensions
This commit is contained in:
@@ -101,31 +101,31 @@ void launch_qmm(
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
compute_encoder.set_bytes(D, 5);
|
||||
compute_encoder.set_bytes(O, 6);
|
||||
|
||||
int offset = 7;
|
||||
if (matrix) {
|
||||
compute_encoder->setBytes(&B, sizeof(int), 7);
|
||||
compute_encoder.set_bytes(B, 7);
|
||||
offset += 1;
|
||||
}
|
||||
|
||||
if (batched || gather) {
|
||||
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), offset);
|
||||
set_vector_bytes(compute_encoder, x_shape, offset + 1);
|
||||
set_vector_bytes(compute_encoder, x_strides, offset + 2);
|
||||
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), offset + 3);
|
||||
set_vector_bytes(compute_encoder, w_shape, offset + 4);
|
||||
set_vector_bytes(compute_encoder, w_strides, offset + 5);
|
||||
set_vector_bytes(compute_encoder, s_strides, offset + 6);
|
||||
set_vector_bytes(compute_encoder, b_strides, offset + 7);
|
||||
compute_encoder.set_bytes(x_batch_ndims, offset);
|
||||
compute_encoder.set_vector_bytes(x_shape, offset + 1);
|
||||
compute_encoder.set_vector_bytes(x_strides, offset + 2);
|
||||
compute_encoder.set_bytes(w_batch_ndims, offset + 3);
|
||||
compute_encoder.set_vector_bytes(w_shape, offset + 4);
|
||||
compute_encoder.set_vector_bytes(w_strides, offset + 5);
|
||||
compute_encoder.set_vector_bytes(s_strides, offset + 6);
|
||||
compute_encoder.set_vector_bytes(b_strides, offset + 7);
|
||||
}
|
||||
if (gather) {
|
||||
auto& lhs_indices = inputs[4];
|
||||
@@ -137,15 +137,15 @@ void launch_qmm(
|
||||
auto& lhs_strides = lhs_indices.strides();
|
||||
auto& rhs_strides = rhs_indices.strides();
|
||||
|
||||
compute_encoder->setBytes(&batch_ndims, sizeof(int), offset + 8);
|
||||
set_vector_bytes(compute_encoder, batch_shape, offset + 9);
|
||||
compute_encoder.set_bytes(batch_ndims, offset + 8);
|
||||
compute_encoder.set_vector_bytes(batch_shape, offset + 9);
|
||||
compute_encoder.set_input_array(lhs_indices, offset + 10);
|
||||
compute_encoder.set_input_array(rhs_indices, offset + 11);
|
||||
set_vector_bytes(compute_encoder, lhs_strides, offset + 12);
|
||||
set_vector_bytes(compute_encoder, rhs_strides, offset + 13);
|
||||
compute_encoder.set_vector_bytes(lhs_strides, offset + 12);
|
||||
compute_encoder.set_vector_bytes(rhs_strides, offset + 13);
|
||||
}
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
@@ -236,27 +236,27 @@ void qvm_split_k(
|
||||
// Encode and dispatch kernel
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_output_array(intermediate, 4);
|
||||
compute_encoder->setBytes(&split_D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
compute_encoder.set_bytes(split_D, 5);
|
||||
compute_encoder.set_bytes(O, 6);
|
||||
|
||||
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 7);
|
||||
set_vector_bytes(compute_encoder, x_shape, 8);
|
||||
set_vector_bytes(compute_encoder, x_strides, 9);
|
||||
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 10);
|
||||
set_vector_bytes(compute_encoder, w_shape, 11);
|
||||
set_vector_bytes(compute_encoder, w_strides, 12);
|
||||
set_vector_bytes(compute_encoder, s_strides, 13);
|
||||
set_vector_bytes(compute_encoder, b_strides, 14);
|
||||
compute_encoder->setBytes(&final_block_size, sizeof(int), 15);
|
||||
compute_encoder.set_bytes(x_batch_ndims, 7);
|
||||
compute_encoder.set_vector_bytes(x_shape, 8);
|
||||
compute_encoder.set_vector_bytes(x_strides, 9);
|
||||
compute_encoder.set_bytes(w_batch_ndims, 10);
|
||||
compute_encoder.set_vector_bytes(w_shape, 11);
|
||||
compute_encoder.set_vector_bytes(w_strides, 12);
|
||||
compute_encoder.set_vector_bytes(s_strides, 13);
|
||||
compute_encoder.set_vector_bytes(b_strides, 14);
|
||||
compute_encoder.set_bytes(final_block_size, 15);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
|
||||
int axis = intermediate.ndim() - 3;
|
||||
@@ -447,7 +447,7 @@ void fast::AffineQuantize::eval_gpu(
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), kernel_func, type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Treat uint32 as uint8 in kernel
|
||||
constexpr int uint8_per_uint32 = 4;
|
||||
@@ -471,7 +471,7 @@ void fast::AffineQuantize::eval_gpu(
|
||||
}
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(grid_shape, w.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user