Fully wrap the command encoder (#1572)

* fully wrap the command encoder

* use consistent style + fix extensions
This commit is contained in:
Awni Hannun
2024-11-08 11:50:21 -08:00
committed by GitHub
parent 59247c2b62
commit 9f0d5c12fc
27 changed files with 469 additions and 484 deletions

View File

@@ -249,7 +249,7 @@ void steel_matmul_regular(
wm,
wn);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
// Use problem size to determine threadblock swizzle
int tn = (N + bn - 1) / bn;
@@ -288,12 +288,12 @@ void steel_matmul_regular(
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
compute_encoder.set_bytes(params, 4);
set_vector_bytes(compute_encoder, batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder.set_vector_bytes(batch_shape, 6);
compute_encoder.set_vector_bytes(batch_strides, 7);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
// Record copies
d.add_temporaries(std::move(copies), s.index);
@@ -390,7 +390,7 @@ void steel_matmul(
wn,
mn_aligned,
k_aligned);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
@@ -416,34 +416,30 @@ void steel_matmul(
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(C_split, 2);
compute_encoder->setBytes(&params, sizeof(GEMMSpiltKParams), 3);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.set_bytes(params, 3);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
// Do accum kernel
{
auto c_split_buf =
static_cast<const MTL::Resource*>(C_split.buffer().ptr());
const class MTL::Resource* const resources[1] = {c_split_buf};
compute_encoder->memoryBarrier(resources, 1);
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
type_to_name(C_split);
auto kernel = get_steel_gemm_splitk_accum_kernel(
d, kernel_name, C_split, out, false);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
// Set the arguments for the kernel
compute_encoder.set_input_array(C_split, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
compute_encoder->setBytes(&N, sizeof(int), 4);
compute_encoder.set_bytes(split_k_partitions, 2);
compute_encoder.set_bytes(split_k_partition_stride, 3);
compute_encoder.set_bytes(N, 4);
// Launch enough thread groups for each output
MTL::Size grid_dims = MTL::Size(N, M, 1);
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
d.add_temporaries(std::move(copies), s.index);
@@ -625,7 +621,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm);
@@ -635,16 +631,16 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
compute_encoder.set_bytes(in_vector_len, 4);
compute_encoder.set_bytes(out_vector_len, 5);
compute_encoder.set_bytes(mat_ld, 6);
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
set_vector_bytes(compute_encoder, batch_shape, 10);
set_vector_bytes(compute_encoder, batch_strides_vec, 11);
set_vector_bytes(compute_encoder, batch_strides_mat, 12);
compute_encoder.set_bytes(batch_ndim, 9);
compute_encoder.set_vector_bytes(batch_shape, 10);
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
return;
@@ -822,7 +818,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm);
@@ -833,23 +829,23 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(c, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
compute_encoder.set_bytes(in_vector_len, 4);
compute_encoder.set_bytes(out_vector_len, 5);
compute_encoder.set_bytes(mat_ld, 6);
compute_encoder->setBytes(&alpha_, sizeof(float), 7);
compute_encoder->setBytes(&beta_, sizeof(float), 8);
compute_encoder.set_bytes(alpha_, 7);
compute_encoder.set_bytes(beta_, 8);
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
set_vector_bytes(compute_encoder, batch_shape, 10);
set_vector_bytes(compute_encoder, batch_strides_vec, 11);
set_vector_bytes(compute_encoder, batch_strides_mat, 12);
set_vector_bytes(compute_encoder, C_batch_stride, 13);
compute_encoder.set_bytes(batch_ndim, 9);
compute_encoder.set_vector_bytes(batch_shape, 10);
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
compute_encoder.set_vector_bytes(C_batch_stride, 13);
int bias_stride = c.strides()[c.ndim() - 1];
compute_encoder->setBytes(&bias_stride, sizeof(int), 14);
compute_encoder.set_bytes(bias_stride, 14);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
return;
@@ -907,7 +903,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
mn_aligned,
k_aligned);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
@@ -933,8 +929,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(C_split, 2);
compute_encoder->setBytes(&params, sizeof(GEMMSpiltKParams), 3);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.set_bytes(params, 3);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
// Do accum kernel
{
@@ -943,25 +939,25 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = get_steel_gemm_splitk_accum_kernel(
d, kernel_name, C_split, out, true);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
// Set the arguments for the kernel
compute_encoder.set_input_array(C_split, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
compute_encoder->setBytes(&N, sizeof(int), 4);
compute_encoder.set_bytes(split_k_partitions, 2);
compute_encoder.set_bytes(split_k_partition_stride, 3);
compute_encoder.set_bytes(N, 4);
compute_encoder.set_input_array(c, 5);
compute_encoder->setBytes(&ldc, sizeof(int), 6);
compute_encoder->setBytes(&fdc, sizeof(int), 7);
compute_encoder->setBytes(&alpha_, sizeof(float), 8);
compute_encoder->setBytes(&beta_, sizeof(float), 9);
compute_encoder.set_bytes(ldc, 6);
compute_encoder.set_bytes(fdc, 7);
compute_encoder.set_bytes(alpha_, 8);
compute_encoder.set_bytes(beta_, 9);
// Launch enough thread groups for each output
MTL::Size grid_dims = MTL::Size(N, M, 1);
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
d.add_temporaries(std::move(copies), s.index);
@@ -1032,7 +1028,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
wm,
wn);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
@@ -1083,13 +1079,13 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(c, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&gemm_params, sizeof(GEMMParams), 4);
compute_encoder->setBytes(&params, sizeof(GEMMAddMMParams), 5);
compute_encoder.set_bytes(gemm_params, 4);
compute_encoder.set_bytes(params, 5);
set_vector_bytes(compute_encoder, batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder.set_vector_bytes(batch_shape, 6);
compute_encoder.set_vector_bytes(batch_strides, 7);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
}
@@ -1304,7 +1300,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
contiguous_kernel);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm);
@@ -1372,18 +1368,18 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
set_vector_bytes(compute_encoder, batch_shape, 10);
set_vector_bytes(compute_encoder, batch_strides_vec, 11);
set_vector_bytes(compute_encoder, batch_strides_mat, 12);
compute_encoder.set_bytes(in_vector_len, 4);
compute_encoder.set_bytes(out_vector_len, 5);
compute_encoder.set_bytes(mat_ld, 6);
compute_encoder.set_bytes(batch_ndim, 9);
compute_encoder.set_vector_bytes(batch_shape, 10);
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
set_vector_bytes(compute_encoder, mask_strides, 23);
set_vector_bytes(compute_encoder, mask_batch_strides, 24);
compute_encoder.set_vector_bytes(mask_strides, 23);
compute_encoder.set_vector_bytes(mask_batch_strides, 24);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
return;
@@ -1423,7 +1419,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
wn,
mn_aligned,
k_aligned);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
// Use problem size to determine threadblock swizzle
int tn = (N + bn - 1) / bn;
@@ -1486,14 +1482,14 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
compute_encoder.set_bytes(params, 4);
set_vector_bytes(compute_encoder, batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder.set_vector_bytes(batch_shape, 6);
compute_encoder.set_vector_bytes(batch_strides, 7);
set_vector_bytes(compute_encoder, mask_strides, 13);
compute_encoder.set_vector_bytes(mask_strides, 13);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
}
@@ -1687,7 +1683,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm);
@@ -1697,28 +1693,28 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
compute_encoder.set_bytes(in_vector_len, 4);
compute_encoder.set_bytes(out_vector_len, 5);
compute_encoder.set_bytes(mat_ld, 6);
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
set_vector_bytes(compute_encoder, batch_shape, 10);
set_vector_bytes(compute_encoder, batch_strides, 11);
compute_encoder.set_bytes(batch_ndim, 9);
compute_encoder.set_vector_bytes(batch_shape, 10);
compute_encoder.set_vector_bytes(batch_strides, 11);
int batch_ndim_vec = batch_shape_vec.size();
compute_encoder->setBytes(&batch_ndim_vec, sizeof(int), 12);
set_vector_bytes(compute_encoder, batch_shape_vec, 13);
set_vector_bytes(compute_encoder, batch_strides_vec, 14);
compute_encoder.set_bytes(batch_ndim_vec, 12);
compute_encoder.set_vector_bytes(batch_shape_vec, 13);
compute_encoder.set_vector_bytes(batch_strides_vec, 14);
int batch_ndim_mat = batch_shape_mat.size();
compute_encoder->setBytes(&batch_ndim_mat, sizeof(int), 15);
set_vector_bytes(compute_encoder, batch_shape_mat, 16);
set_vector_bytes(compute_encoder, batch_strides_mat, 17);
compute_encoder.set_bytes(batch_ndim_mat, 15);
compute_encoder.set_vector_bytes(batch_shape_mat, 16);
compute_encoder.set_vector_bytes(batch_strides_mat, 17);
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
return;
@@ -1788,7 +1784,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
wm,
wn);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
// Use problem size to determine threadblock swizzle
int tn = (N + bn - 1) / bn;
@@ -1827,10 +1823,10 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
compute_encoder.set_bytes(params, 4);
set_vector_bytes(compute_encoder, batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder.set_vector_bytes(batch_shape, 6);
compute_encoder.set_vector_bytes(batch_strides, 7);
compute_encoder.set_input_array(lhs_indices, 10);
compute_encoder.set_input_array(rhs_indices, 11);
@@ -1845,11 +1841,11 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
operand_batch_ndim.push_back(0);
set_vector_bytes(compute_encoder, operand_shape, 13);
set_vector_bytes(compute_encoder, operand_strides, 14);
set_vector_bytes(compute_encoder, operand_batch_ndim, 15);
compute_encoder.set_vector_bytes(operand_shape, 13);
compute_encoder.set_vector_bytes(operand_strides, 14);
compute_encoder.set_vector_bytes(operand_batch_ndim, 15);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
}