mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fully wrap the command encoder (#1572)
* fully wrap the command encoder * use consistent style + fix extensions
This commit is contained in:
@@ -249,7 +249,7 @@ void steel_matmul_regular(
|
||||
wm,
|
||||
wn);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
int tn = (N + bn - 1) / bn;
|
||||
@@ -288,12 +288,12 @@ void steel_matmul_regular(
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
|
||||
set_vector_bytes(compute_encoder, batch_shape, 6);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 7);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
// Record copies
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
@@ -390,7 +390,7 @@ void steel_matmul(
|
||||
wn,
|
||||
mn_aligned,
|
||||
k_aligned);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int tn = (N + bn - 1) / bn;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
@@ -416,34 +416,30 @@ void steel_matmul(
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(C_split, 2);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.set_bytes(params, 3);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
// Do accum kernel
|
||||
{
|
||||
auto c_split_buf =
|
||||
static_cast<const MTL::Resource*>(C_split.buffer().ptr());
|
||||
const class MTL::Resource* const resources[1] = {c_split_buf};
|
||||
compute_encoder->memoryBarrier(resources, 1);
|
||||
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
||||
type_to_name(C_split);
|
||||
|
||||
auto kernel = get_steel_gemm_splitk_accum_kernel(
|
||||
d, kernel_name, C_split, out, false);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set the arguments for the kernel
|
||||
compute_encoder.set_input_array(C_split, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 4);
|
||||
compute_encoder.set_bytes(split_k_partitions, 2);
|
||||
compute_encoder.set_bytes(split_k_partition_stride, 3);
|
||||
compute_encoder.set_bytes(N, 4);
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
MTL::Size grid_dims = MTL::Size(N, M, 1);
|
||||
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
|
||||
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
@@ -625,7 +621,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
@@ -635,16 +631,16 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
|
||||
compute_encoder.set_bytes(in_vector_len, 4);
|
||||
compute_encoder.set_bytes(out_vector_len, 5);
|
||||
compute_encoder.set_bytes(mat_ld, 6);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, batch_strides_vec, 11);
|
||||
set_vector_bytes(compute_encoder, batch_strides_mat, 12);
|
||||
compute_encoder.set_bytes(batch_ndim, 9);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 10);
|
||||
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
|
||||
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
@@ -822,7 +818,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
@@ -833,23 +829,23 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(c, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
|
||||
compute_encoder.set_bytes(in_vector_len, 4);
|
||||
compute_encoder.set_bytes(out_vector_len, 5);
|
||||
compute_encoder.set_bytes(mat_ld, 6);
|
||||
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 7);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 8);
|
||||
compute_encoder.set_bytes(alpha_, 7);
|
||||
compute_encoder.set_bytes(beta_, 8);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, batch_strides_vec, 11);
|
||||
set_vector_bytes(compute_encoder, batch_strides_mat, 12);
|
||||
set_vector_bytes(compute_encoder, C_batch_stride, 13);
|
||||
compute_encoder.set_bytes(batch_ndim, 9);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 10);
|
||||
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
|
||||
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
|
||||
compute_encoder.set_vector_bytes(C_batch_stride, 13);
|
||||
|
||||
int bias_stride = c.strides()[c.ndim() - 1];
|
||||
compute_encoder->setBytes(&bias_stride, sizeof(int), 14);
|
||||
compute_encoder.set_bytes(bias_stride, 14);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
@@ -907,7 +903,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
mn_aligned,
|
||||
k_aligned);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int tn = (N + bn - 1) / bn;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
@@ -933,8 +929,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(C_split, 2);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.set_bytes(params, 3);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
// Do accum kernel
|
||||
{
|
||||
@@ -943,25 +939,25 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto kernel = get_steel_gemm_splitk_accum_kernel(
|
||||
d, kernel_name, C_split, out, true);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set the arguments for the kernel
|
||||
compute_encoder.set_input_array(C_split, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 4);
|
||||
compute_encoder.set_bytes(split_k_partitions, 2);
|
||||
compute_encoder.set_bytes(split_k_partition_stride, 3);
|
||||
compute_encoder.set_bytes(N, 4);
|
||||
compute_encoder.set_input_array(c, 5);
|
||||
compute_encoder->setBytes(&ldc, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&fdc, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 8);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 9);
|
||||
compute_encoder.set_bytes(ldc, 6);
|
||||
compute_encoder.set_bytes(fdc, 7);
|
||||
compute_encoder.set_bytes(alpha_, 8);
|
||||
compute_encoder.set_bytes(beta_, 9);
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
MTL::Size grid_dims = MTL::Size(N, M, 1);
|
||||
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
|
||||
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
@@ -1032,7 +1028,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
wm,
|
||||
wn);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int tn = (N + bn - 1) / bn;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
@@ -1083,13 +1079,13 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(c, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&gemm_params, sizeof(GEMMParams), 4);
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 5);
|
||||
compute_encoder.set_bytes(gemm_params, 4);
|
||||
compute_encoder.set_bytes(params, 5);
|
||||
|
||||
set_vector_bytes(compute_encoder, batch_shape, 6);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 7);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
@@ -1304,7 +1300,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
contiguous_kernel);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
@@ -1372,18 +1368,18 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, batch_strides_vec, 11);
|
||||
set_vector_bytes(compute_encoder, batch_strides_mat, 12);
|
||||
compute_encoder.set_bytes(in_vector_len, 4);
|
||||
compute_encoder.set_bytes(out_vector_len, 5);
|
||||
compute_encoder.set_bytes(mat_ld, 6);
|
||||
compute_encoder.set_bytes(batch_ndim, 9);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 10);
|
||||
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
|
||||
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
|
||||
|
||||
set_vector_bytes(compute_encoder, mask_strides, 23);
|
||||
set_vector_bytes(compute_encoder, mask_batch_strides, 24);
|
||||
compute_encoder.set_vector_bytes(mask_strides, 23);
|
||||
compute_encoder.set_vector_bytes(mask_batch_strides, 24);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
@@ -1423,7 +1419,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
wn,
|
||||
mn_aligned,
|
||||
k_aligned);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
int tn = (N + bn - 1) / bn;
|
||||
@@ -1486,14 +1482,14 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
|
||||
set_vector_bytes(compute_encoder, batch_shape, 6);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 7);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
||||
|
||||
set_vector_bytes(compute_encoder, mask_strides, 13);
|
||||
compute_encoder.set_vector_bytes(mask_strides, 13);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
@@ -1687,7 +1683,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
@@ -1697,28 +1693,28 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
|
||||
compute_encoder.set_bytes(in_vector_len, 4);
|
||||
compute_encoder.set_bytes(out_vector_len, 5);
|
||||
compute_encoder.set_bytes(mat_ld, 6);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 11);
|
||||
compute_encoder.set_bytes(batch_ndim, 9);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 10);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 11);
|
||||
|
||||
int batch_ndim_vec = batch_shape_vec.size();
|
||||
compute_encoder->setBytes(&batch_ndim_vec, sizeof(int), 12);
|
||||
set_vector_bytes(compute_encoder, batch_shape_vec, 13);
|
||||
set_vector_bytes(compute_encoder, batch_strides_vec, 14);
|
||||
compute_encoder.set_bytes(batch_ndim_vec, 12);
|
||||
compute_encoder.set_vector_bytes(batch_shape_vec, 13);
|
||||
compute_encoder.set_vector_bytes(batch_strides_vec, 14);
|
||||
|
||||
int batch_ndim_mat = batch_shape_mat.size();
|
||||
compute_encoder->setBytes(&batch_ndim_mat, sizeof(int), 15);
|
||||
set_vector_bytes(compute_encoder, batch_shape_mat, 16);
|
||||
set_vector_bytes(compute_encoder, batch_strides_mat, 17);
|
||||
compute_encoder.set_bytes(batch_ndim_mat, 15);
|
||||
compute_encoder.set_vector_bytes(batch_shape_mat, 16);
|
||||
compute_encoder.set_vector_bytes(batch_strides_mat, 17);
|
||||
|
||||
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
|
||||
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
@@ -1788,7 +1784,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
wm,
|
||||
wn);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
int tn = (N + bn - 1) / bn;
|
||||
@@ -1827,10 +1823,10 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
|
||||
set_vector_bytes(compute_encoder, batch_shape, 6);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 7);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
||||
|
||||
compute_encoder.set_input_array(lhs_indices, 10);
|
||||
compute_encoder.set_input_array(rhs_indices, 11);
|
||||
@@ -1845,11 +1841,11 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
operand_batch_ndim.push_back(0);
|
||||
|
||||
set_vector_bytes(compute_encoder, operand_shape, 13);
|
||||
set_vector_bytes(compute_encoder, operand_strides, 14);
|
||||
set_vector_bytes(compute_encoder, operand_batch_ndim, 15);
|
||||
compute_encoder.set_vector_bytes(operand_shape, 13);
|
||||
compute_encoder.set_vector_bytes(operand_strides, 14);
|
||||
compute_encoder.set_vector_bytes(operand_batch_ndim, 15);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user