This commit is contained in:
Awni Hannun 2025-05-02 20:25:28 -07:00
parent 55d6edcaa3
commit 01a29b51c8

View File

@ -145,10 +145,10 @@ void binary_op_gpu_inplace(
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims;
if (large) {
compute_encoder.set_bytes<int64_t>(a.size(), arg_idx++);
compute_encoder.set_bytes<int64_t>(out.data_size(), arg_idx++);
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
} else {
compute_encoder.set_bytes<int>(a.size(), arg_idx++);
compute_encoder.set_bytes<int>(out.data_size(), arg_idx++);
grid_dims = MTL::Size(nthreads, 1, 1);
}
compute_encoder.dispatch_threads(grid_dims, group_dims);