mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix for max block dim (#2631)
This commit is contained in:
@@ -120,19 +120,28 @@ dim3 get_2d_grid_dims(
|
||||
size_t divisor);
|
||||
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
|
||||
|
||||
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
|
||||
// assuming each thread handles |work_per_thread| elements of |arr|.
|
||||
// Get the num_blocks and block_dims assuming each thread handles
|
||||
// |work_per_thread| elements of |arr|.
|
||||
std::tuple<dim3, uint> get_launch_args(
|
||||
size_t size,
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
bool large,
|
||||
int work_per_thread = 1);
|
||||
int work_per_thread = 1,
|
||||
uint max_block_dim = 1024);
|
||||
|
||||
inline std::tuple<dim3, uint>
|
||||
get_launch_args(const array& arr, bool large, int work_per_thread = 1) {
|
||||
inline std::tuple<dim3, uint> get_launch_args(
|
||||
const array& arr,
|
||||
bool large,
|
||||
int work_per_thread = 1,
|
||||
uint max_block_dim = 1024) {
|
||||
return get_launch_args(
|
||||
arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
|
||||
arr.size(),
|
||||
arr.shape(),
|
||||
arr.strides(),
|
||||
large,
|
||||
work_per_thread,
|
||||
max_block_dim);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user