mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Fixes for transpositions and expands
This commit is contained in:
parent
fd1d0821d2
commit
8bd4bf2393
@ -34,7 +34,19 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// If it is a general reduce then copy the input to a contiguous array and
|
||||
// recompute the plan.
|
||||
if (plan.type == GeneralReduce) {
|
||||
//
|
||||
// TODO: Instead of copying we can use elem-to-loc to deal with broadcasting
|
||||
// like we do in Metal. When it comes to broadcasted reduction axes
|
||||
// some can be ignored eg for min/max.
|
||||
bool broadcasted = false;
|
||||
for (int i = 0, j = 0; i < in.ndim() && !broadcasted; i++) {
|
||||
if (j < axes_.size() && axes_[j] == i) {
|
||||
j++;
|
||||
} else {
|
||||
broadcasted = in.strides(i) == 0;
|
||||
}
|
||||
}
|
||||
if (plan.type == GeneralReduce || broadcasted) {
|
||||
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, in_copy, CopyType::General, s);
|
||||
encoder.add_temporary(in_copy);
|
||||
|
@ -104,13 +104,24 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
size_t total = args.non_col_reductions * args.reduction_size;
|
||||
if (tile_x * BN + BN <= args.reduction_stride) {
|
||||
for (size_t r = thread_y; r < total; r += BM) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
|
||||
if (args.reduction_stride % N_READS == 0) {
|
||||
for (size_t r = thread_y; r < total; r += BM) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
|
||||
}
|
||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
} else {
|
||||
for (size_t r = thread_y; r < total; r += BM) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
|
||||
}
|
||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
} else {
|
||||
for (size_t r = thread_y; r < total; r += BM) {
|
||||
@ -157,11 +168,13 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||
inline auto output_grid_for_col_reduce(
|
||||
const array& out,
|
||||
const cu::ColReduceArgs& args) {
|
||||
auto out_shape = out.shape();
|
||||
auto out_strides = out.strides();
|
||||
while (!out_shape.empty() && out_strides.back() < args.reduction_stride) {
|
||||
out_shape.pop_back();
|
||||
out_strides.pop_back();
|
||||
Shape out_shape;
|
||||
Strides out_strides;
|
||||
for (int i = 0; i < out.ndim(); i++) {
|
||||
if (out.strides(i) >= args.reduction_stride) {
|
||||
out_shape.push_back(out.shape(i));
|
||||
out_strides.push_back(out.strides(i));
|
||||
}
|
||||
}
|
||||
return get_2d_grid_dims(out_shape, out_strides);
|
||||
}
|
||||
|
@ -113,7 +113,7 @@ inline void allocate_same_layout(
|
||||
auto out_strides = in.strides();
|
||||
for (auto ax : axes) {
|
||||
for (auto& s : out_strides) {
|
||||
if (s > in.strides(ax)) {
|
||||
if (s > in.strides(ax) && in.strides(ax) > 0) {
|
||||
s /= in.shape(ax);
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
@ -57,20 +59,24 @@ struct RowReduceArgs {
|
||||
}
|
||||
|
||||
// Convert shape and strides as if in was contiguous
|
||||
void convert_shapes_to_contiguous(
|
||||
const array& in,
|
||||
const std::vector<int>& axes) {
|
||||
void sort_access_pattern(const array& in, const std::vector<int>& axes) {
|
||||
auto shape_vec = in.shape();
|
||||
auto strides_vec = in.strides();
|
||||
size_t s = 1;
|
||||
for (int i = in.ndim() - 1; i >= 0; i--) {
|
||||
strides_vec[i] = s;
|
||||
s *= shape_vec[i];
|
||||
}
|
||||
std::tie(shape_vec, strides_vec) =
|
||||
shapes_without_reduction_axes(shape_vec, strides_vec, axes);
|
||||
std::vector<int> indices(shape_vec.size());
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
std::sort(indices.begin(), indices.end(), [&](int left, int right) {
|
||||
return strides_vec[left] > strides_vec[right];
|
||||
});
|
||||
decltype(shape_vec) sorted_shape;
|
||||
decltype(strides_vec) sorted_strides;
|
||||
for (auto idx : indices) {
|
||||
sorted_shape.push_back(shape_vec[idx]);
|
||||
sorted_strides.push_back(strides_vec[idx]);
|
||||
}
|
||||
std::tie(shape_vec, strides_vec) =
|
||||
collapse_contiguous_dims(shape_vec, strides_vec);
|
||||
collapse_contiguous_dims(sorted_shape, sorted_strides);
|
||||
shape = const_param(shape_vec);
|
||||
strides = const_param(strides_vec);
|
||||
ndim = shape_vec.size();
|
||||
@ -282,7 +288,7 @@ void row_reduce_looped(
|
||||
using U = cu::ReduceResult<OP, T>::type;
|
||||
|
||||
// Calculate the grid and block dims
|
||||
args.convert_shapes_to_contiguous(x, axes);
|
||||
args.sort_access_pattern(x, axes);
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
||||
int threads = std::min(1024UL, reductions);
|
||||
|
@ -1,7 +1,6 @@
|
||||
cuda_skip = {
|
||||
"TestArray.test_api",
|
||||
"TestBF16.test_arg_reduction_ops",
|
||||
"TestBF16.test_reduction_ops",
|
||||
"TestBlas.test_complex_gemm",
|
||||
"TestEinsum.test_ellipses",
|
||||
"TestEinsum.test_opt_einsum_test_cases",
|
||||
@ -15,8 +14,6 @@ cuda_skip = {
|
||||
"TestOps.test_dynamic_slicing",
|
||||
"TestReduce.test_axis_permutation_sums",
|
||||
"TestReduce.test_dtypes",
|
||||
"TestReduce.test_expand_sums",
|
||||
"TestReduce.test_many_reduction_axes",
|
||||
"TestUpsample.test_torch_upsample",
|
||||
# Block masked matmul NYI
|
||||
"TestBlas.test_block_masked_matmul",
|
||||
|
Loading…
Reference in New Issue
Block a user