mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fixes for transpositions and expands
This commit is contained in:
parent
fd1d0821d2
commit
8bd4bf2393
@ -34,7 +34,19 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
// If it is a general reduce then copy the input to a contiguous array and
|
// If it is a general reduce then copy the input to a contiguous array and
|
||||||
// recompute the plan.
|
// recompute the plan.
|
||||||
if (plan.type == GeneralReduce) {
|
//
|
||||||
|
// TODO: Instead of copying we can use elem-to-loc to deal with broadcasting
|
||||||
|
// like we do in Metal. When it comes to broadcasted reduction axes
|
||||||
|
// some can be ignored eg for min/max.
|
||||||
|
bool broadcasted = false;
|
||||||
|
for (int i = 0, j = 0; i < in.ndim() && !broadcasted; i++) {
|
||||||
|
if (j < axes_.size() && axes_[j] == i) {
|
||||||
|
j++;
|
||||||
|
} else {
|
||||||
|
broadcasted = in.strides(i) == 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (plan.type == GeneralReduce || broadcasted) {
|
||||||
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
||||||
copy_gpu(in, in_copy, CopyType::General, s);
|
copy_gpu(in, in_copy, CopyType::General, s);
|
||||||
encoder.add_temporary(in_copy);
|
encoder.add_temporary(in_copy);
|
||||||
|
@ -104,6 +104,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
|
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
size_t total = args.non_col_reductions * args.reduction_size;
|
size_t total = args.non_col_reductions * args.reduction_size;
|
||||||
if (tile_x * BN + BN <= args.reduction_stride) {
|
if (tile_x * BN + BN <= args.reduction_stride) {
|
||||||
|
if (args.reduction_stride % N_READS == 0) {
|
||||||
for (size_t r = thread_y; r < total; r += BM) {
|
for (size_t r = thread_y; r < total; r += BM) {
|
||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
|
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
|
||||||
@ -112,6 +113,16 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
}
|
}
|
||||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t r = thread_y; r < total; r += BM) {
|
||||||
|
T vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
|
||||||
|
}
|
||||||
|
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t r = thread_y; r < total; r += BM) {
|
for (size_t r = thread_y; r < total; r += BM) {
|
||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
@ -157,11 +168,13 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
inline auto output_grid_for_col_reduce(
|
inline auto output_grid_for_col_reduce(
|
||||||
const array& out,
|
const array& out,
|
||||||
const cu::ColReduceArgs& args) {
|
const cu::ColReduceArgs& args) {
|
||||||
auto out_shape = out.shape();
|
Shape out_shape;
|
||||||
auto out_strides = out.strides();
|
Strides out_strides;
|
||||||
while (!out_shape.empty() && out_strides.back() < args.reduction_stride) {
|
for (int i = 0; i < out.ndim(); i++) {
|
||||||
out_shape.pop_back();
|
if (out.strides(i) >= args.reduction_stride) {
|
||||||
out_strides.pop_back();
|
out_shape.push_back(out.shape(i));
|
||||||
|
out_strides.push_back(out.strides(i));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return get_2d_grid_dims(out_shape, out_strides);
|
return get_2d_grid_dims(out_shape, out_strides);
|
||||||
}
|
}
|
||||||
|
@ -113,7 +113,7 @@ inline void allocate_same_layout(
|
|||||||
auto out_strides = in.strides();
|
auto out_strides = in.strides();
|
||||||
for (auto ax : axes) {
|
for (auto ax : axes) {
|
||||||
for (auto& s : out_strides) {
|
for (auto& s : out_strides) {
|
||||||
if (s > in.strides(ax)) {
|
if (s > in.strides(ax) && in.strides(ax) > 0) {
|
||||||
s /= in.shape(ax);
|
s /= in.shape(ax);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
@ -57,20 +59,24 @@ struct RowReduceArgs {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Convert shape and strides as if in was contiguous
|
// Convert shape and strides as if in was contiguous
|
||||||
void convert_shapes_to_contiguous(
|
void sort_access_pattern(const array& in, const std::vector<int>& axes) {
|
||||||
const array& in,
|
|
||||||
const std::vector<int>& axes) {
|
|
||||||
auto shape_vec = in.shape();
|
auto shape_vec = in.shape();
|
||||||
auto strides_vec = in.strides();
|
auto strides_vec = in.strides();
|
||||||
size_t s = 1;
|
|
||||||
for (int i = in.ndim() - 1; i >= 0; i--) {
|
|
||||||
strides_vec[i] = s;
|
|
||||||
s *= shape_vec[i];
|
|
||||||
}
|
|
||||||
std::tie(shape_vec, strides_vec) =
|
std::tie(shape_vec, strides_vec) =
|
||||||
shapes_without_reduction_axes(shape_vec, strides_vec, axes);
|
shapes_without_reduction_axes(shape_vec, strides_vec, axes);
|
||||||
|
std::vector<int> indices(shape_vec.size());
|
||||||
|
std::iota(indices.begin(), indices.end(), 0);
|
||||||
|
std::sort(indices.begin(), indices.end(), [&](int left, int right) {
|
||||||
|
return strides_vec[left] > strides_vec[right];
|
||||||
|
});
|
||||||
|
decltype(shape_vec) sorted_shape;
|
||||||
|
decltype(strides_vec) sorted_strides;
|
||||||
|
for (auto idx : indices) {
|
||||||
|
sorted_shape.push_back(shape_vec[idx]);
|
||||||
|
sorted_strides.push_back(strides_vec[idx]);
|
||||||
|
}
|
||||||
std::tie(shape_vec, strides_vec) =
|
std::tie(shape_vec, strides_vec) =
|
||||||
collapse_contiguous_dims(shape_vec, strides_vec);
|
collapse_contiguous_dims(sorted_shape, sorted_strides);
|
||||||
shape = const_param(shape_vec);
|
shape = const_param(shape_vec);
|
||||||
strides = const_param(strides_vec);
|
strides = const_param(strides_vec);
|
||||||
ndim = shape_vec.size();
|
ndim = shape_vec.size();
|
||||||
@ -282,7 +288,7 @@ void row_reduce_looped(
|
|||||||
using U = cu::ReduceResult<OP, T>::type;
|
using U = cu::ReduceResult<OP, T>::type;
|
||||||
|
|
||||||
// Calculate the grid and block dims
|
// Calculate the grid and block dims
|
||||||
args.convert_shapes_to_contiguous(x, axes);
|
args.sort_access_pattern(x, axes);
|
||||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
||||||
int threads = std::min(1024UL, reductions);
|
int threads = std::min(1024UL, reductions);
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
cuda_skip = {
|
cuda_skip = {
|
||||||
"TestArray.test_api",
|
"TestArray.test_api",
|
||||||
"TestBF16.test_arg_reduction_ops",
|
"TestBF16.test_arg_reduction_ops",
|
||||||
"TestBF16.test_reduction_ops",
|
|
||||||
"TestBlas.test_complex_gemm",
|
"TestBlas.test_complex_gemm",
|
||||||
"TestEinsum.test_ellipses",
|
"TestEinsum.test_ellipses",
|
||||||
"TestEinsum.test_opt_einsum_test_cases",
|
"TestEinsum.test_opt_einsum_test_cases",
|
||||||
@ -15,8 +14,6 @@ cuda_skip = {
|
|||||||
"TestOps.test_dynamic_slicing",
|
"TestOps.test_dynamic_slicing",
|
||||||
"TestReduce.test_axis_permutation_sums",
|
"TestReduce.test_axis_permutation_sums",
|
||||||
"TestReduce.test_dtypes",
|
"TestReduce.test_dtypes",
|
||||||
"TestReduce.test_expand_sums",
|
|
||||||
"TestReduce.test_many_reduction_axes",
|
|
||||||
"TestUpsample.test_torch_upsample",
|
"TestUpsample.test_torch_upsample",
|
||||||
# Block masked matmul NYI
|
# Block masked matmul NYI
|
||||||
"TestBlas.test_block_masked_matmul",
|
"TestBlas.test_block_masked_matmul",
|
||||||
|
Loading…
Reference in New Issue
Block a user