fix for max block dim (#2631)

This commit is contained in:
Awni Hannun
2025-09-29 08:59:25 -07:00
committed by GitHub
parent e76a8dd5c5
commit dc371ae7a5
7 changed files with 67 additions and 21 deletions

View File

@@ -332,9 +332,9 @@ void Compiled::eval_gpu(
encoder.set_output_array(out);
}
auto kernel = mod.get_kernel(kernel_name);
auto [kernel, max_block_dims] = mod.get_kernel_and_dims(kernel_name);
auto [num_blocks, block_dims] =
get_launch_args(outputs[0], large, work_per_thread);
get_launch_args(outputs[0], large, work_per_thread, max_block_dims);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
}