mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix for max block dim (#2631)
This commit is contained in:
@@ -332,9 +332,9 @@ void Compiled::eval_gpu(
|
||||
encoder.set_output_array(out);
|
||||
}
|
||||
|
||||
auto kernel = mod.get_kernel(kernel_name);
|
||||
auto [kernel, max_block_dims] = mod.get_kernel_and_dims(kernel_name);
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(outputs[0], large, work_per_thread);
|
||||
get_launch_args(outputs[0], large, work_per_thread, max_block_dims);
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user