mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 06:44:40 +08:00
Fully wrap the command encoder (#1572)
* fully wrap the command encoder * use consistent style + fix extensions
This commit is contained in:
@@ -494,7 +494,7 @@ below.
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
// those in the kernel declaration at axpby.metal
|
||||
@@ -509,14 +509,14 @@ below.
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
||||
compute_encoder.set_bytes(alpha_, 3);
|
||||
compute_encoder.set_bytes(beta_, 4);
|
||||
|
||||
// Encode shape, strides and ndim
|
||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
compute_encoder.set_vector_bytes(x.shape(), 5);
|
||||
compute_encoder.set_vector_bytes(x.strides(), 6);
|
||||
compute_encoder.set_bytes(y.strides(), 7);
|
||||
compute_encoder.set_bytes(ndim, 8);
|
||||
|
||||
// We launch 1 thread for each input and make sure that the number of
|
||||
// threads in any given threadgroup is not higher than the max allowed
|
||||
@@ -530,7 +530,7 @@ below.
|
||||
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
||||
|
Reference in New Issue
Block a user