mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fully wrap the command encoder (#1572)
* fully wrap the command encoder * use consistent style + fix extensions
This commit is contained in:
@@ -68,12 +68,12 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
if (contiguous) {
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(
|
||||
in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
size_t size = in.shape(axis_);
|
||||
compute_encoder->setBytes(&size, sizeof(size_t), 2);
|
||||
compute_encoder.set_bytes(size, 2);
|
||||
|
||||
// Compute the thread grid
|
||||
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
|
||||
@@ -95,10 +95,10 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size grid_dims(
|
||||
thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height);
|
||||
MTL::Size group_dims(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
} else {
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(
|
||||
in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
@@ -107,9 +107,9 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int bm = 32;
|
||||
int bn = 32;
|
||||
size_t stride_blocks = (stride + bn - 1) / bn;
|
||||
compute_encoder->setBytes(&size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&stride, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&stride_blocks, sizeof(size_t), 4);
|
||||
compute_encoder.set_bytes(size, 2);
|
||||
compute_encoder.set_bytes(stride, 3);
|
||||
compute_encoder.set_bytes(stride_blocks, 4);
|
||||
|
||||
// Compute the thread grid
|
||||
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
|
||||
@@ -125,7 +125,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size grid_dims(
|
||||
thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height);
|
||||
MTL::Size group_dims(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
|
||||
Reference in New Issue
Block a user