fix for max block dim (#2631)

This commit is contained in:
Awni Hannun
2025-09-29 08:59:25 -07:00
committed by GitHub
parent e76a8dd5c5
commit dc371ae7a5
7 changed files with 67 additions and 21 deletions

View File

@@ -120,19 +120,28 @@ dim3 get_2d_grid_dims(
size_t divisor);
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
// assuming each thread handles |work_per_thread| elements of |arr|.
// Get the num_blocks and block_dims assuming each thread handles
// |work_per_thread| elements of |arr|.
std::tuple<dim3, uint> get_launch_args(
size_t size,
const Shape& shape,
const Strides& strides,
bool large,
int work_per_thread = 1);
int work_per_thread = 1,
uint max_block_dim = 1024);
inline std::tuple<dim3, uint>
get_launch_args(const array& arr, bool large, int work_per_thread = 1) {
inline std::tuple<dim3, uint> get_launch_args(
const array& arr,
bool large,
int work_per_thread = 1,
uint max_block_dim = 1024) {
return get_launch_args(
arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
arr.size(),
arr.shape(),
arr.strides(),
large,
work_per_thread,
max_block_dim);
}
} // namespace mlx::core