mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fully wrap the command encoder (#1572)
* fully wrap the command encoder * use consistent style + fix extensions
This commit is contained in:
@@ -63,7 +63,7 @@ void ternary_op_gpu_inplace(
|
||||
auto kernel = get_ternary_kernel(d, kernel_name, out.dtype(), op);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||
bool donate_c = c.data_shared_ptr() == nullptr;
|
||||
@@ -80,18 +80,18 @@ void ternary_op_gpu_inplace(
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7);
|
||||
compute_encoder.set_vector_bytes(shape, 4);
|
||||
compute_encoder.set_vector_bytes(strides_a, 5);
|
||||
compute_encoder.set_vector_bytes(strides_b, 6);
|
||||
compute_encoder.set_vector_bytes(strides_c, 7);
|
||||
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
compute_encoder.set_bytes(ndim, 8);
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder.set_vector_bytes(strides_a, 4);
|
||||
compute_encoder.set_vector_bytes(strides_b, 5);
|
||||
compute_encoder.set_vector_bytes(strides_c, 6);
|
||||
}
|
||||
|
||||
if (thread_group_size != 1024) {
|
||||
@@ -99,7 +99,7 @@ void ternary_op_gpu_inplace(
|
||||
}
|
||||
MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D or 2D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
@@ -109,7 +109,7 @@ void ternary_op_gpu_inplace(
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user