Fully wrap the command encoder (#1572)

* fully wrap the command encoder

* use consistent style + fix extensions
This commit is contained in:
Awni Hannun
2024-11-08 11:50:21 -08:00
committed by GitHub
parent 59247c2b62
commit 9f0d5c12fc
27 changed files with 469 additions and 484 deletions

View File

@@ -63,7 +63,7 @@ void ternary_op_gpu_inplace(
auto kernel = get_ternary_kernel(d, kernel_name, out.dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
bool donate_c = c.data_shared_ptr() == nullptr;
@@ -80,18 +80,18 @@ void ternary_op_gpu_inplace(
size_t rest = out.size() / (dim0 * dim1);
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7);
compute_encoder.set_vector_bytes(shape, 4);
compute_encoder.set_vector_bytes(strides_a, 5);
compute_encoder.set_vector_bytes(strides_b, 6);
compute_encoder.set_vector_bytes(strides_c, 7);
compute_encoder->setBytes(&ndim, sizeof(int), 8);
compute_encoder.set_bytes(ndim, 8);
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6);
compute_encoder.set_vector_bytes(strides_a, 4);
compute_encoder.set_vector_bytes(strides_b, 5);
compute_encoder.set_vector_bytes(strides_c, 6);
}
if (thread_group_size != 1024) {
@@ -99,7 +99,7 @@ void ternary_op_gpu_inplace(
}
MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatch_threads(grid_dims, group_dims);
} else {
// Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size();
@@ -109,7 +109,7 @@ void ternary_op_gpu_inplace(
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
}