Fix ternary for large arrays (#1359)

* fix ternary for large arrays

* fix
This commit is contained in:
Awni Hannun
2024-08-26 11:22:27 -07:00
committed by GitHub
parent 860d3a50d7
commit 2fdf9eb535
2 changed files with 37 additions and 5 deletions

View File

@@ -56,9 +56,12 @@ void ternary_op_gpu_inplace(
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_input_array(c, 2);
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
bool donate_c = c.data_shared_ptr() == nullptr;
compute_encoder.set_input_array(donate_a ? out : a, 0);
compute_encoder.set_input_array(donate_b ? out : b, 1);
compute_encoder.set_input_array(donate_c ? out : c, 2);
compute_encoder.set_output_array(out, 3);
if (topt == TernaryOpType::General) {
@@ -91,9 +94,10 @@ void ternary_op_gpu_inplace(
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
// Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;