Fully wrap the command encoder (#1572)

* fully wrap the command encoder

* use consistent style + fix extensions
This commit is contained in:
Awni Hannun
2024-11-08 11:50:21 -08:00
committed by GitHub
parent 59247c2b62
commit 9f0d5c12fc
27 changed files with 469 additions and 484 deletions

View File

@@ -81,12 +81,12 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(
in.data_shared_ptr() == nullptr ? out : in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&axis_size, sizeof(int), 2);
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder.set_bytes(axis_size, 2);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
d.add_temporaries(std::move(copies), s.index);