mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Compare commits
6 Commits
fd1d0821d2
...
0ce20290b9
Author | SHA1 | Date | |
---|---|---|---|
![]() |
0ce20290b9 | ||
![]() |
6a59c92457 | ||
![]() |
cd31120671 | ||
![]() |
1218893b39 | ||
![]() |
ed3f6752bf | ||
![]() |
5e654b2525 |
@ -5,6 +5,7 @@ import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.mps
|
||||
|
||||
|
||||
@ -44,8 +45,10 @@ def bench(f, *args):
|
||||
|
||||
|
||||
def sync_if_needed(x):
|
||||
if x.device != torch.device("cpu"):
|
||||
if x.device == torch.device("mps"):
|
||||
torch.mps.synchronize()
|
||||
elif x.device == torch.device("cuda"):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@ -99,6 +102,14 @@ def reduction(op, axis, x):
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sum_and_add(axis, x, y):
|
||||
z = x.sum(axis=axis, keepdims=True)
|
||||
for i in range(50):
|
||||
z = (z + y).sum(axis=axis, keepdims=True)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def softmax(axis, x):
|
||||
ys = []
|
||||
@ -340,7 +351,11 @@ if __name__ == "__main__":
|
||||
args.axis.pop(0)
|
||||
|
||||
torch.set_num_threads(1)
|
||||
device = "cpu" if args.cpu else "mps"
|
||||
device = "mps"
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
if args.cpu:
|
||||
device = "cpu"
|
||||
|
||||
types = args.dtype
|
||||
if not types:
|
||||
@ -460,5 +475,8 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "selu":
|
||||
print(bench(selu, x))
|
||||
|
||||
elif args.benchmark == "sum_and_add":
|
||||
print(bench(sum_and_add, axis, *xs))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||
|
@ -29,9 +29,9 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
|
@ -157,7 +157,7 @@ void binary_op_gpu_inplace(
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel =
|
||||
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
|
||||
cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out_a, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
|
@ -21,29 +21,11 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(!axes_.empty());
|
||||
assert(out.size() != in.size());
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
// Fill out with init value.
|
||||
if (in.size() == 0) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
thrust::fill_n(
|
||||
cu::thrust_policy(stream),
|
||||
thrust::device_pointer_cast(out.data<OutType>()),
|
||||
out.data_size(),
|
||||
cu::ReduceInit<OP, InType>::value());
|
||||
});
|
||||
});
|
||||
});
|
||||
return;
|
||||
throw std::runtime_error("Should never reach here.");
|
||||
}
|
||||
|
||||
// Reduce.
|
||||
@ -59,9 +41,8 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
plan = get_reduction_plan(in, axes_);
|
||||
}
|
||||
|
||||
if ((plan.type == ContiguousAllReduce) ||
|
||||
(plan.type == ContiguousReduce && plan.shape.size() == 1)) {
|
||||
segmented_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
||||
if (plan.type == ContiguousAllReduce) {
|
||||
all_reduce(encoder, in, out, reduce_type_);
|
||||
return;
|
||||
}
|
||||
|
||||
|
153
mlx/backend/cuda/reduce/all_reduce.cu
Normal file
153
mlx/backend/cuda/reduce/all_reduce.cu
Normal file
@ -0,0 +1,153 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cub/block/block_load.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T, typename U, typename ReduceOp, int N = 4>
|
||||
__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
const U init = cu::ReduceInit<ReduceOp, T>::value();
|
||||
ReduceOp op;
|
||||
|
||||
T vals[N];
|
||||
U accs[N];
|
||||
for (int i = 0; i < N; i++) {
|
||||
accs[i] = init;
|
||||
}
|
||||
|
||||
size_t start = grid.block_rank() * block_step;
|
||||
size_t end = start + block_step;
|
||||
size_t check = min(end, size);
|
||||
|
||||
for (size_t i = start; i + block.size() * N <= check; i += block.size() * N) {
|
||||
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
|
||||
for (int j = 0; j < N; j++) {
|
||||
accs[j] = op(accs[j], __cast<U, T>(vals[j]));
|
||||
}
|
||||
}
|
||||
|
||||
if (end > size) {
|
||||
size_t offset = end - block.size() * N;
|
||||
int block_end = size - offset;
|
||||
cub::LoadDirectBlocked(
|
||||
block.thread_rank(), in + offset, vals, block_end, __cast<T, U>(init));
|
||||
for (int i = 0; i < N; i++) {
|
||||
accs[i] = op(accs[i], __cast<U, T>(vals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 1; i < N; i++) {
|
||||
accs[0] = op(accs[0], accs[i]);
|
||||
}
|
||||
accs[0] = cg::reduce(warp, accs[0], op);
|
||||
|
||||
if (warp.meta_group_size() > 1) {
|
||||
__shared__ U shared_accumulators[32];
|
||||
if (warp.thread_rank() == 0) {
|
||||
shared_accumulators[warp.meta_group_rank()] = accs[0];
|
||||
}
|
||||
block.sync();
|
||||
accs[0] = (warp.thread_rank() < warp.meta_group_size())
|
||||
? shared_accumulators[warp.thread_rank()]
|
||||
: init;
|
||||
accs[0] = cg::reduce(warp, accs[0], op);
|
||||
}
|
||||
|
||||
if (block.thread_rank() == 0) {
|
||||
out[grid.block_rank()] = accs[0];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void all_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type) {
|
||||
constexpr int N_READS = 8;
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto get_args = [](size_t size, int N) {
|
||||
size_t reductions = size / N;
|
||||
int threads = 512;
|
||||
size_t full_blocks = (reductions + threads - 1) / threads;
|
||||
int blocks;
|
||||
if (full_blocks < 32) {
|
||||
blocks = 1;
|
||||
} else if (full_blocks < 128) {
|
||||
blocks = 32;
|
||||
} else if (full_blocks < 512) {
|
||||
blocks = 128;
|
||||
} else if (full_blocks < 1024) {
|
||||
blocks = 512;
|
||||
} else {
|
||||
blocks = 1024;
|
||||
}
|
||||
size_t reductions_per_block = std::max(
|
||||
static_cast<size_t>(threads), (reductions + blocks - 1) / blocks);
|
||||
size_t block_step = reductions_per_block * N;
|
||||
|
||||
return std::make_tuple(blocks, threads, block_step);
|
||||
};
|
||||
|
||||
int blocks, threads;
|
||||
size_t block_step;
|
||||
array x = in;
|
||||
|
||||
// Large array so allocate an intermediate and accumulate there
|
||||
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
||||
if (blocks > 1) {
|
||||
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
||||
array intermediate({blocks}, out.dtype(), nullptr, {});
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
encoder.add_temporary(intermediate);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_output_array(intermediate);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using T = cuda_type_t<CTYPE>;
|
||||
using U = cu::ReduceResult<OP, T>::type;
|
||||
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
||||
kernel<<<blocks, threads, 0, stream>>>(
|
||||
x.data<T>(), intermediate.data<U>(), block_step, x.size());
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Set the input for the next step and recalculate the blocks
|
||||
x = intermediate;
|
||||
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
||||
}
|
||||
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using T = cuda_type_t<CTYPE>;
|
||||
using U = cu::ReduceResult<OP, T>::type;
|
||||
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
||||
kernel<<<blocks, threads, 0, stream>>>(
|
||||
x.data<T>(), out.data<U>(), block_step, x.size());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -47,6 +47,12 @@ namespace mlx::core {
|
||||
throw std::invalid_argument("Unknown reduce type."); \
|
||||
}
|
||||
|
||||
void all_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type);
|
||||
|
||||
void segmented_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
|
@ -3,48 +3,89 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce_utils.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// Reduce ops.
|
||||
struct And {
|
||||
__device__ bool operator()(bool a, bool b) {
|
||||
__device__ __forceinline__ bool operator()(bool a, bool b) {
|
||||
return a && b;
|
||||
}
|
||||
|
||||
__device__ void atomic_update(bool* x, bool y) {
|
||||
atomic_reduce<bool, And>(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
struct Or {
|
||||
__device__ bool operator()(bool a, bool b) {
|
||||
__device__ __forceinline__ bool operator()(bool a, bool b) {
|
||||
return a || b;
|
||||
}
|
||||
|
||||
__device__ void atomic_update(bool* x, bool y) {
|
||||
atomic_reduce<bool, Or>(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
struct Sum {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
__device__ __forceinline__ T operator()(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void atomic_update(T* x, T y) {
|
||||
atomic_reduce<T, Sum>(x, y);
|
||||
}
|
||||
|
||||
__device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) {
|
||||
atomicAdd(x, y);
|
||||
}
|
||||
|
||||
__device__ void atomic_update(int* x, int y) {
|
||||
atomicAdd(x, y);
|
||||
}
|
||||
|
||||
__device__ void atomic_update(float* x, float y) {
|
||||
atomicAdd(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
struct Prod {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
__device__ __forceinline__ T operator()(T a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void atomic_update(T* x, T y) {
|
||||
atomic_reduce<T, Prod>(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
struct Min {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
__device__ __forceinline__ T operator()(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void atomic_update(T* x, T y) {
|
||||
atomic_reduce<T, Min>(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
struct Max {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
__device__ __forceinline__ T operator()(T a, T b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void atomic_update(T* x, T y) {
|
||||
atomic_reduce<T, Max>(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
// Traits to get the result type of reduce op.
|
||||
@ -120,7 +161,7 @@ template <typename T>
|
||||
struct ReduceInit<Prod, T> {
|
||||
static constexpr __host__ __device__ auto value() {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return T{1, 1};
|
||||
return T{1, 0};
|
||||
} else {
|
||||
return typename ReduceResult<Prod, T>::type{1};
|
||||
}
|
||||
|
65
mlx/backend/cuda/reduce/reduce_utils.cuh
Normal file
65
mlx/backend/cuda/reduce/reduce_utils.cuh
Normal file
@ -0,0 +1,65 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
template <size_t N>
|
||||
struct uint_by_size;
|
||||
template <>
|
||||
struct uint_by_size<2> {
|
||||
using type = uint16_t;
|
||||
};
|
||||
template <>
|
||||
struct uint_by_size<4> {
|
||||
using type = uint32_t;
|
||||
};
|
||||
template <>
|
||||
struct uint_by_size<8> {
|
||||
using type = unsigned long long int;
|
||||
};
|
||||
|
||||
template <typename T, typename Op>
|
||||
__device__ void atomic_reduce(T* x, T y) {
|
||||
if constexpr (sizeof(T) == 1) {
|
||||
using U = uint16_t;
|
||||
U* x_int = (U*)((char*)x - ((size_t)x % 2));
|
||||
int shift = ((char*)x - (char*)x_int) * 8;
|
||||
int mask = 0xff << shift;
|
||||
U old_val, new_val;
|
||||
do {
|
||||
old_val = *x_int;
|
||||
T result = Op{}(static_cast<T>((old_val >> shift) & 0xff), y);
|
||||
new_val = (old_val & ~mask) | (result << shift);
|
||||
} while (atomicCAS(x_int, old_val, new_val) != old_val);
|
||||
} else {
|
||||
using U = typename uint_by_size<sizeof(T)>::type;
|
||||
U* x_int = (U*)(x);
|
||||
U old_val, new_val;
|
||||
do {
|
||||
old_val = *x_int;
|
||||
T result = Op{}(*((T*)&old_val), y);
|
||||
new_val = *((U*)&result);
|
||||
} while (atomicCAS(x_int, old_val, new_val) != old_val);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Should make a custom complex type
|
||||
template <typename U, typename T>
|
||||
inline __device__ U __cast(T x) {
|
||||
return static_cast<U>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ bool __cast<bool, cuComplex>(cuComplex x) {
|
||||
return x.x != 0 && x.y != 0;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ cuComplex __cast<cuComplex, bool>(bool x) {
|
||||
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
@ -136,6 +136,91 @@ __global__ void row_reduce_small_warp(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
|
||||
__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
const U init = cu::ReduceInit<ReduceOp, T>::value();
|
||||
ReduceOp op;
|
||||
|
||||
T vals[M][N];
|
||||
U accs[M];
|
||||
for (int i = 0; i < M; i++) {
|
||||
accs[i] = init;
|
||||
}
|
||||
|
||||
const size_t start_row =
|
||||
min(n_rows - M, static_cast<size_t>(grid.block_rank() * M));
|
||||
in += start_row * size;
|
||||
out += start_row;
|
||||
|
||||
int i = 0;
|
||||
for (; i + block.size() * N <= size; i += block.size() * N) {
|
||||
for (int k = 0; k < M; k++) {
|
||||
cub::LoadDirectBlockedVectorized<T, N>(
|
||||
block.thread_rank(), in + k * size + i, vals[k]);
|
||||
for (int j = 0; j < N; j++) {
|
||||
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (size > i) {
|
||||
for (int k = 0; k < M; k++) {
|
||||
cub::LoadDirectBlocked(
|
||||
block.thread_rank(),
|
||||
in + k * size + i,
|
||||
vals[k],
|
||||
size,
|
||||
__cast<T, U>(init));
|
||||
for (int j = 0; i < N; i++) {
|
||||
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < M; i++) {
|
||||
accs[i] = cg::reduce(warp, accs[i], op);
|
||||
}
|
||||
|
||||
if (warp.meta_group_size() > 1) {
|
||||
__shared__ U shared_accumulators[32 * M];
|
||||
if (warp.thread_rank() == 0) {
|
||||
for (int i = 0; i < M; i++) {
|
||||
shared_accumulators[warp.meta_group_rank() * M + i] = accs[i];
|
||||
}
|
||||
}
|
||||
block.sync();
|
||||
if (warp.thread_rank() < warp.meta_group_size()) {
|
||||
for (int i = 0; i < M; i++) {
|
||||
accs[i] = shared_accumulators[warp.thread_rank() * M + i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < M; i++) {
|
||||
accs[i] = init;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < M; i++) {
|
||||
accs[i] = cg::reduce(warp, accs[i], op);
|
||||
}
|
||||
}
|
||||
|
||||
if (block.thread_rank() == 0) {
|
||||
if (grid.block_rank() * M + M <= n_rows) {
|
||||
for (int i = 0; i < M; i++) {
|
||||
out[i] = accs[i];
|
||||
}
|
||||
} else {
|
||||
short offset = grid.block_rank() * M + M - n_rows;
|
||||
for (int i = offset; i < M; i++) {
|
||||
out[i] = accs[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
@ -144,12 +229,13 @@ template <
|
||||
int BLOCK_DIM_X,
|
||||
int N_READS = 4>
|
||||
__global__ void row_reduce_looped(
|
||||
const T* in,
|
||||
T* in,
|
||||
U* out,
|
||||
size_t out_size,
|
||||
const __grid_constant__ RowReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
size_t out_idx = grid.thread_rank() / BLOCK_DIM_X;
|
||||
if (out_idx >= out_size) {
|
||||
@ -160,20 +246,31 @@ __global__ void row_reduce_looped(
|
||||
|
||||
U total_val = ReduceInit<Op, T>::value();
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
size_t full_blocks = args.row_size / (BLOCK_DIM_X * N_READS);
|
||||
size_t final_offset = full_blocks * BLOCK_DIM_X * N_READS;
|
||||
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
for (size_t r = 0; r < cuda::ceil_div(args.row_size, BLOCK_DIM_X * N_READS);
|
||||
r++) {
|
||||
U vals[N_READS];
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlockedVectorized<T, N_READS>(
|
||||
block.thread_rank(),
|
||||
in + loop.location() + r * BLOCK_DIM_X * N_READS,
|
||||
vals);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total_val = op(total_val, __cast<U, T>(vals[i]));
|
||||
}
|
||||
}
|
||||
if (final_offset < args.row_size) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r * BLOCK_DIM_X + block.thread_index().x,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
block.thread_rank(),
|
||||
in + loop.location() + final_offset,
|
||||
vals,
|
||||
args.row_size,
|
||||
ReduceInit<Op, T>::value());
|
||||
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||
args.row_size - final_offset,
|
||||
__cast<T, U>(ReduceInit<Op, T>::value()));
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total_val = op(total_val, __cast<U, T>(vals[i]));
|
||||
}
|
||||
}
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
@ -190,6 +287,138 @@ __global__ void row_reduce_looped(
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void row_reduce_simple(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
constexpr int N_READS = 8;
|
||||
|
||||
// Initialize out such that its strides match in's layout (except the fastest
|
||||
// moving axis)
|
||||
auto out_strides = in.strides();
|
||||
for (auto& s : out_strides) {
|
||||
s /= plan.shape.back();
|
||||
}
|
||||
auto [data_size, rc, cc] = check_contiguity(out.shape(), out_strides);
|
||||
auto fl = in.flags();
|
||||
fl.row_contiguous = rc;
|
||||
fl.col_contiguous = cc;
|
||||
fl.contiguous = data_size == out.size();
|
||||
out.set_data(
|
||||
allocator::malloc(out.nbytes()),
|
||||
data_size,
|
||||
out_strides,
|
||||
fl,
|
||||
allocator::free);
|
||||
|
||||
// Just a way to get out of the constness because cub doesn't like it ...
|
||||
// (sigh)
|
||||
array x = in;
|
||||
|
||||
// TODO: If out.size() < 1024 which will be a common case then write this in
|
||||
// 2 passes. Something like 32 * out.size() and then do a warp reduce.
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using T = cuda_type_t<CTYPE>;
|
||||
using U = cu::ReduceResult<OP, T>::type;
|
||||
|
||||
// Calculate the grid and block dims
|
||||
size_t reductions = plan.shape.back() / N_READS;
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
int threads = std::min(1024UL, reductions);
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
dim3 block(threads, 1, 1);
|
||||
|
||||
// Pick the kernel
|
||||
auto kernel = cu::row_reduce_simple<T, U, OP, N_READS>;
|
||||
if (grid.x >= 1024) {
|
||||
grid.x = (grid.x + 1) / 2;
|
||||
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
|
||||
}
|
||||
|
||||
// Launch
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
x.data<T>(), out.data<U>(), out.size(), plan.shape.back());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void row_reduce_looped(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
constexpr int N_READS = 8;
|
||||
|
||||
// Initialize out such that it matches in's layout. Basically we keep any
|
||||
// transpositions as it were and that allows us to skip finding the location
|
||||
// of the output that matches the input.
|
||||
auto out_strides = in.strides();
|
||||
for (auto ax : axes) {
|
||||
for (auto& s : out_strides) {
|
||||
if (s > in.strides(ax)) {
|
||||
s /= in.shape(ax);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto [data_size, rc, cc] = check_contiguity(out.shape(), out_strides);
|
||||
auto fl = in.flags();
|
||||
fl.row_contiguous = rc;
|
||||
fl.col_contiguous = cc;
|
||||
fl.contiguous = data_size == out.size();
|
||||
out.set_data(
|
||||
allocator::malloc(out.nbytes()),
|
||||
data_size,
|
||||
out_strides,
|
||||
fl,
|
||||
allocator::free);
|
||||
|
||||
// Just a way to get out of the constness because cub doesn't like it ...
|
||||
// (sigh)
|
||||
array x = in;
|
||||
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using T = cuda_type_t<CTYPE>;
|
||||
using U = cu::ReduceResult<OP, T>::type;
|
||||
|
||||
// Calculate the grid and block dims
|
||||
cu::RowReduceArgs args(in, plan, axes);
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
size_t reductions = args.row_size / N_READS;
|
||||
int threads = std::min(1024UL, reductions);
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
dim3 block(threads, 1, 1);
|
||||
|
||||
// Pick the kernel
|
||||
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
|
||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||
MLX_SWITCH_BLOCK_DIM(threads, THREADS, {
|
||||
kernel = cu::row_reduce_looped<T, U, OP, NDIM, THREADS, N_READS>;
|
||||
block.x = THREADS;
|
||||
});
|
||||
});
|
||||
|
||||
// Launch
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
x.data<T>(), out.data<U>(), out.size(), args);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void row_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
@ -197,54 +426,62 @@ void row_reduce(
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
cu::RowReduceArgs args(in, plan, axes);
|
||||
// Simple row reduce means that we have 1 axis that we are reducing over and
|
||||
// it has stride 1.
|
||||
if (plan.shape.size() == 1) {
|
||||
row_reduce_simple(encoder, in, out, reduce_type, axes, plan);
|
||||
}
|
||||
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||
constexpr size_t N_READS = 4;
|
||||
dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides());
|
||||
dim3 block_dims, num_blocks;
|
||||
auto kernel =
|
||||
cu::row_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
||||
if (args.row_size <= 64) {
|
||||
if ((args.non_row_reductions < 32 && args.row_size <= 8) ||
|
||||
(args.non_row_reductions <= 8)) {
|
||||
block_dims.x = std::min(out_dims.x, 1024u);
|
||||
num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x);
|
||||
num_blocks.y = out_dims.y;
|
||||
} else {
|
||||
block_dims.x = WARP_SIZE;
|
||||
num_blocks.y = out_dims.x;
|
||||
num_blocks.z = out_dims.y;
|
||||
kernel =
|
||||
cu::row_reduce_small_warp<InType, OutType, OP, NDIM, N_READS>;
|
||||
}
|
||||
} else {
|
||||
size_t num_threads = cuda::ceil_div(args.row_size, N_READS);
|
||||
num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE;
|
||||
MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, {
|
||||
num_blocks.y = out_dims.x;
|
||||
num_blocks.z = out_dims.y;
|
||||
block_dims.x = BLOCK_DIM_X;
|
||||
kernel = cu::row_reduce_looped<
|
||||
InType,
|
||||
OutType,
|
||||
OP,
|
||||
NDIM,
|
||||
BLOCK_DIM_X,
|
||||
N_READS>;
|
||||
});
|
||||
}
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>(), out.data<OutType>(), out.size(), args);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
// Fallback row reduce
|
||||
row_reduce_looped(encoder, in, out, reduce_type, axes, plan);
|
||||
|
||||
// encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
// MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
// using InType = cuda_type_t<CTYPE>;
|
||||
// MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
// using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
// MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||
// constexpr size_t N_READS = 4;
|
||||
// dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides());
|
||||
// dim3 block_dims, num_blocks;
|
||||
// auto kernel =
|
||||
// cu::row_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
||||
// if (args.row_size <= 64) {
|
||||
// if ((args.non_row_reductions < 32 && args.row_size <= 8) ||
|
||||
// (args.non_row_reductions <= 8)) {
|
||||
// block_dims.x = std::min(out_dims.x, 1024u);
|
||||
// num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x);
|
||||
// num_blocks.y = out_dims.y;
|
||||
// } else {
|
||||
// block_dims.x = WARP_SIZE;
|
||||
// num_blocks.y = out_dims.x;
|
||||
// num_blocks.z = out_dims.y;
|
||||
// kernel =
|
||||
// cu::row_reduce_small_warp<InType, OutType, OP, NDIM,
|
||||
// N_READS>;
|
||||
// }
|
||||
// } else {
|
||||
// size_t num_threads = cuda::ceil_div(args.row_size, N_READS);
|
||||
// num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE;
|
||||
// MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, {
|
||||
// num_blocks.y = out_dims.x;
|
||||
// num_blocks.z = out_dims.y;
|
||||
// block_dims.x = BLOCK_DIM_X;
|
||||
// kernel = cu::row_reduce_looped<
|
||||
// InType,
|
||||
// OutType,
|
||||
// OP,
|
||||
// NDIM,
|
||||
// BLOCK_DIM_X,
|
||||
// N_READS>;
|
||||
// });
|
||||
// }
|
||||
// kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
// in.data<InType>(), out.data<OutType>(), out.size(), args);
|
||||
// });
|
||||
// });
|
||||
// });
|
||||
// });
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1,84 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <cub/device/device_reduce.cuh>
|
||||
#include <cub/device/device_segmented_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename... Args>
|
||||
void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
// Allocate temporary storage.
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...));
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
// Run op.
|
||||
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
// Allocate temporary storage.
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...));
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
// Run op.
|
||||
CHECK_CUDA_ERROR(
|
||||
cub::DeviceSegmentedReduce::Reduce(temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
struct MultiplyOp {
|
||||
int factor;
|
||||
__device__ int operator()(int i) {
|
||||
return i * factor;
|
||||
}
|
||||
};
|
||||
|
||||
void segmented_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
auto in_iter = cu::make_cast_iterator<OutType>(
|
||||
thrust::device_pointer_cast(in.data<InType>()));
|
||||
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
|
||||
auto init = cu::ReduceInit<OP, InType>::value();
|
||||
|
||||
if (plan.type == ContiguousAllReduce) {
|
||||
cub_all_reduce(
|
||||
encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream);
|
||||
} else if (plan.type == ContiguousReduce) {
|
||||
auto offsets = thrust::make_transform_iterator(
|
||||
thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()});
|
||||
cub_segmented_reduce(
|
||||
encoder,
|
||||
in_iter,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
offsets,
|
||||
offsets + 1,
|
||||
OP(),
|
||||
init,
|
||||
stream);
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported plan in segmented_reduce.");
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
Loading…
Reference in New Issue
Block a user