mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use the same tuning for looped
This commit is contained in:
@@ -270,8 +270,6 @@ void row_reduce_looped(
|
|||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const ReductionPlan& plan,
|
const ReductionPlan& plan,
|
||||||
cu::RowReduceArgs args) {
|
cu::RowReduceArgs args) {
|
||||||
constexpr int N_READS = 8;
|
|
||||||
|
|
||||||
// Allocate data for the output using in's layout to access them as
|
// Allocate data for the output using in's layout to access them as
|
||||||
// contiguously as possible.
|
// contiguously as possible.
|
||||||
allocate_same_layout(out, in, axes);
|
allocate_same_layout(out, in, axes);
|
||||||
@@ -284,12 +282,27 @@ void row_reduce_looped(
|
|||||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
using U = typename cu::ReduceResult<OP, T>::type;
|
using U = typename cu::ReduceResult<OP, T>::type;
|
||||||
|
|
||||||
|
constexpr int N_READS = 16 / sizeof(T);
|
||||||
|
|
||||||
// Calculate the grid and block dims
|
// Calculate the grid and block dims
|
||||||
args.sort_access_pattern(in, axes);
|
args.sort_access_pattern(in, axes);
|
||||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
||||||
int threads = std::min(1024UL, reductions);
|
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
if (warps > 128) {
|
||||||
|
warps = 32;
|
||||||
|
} else {
|
||||||
|
warps = 16;
|
||||||
|
}
|
||||||
|
int best = reductions;
|
||||||
|
for (int j = warps; j >= 1; j /= 2) {
|
||||||
|
int t = reductions % (j * WARP_SIZE);
|
||||||
|
if (t < best) {
|
||||||
|
warps = j;
|
||||||
|
best = t;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int threads = warps * WARP_SIZE;
|
||||||
dim3 block(threads, 1, 1);
|
dim3 block(threads, 1, 1);
|
||||||
|
|
||||||
// Pick the kernel
|
// Pick the kernel
|
||||||
|
|||||||
Reference in New Issue
Block a user