mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Compare commits
5 Commits
55d4a0be65
...
f07eb684a6
Author | SHA1 | Date | |
---|---|---|---|
![]() |
f07eb684a6 | ||
![]() |
850ad01914 | ||
![]() |
4d95cb24b4 | ||
![]() |
aa07429bad | ||
![]() |
918761a25a |
@ -30,9 +30,11 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||||
@ -43,8 +45,8 @@ target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
|||||||
file(
|
file(
|
||||||
GLOB MLX_JIT_SOURCES
|
GLOB MLX_JIT_SOURCES
|
||||||
RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
"${CMAKE_CURRENT_SOURCE_DIR}/kernels/*.h"
|
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.h"
|
||||||
"${CMAKE_CURRENT_SOURCE_DIR}/kernels/*.cuh")
|
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.cuh")
|
||||||
string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES})
|
string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES})
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
OUTPUT gen/cuda_jit_sources.h
|
OUTPUT gen/cuda_jit_sources.h
|
||||||
|
12
mlx/backend/cuda/device/ternary_ops.cuh
Normal file
12
mlx/backend/cuda/device/ternary_ops.cuh
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
struct Select {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(bool condition, T x, T y) {
|
||||||
|
return condition ? x : y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
@ -162,6 +162,27 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
|
|||||||
return cuda::std::make_tuple(a_loc, b_loc);
|
return cuda::std::make_tuple(a_loc, b_loc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int NDIM, typename IdxT = int64_t>
|
||||||
|
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
|
||||||
|
IdxT elem,
|
||||||
|
const int* shape,
|
||||||
|
const int64_t* a_strides,
|
||||||
|
const int64_t* b_strides,
|
||||||
|
const int64_t* c_strides) {
|
||||||
|
IdxT a_loc = 0;
|
||||||
|
IdxT b_loc = 0;
|
||||||
|
IdxT c_loc = 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = NDIM - 1; i >= 0; --i) {
|
||||||
|
int dim_idx = elem % shape[i];
|
||||||
|
a_loc += dim_idx * a_strides[i];
|
||||||
|
b_loc += dim_idx * b_strides[i];
|
||||||
|
c_loc += dim_idx * c_strides[i];
|
||||||
|
elem /= shape[i];
|
||||||
|
}
|
||||||
|
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||||
|
}
|
||||||
|
|
||||||
// Optimized version when ndim is larger than 4.
|
// Optimized version when ndim is larger than 4.
|
||||||
template <typename IdxT = int64_t>
|
template <typename IdxT = int64_t>
|
||||||
inline __host__ __device__ IdxT
|
inline __host__ __device__ IdxT
|
||||||
@ -191,6 +212,26 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
|||||||
return cuda::std::make_tuple(a_loc, b_loc);
|
return cuda::std::make_tuple(a_loc, b_loc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename IdxT = int64_t>
|
||||||
|
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
|
||||||
|
IdxT elem,
|
||||||
|
const int* shape,
|
||||||
|
const int64_t* a_strides,
|
||||||
|
const int64_t* b_strides,
|
||||||
|
const int64_t* c_strides,
|
||||||
|
int ndim) {
|
||||||
|
auto [a_loc, b_loc, c_loc] =
|
||||||
|
elem_to_loc_nd<3>(elem, shape, a_strides, b_strides, c_strides);
|
||||||
|
for (int i = ndim - 1; i >= 3; --i) {
|
||||||
|
int dim_idx = elem % shape[i];
|
||||||
|
a_loc += dim_idx * a_strides[i];
|
||||||
|
b_loc += dim_idx * b_strides[i];
|
||||||
|
c_loc += dim_idx * c_strides[i];
|
||||||
|
elem /= shape[i];
|
||||||
|
}
|
||||||
|
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Elem to loc in a loop utils
|
// Elem to loc in a loop utils
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -244,8 +244,7 @@ void LayerNorm::eval_gpu(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
array o = set_output(inputs[0]);
|
const array x = set_output(inputs[0]);
|
||||||
const array& x = o.data_shared_ptr() ? o : out;
|
|
||||||
const array& w = inputs[1];
|
const array& w = inputs[1];
|
||||||
const array& b = inputs[2];
|
const array& b = inputs[2];
|
||||||
|
|
||||||
|
@ -91,7 +91,6 @@ NO_GPU(QuantizedMatmul)
|
|||||||
NO_GPU(Scan)
|
NO_GPU(Scan)
|
||||||
NO_GPU(Scatter)
|
NO_GPU(Scatter)
|
||||||
NO_GPU(ScatterAxis)
|
NO_GPU(ScatterAxis)
|
||||||
NO_GPU(Select)
|
|
||||||
NO_GPU_MULTI(SVD)
|
NO_GPU_MULTI(SVD)
|
||||||
NO_GPU(Inverse)
|
NO_GPU(Inverse)
|
||||||
NO_GPU(Cholesky)
|
NO_GPU(Cholesky)
|
||||||
@ -99,8 +98,6 @@ NO_GPU_MULTI(Eig)
|
|||||||
NO_GPU_MULTI(Eigh)
|
NO_GPU_MULTI(Eigh)
|
||||||
|
|
||||||
namespace fast {
|
namespace fast {
|
||||||
NO_GPU_USE_FALLBACK(RMSNorm)
|
|
||||||
NO_GPU_MULTI(RMSNormVJP)
|
|
||||||
NO_GPU_USE_FALLBACK(RoPE)
|
NO_GPU_USE_FALLBACK(RoPE)
|
||||||
NO_GPU(ScaledDotProductAttention)
|
NO_GPU(ScaledDotProductAttention)
|
||||||
NO_GPU_MULTI(AffineQuantize)
|
NO_GPU_MULTI(AffineQuantize)
|
||||||
|
343
mlx/backend/cuda/rms_norm.cu
Normal file
343
mlx/backend/cuda/rms_norm.cu
Normal file
@ -0,0 +1,343 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/fast_primitives.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cooperative_groups/reduce.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <cub/block/block_load.cuh>
|
||||||
|
#include <cub/block/block_reduce.cuh>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
inline __device__ float2 plus_f2(const float2& a, const float2& b) {
|
||||||
|
return {a.x + b.x, a.y + b.y};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
|
||||||
|
template <typename T, int BLOCK_DIM>
|
||||||
|
struct BlockBroadcastReduce {
|
||||||
|
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
|
||||||
|
static_assert(BLOCK_DIM % WARP_SIZE == 0);
|
||||||
|
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
|
||||||
|
|
||||||
|
cg::thread_block& block;
|
||||||
|
TempStorage& temp;
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
|
||||||
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
T x = cg::reduce(warp, input, op);
|
||||||
|
if (warp.thread_rank() == 0) {
|
||||||
|
temp[warp.meta_group_rank()] = x;
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
|
||||||
|
: init_value;
|
||||||
|
return cg::reduce(warp, x, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ T Sum(const T& input) {
|
||||||
|
return Reduce(input, cg::plus<T>{}, T{});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, int BLOCK_DIM, int N_READS = 4>
|
||||||
|
__global__ void rms_norm(
|
||||||
|
const T* x,
|
||||||
|
const T* w,
|
||||||
|
T* out,
|
||||||
|
float eps,
|
||||||
|
int32_t axis_size,
|
||||||
|
int64_t w_stride) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
|
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||||
|
__shared__ typename BlockReduceT::TempStorage temp;
|
||||||
|
|
||||||
|
x += grid.block_rank() * axis_size;
|
||||||
|
out += grid.block_rank() * axis_size;
|
||||||
|
|
||||||
|
// Normalizer.
|
||||||
|
float normalizer = 0;
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
|
T xn[N_READS];
|
||||||
|
cub::LoadDirectBlocked(index, x, xn, axis_size, 0);
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
float t = static_cast<float>(xn[i]);
|
||||||
|
normalizer += t * t;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
|
||||||
|
normalizer = rsqrt(normalizer / axis_size + eps);
|
||||||
|
|
||||||
|
// Outputs.
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
|
T xn[N_READS];
|
||||||
|
T wn[N_READS];
|
||||||
|
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||||
|
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
float norm = static_cast<float>(xn[i]) * normalizer;
|
||||||
|
xn[i] = wn[i] * static_cast<T>(norm);
|
||||||
|
}
|
||||||
|
cub::StoreDirectBlocked(index, out, xn, axis_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
|
||||||
|
__global__ void rms_norm_vjp(
|
||||||
|
const T* x,
|
||||||
|
const T* w,
|
||||||
|
const T* g,
|
||||||
|
T* gx,
|
||||||
|
T* gw,
|
||||||
|
float eps,
|
||||||
|
int32_t axis_size,
|
||||||
|
int64_t w_stride) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
|
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||||
|
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>;
|
||||||
|
__shared__ union {
|
||||||
|
typename BlockReduceF::TempStorage f;
|
||||||
|
typename BlockReduceF2::TempStorage f2;
|
||||||
|
} temp;
|
||||||
|
|
||||||
|
x += grid.block_rank() * axis_size;
|
||||||
|
g += grid.block_rank() * axis_size;
|
||||||
|
gx += grid.block_rank() * axis_size;
|
||||||
|
gw += grid.block_rank() * axis_size;
|
||||||
|
|
||||||
|
// Normalizer.
|
||||||
|
float2 factors = {};
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
|
T xn[N_READS];
|
||||||
|
T wn[N_READS] = {};
|
||||||
|
T gn[N_READS] = {};
|
||||||
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
|
cub::LoadDirectBlocked(index, x, xn, axis_size, 0);
|
||||||
|
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||||
|
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
float t = static_cast<float>(xn[i]);
|
||||||
|
float wi = wn[i];
|
||||||
|
float gi = gn[i];
|
||||||
|
float wg = wi * gi;
|
||||||
|
factors = plus_f2(factors, {wg * t, t * t});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {});
|
||||||
|
float meangwx = factors.x / axis_size;
|
||||||
|
float normalizer = rsqrt(factors.y / axis_size + eps);
|
||||||
|
float normalizer3 = normalizer * normalizer * normalizer;
|
||||||
|
|
||||||
|
// Outputs.
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
|
T xn[N_READS];
|
||||||
|
T wn[N_READS];
|
||||||
|
T gn[N_READS];
|
||||||
|
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||||
|
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||||
|
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
float xi = xn[i];
|
||||||
|
float wi = wn[i];
|
||||||
|
float gi = gn[i];
|
||||||
|
xn[i] = static_cast<T>(normalizer * wi * gi - xi * meangwx * normalizer3);
|
||||||
|
if constexpr (HAS_W) {
|
||||||
|
wn[i] = static_cast<T>(gi * xi * normalizer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cub::StoreDirectBlocked(index, gx, xn, axis_size);
|
||||||
|
if constexpr (HAS_W) {
|
||||||
|
cub::StoreDirectBlocked(index, gw, wn, axis_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
namespace fast {
|
||||||
|
|
||||||
|
bool RMSNorm::use_fallback(Stream s) {
|
||||||
|
return s.device == Device::cpu;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: There are duplicate code with backend/metal/normalization.cpp
|
||||||
|
void RMSNorm::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
nvtx3::scoped_range r("RMSNorm::eval_gpu");
|
||||||
|
auto& s = stream();
|
||||||
|
auto& out = outputs[0];
|
||||||
|
|
||||||
|
// Make sure that the last dimension is contiguous.
|
||||||
|
auto set_output = [&s, &out](const array& x) {
|
||||||
|
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||||
|
if (no_copy && x.ndim() > 1) {
|
||||||
|
auto s = x.strides()[x.ndim() - 2];
|
||||||
|
no_copy &= (s == 0 || s == x.shape().back());
|
||||||
|
}
|
||||||
|
if (no_copy) {
|
||||||
|
if (x.is_donatable()) {
|
||||||
|
out.copy_shared_buffer(x);
|
||||||
|
} else {
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc(x.data_size() * x.itemsize()),
|
||||||
|
x.data_size(),
|
||||||
|
x.strides(),
|
||||||
|
x.flags());
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
} else {
|
||||||
|
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||||
|
copy_gpu(x, x_copy, CopyType::General, s);
|
||||||
|
out.copy_shared_buffer(x_copy);
|
||||||
|
return x_copy;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const array x = set_output(inputs[0]);
|
||||||
|
const array& w = inputs[1];
|
||||||
|
|
||||||
|
int32_t axis_size = x.shape().back();
|
||||||
|
int32_t n_rows = x.data_size() / axis_size;
|
||||||
|
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(x);
|
||||||
|
encoder.set_input_array(w);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "rms_norm", CTYPE, {
|
||||||
|
using DataType = cuda_type_t<CTYPE>;
|
||||||
|
constexpr uint32_t N_READS = 4;
|
||||||
|
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||||
|
auto kernel = cu::rms_norm<DataType, BLOCK_DIM, N_READS>;
|
||||||
|
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||||
|
x.data<DataType>(),
|
||||||
|
w.data<DataType>(),
|
||||||
|
out.data<DataType>(),
|
||||||
|
eps_,
|
||||||
|
axis_size,
|
||||||
|
w_stride);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void RMSNormVJP::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
nvtx3::scoped_range r("RMSNormVJP::eval_gpu");
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
|
// Ensure row contiguity. We could relax this step by checking that the array
|
||||||
|
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||||
|
// same as the cotangent strides but for now this is simpler.
|
||||||
|
auto check_input = [&s](const array& x) -> std::pair<array, bool> {
|
||||||
|
if (x.flags().row_contiguous) {
|
||||||
|
return {x, false};
|
||||||
|
}
|
||||||
|
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||||
|
copy_gpu(x, x_copy, CopyType::General, s);
|
||||||
|
return {x_copy, true};
|
||||||
|
};
|
||||||
|
bool donate_x = inputs[0].is_donatable();
|
||||||
|
bool donate_g = inputs[2].is_donatable();
|
||||||
|
auto [x, copied] = check_input(inputs[0]);
|
||||||
|
donate_x |= copied;
|
||||||
|
const array& w = inputs[1];
|
||||||
|
auto [g, g_copied] = check_input(inputs[2]);
|
||||||
|
donate_g |= g_copied;
|
||||||
|
array& gx = outputs[0];
|
||||||
|
array& gw = outputs[1];
|
||||||
|
|
||||||
|
// Check whether we had a weight.
|
||||||
|
bool has_w = w.ndim() != 0;
|
||||||
|
|
||||||
|
// Allocate space for the outputs.
|
||||||
|
bool g_in_gx = false;
|
||||||
|
if (donate_x) {
|
||||||
|
gx.copy_shared_buffer(x);
|
||||||
|
} else if (donate_g) {
|
||||||
|
gx.copy_shared_buffer(g);
|
||||||
|
g_in_gx = true;
|
||||||
|
} else {
|
||||||
|
gx.set_data(allocator::malloc(gx.nbytes()));
|
||||||
|
}
|
||||||
|
if (g_copied && !g_in_gx) {
|
||||||
|
encoder.add_temporary(g);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t axis_size = x.shape().back();
|
||||||
|
int32_t n_rows = x.data_size() / axis_size;
|
||||||
|
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||||
|
|
||||||
|
// Allocate a temporary to store the gradients for w and allocate the output
|
||||||
|
// gradient accumulators.
|
||||||
|
array gw_temp =
|
||||||
|
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
|
||||||
|
if (has_w) {
|
||||||
|
if (!g_in_gx && donate_g) {
|
||||||
|
gw_temp.copy_shared_buffer(g);
|
||||||
|
} else {
|
||||||
|
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||||
|
encoder.add_temporary(gw_temp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||||
|
|
||||||
|
encoder.set_input_array(x);
|
||||||
|
encoder.set_input_array(w);
|
||||||
|
encoder.set_input_array(g);
|
||||||
|
encoder.set_output_array(gx);
|
||||||
|
encoder.set_output_array(gw_temp);
|
||||||
|
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "rms_norm_vjp", CTYPE, {
|
||||||
|
using DataType = cuda_type_t<CTYPE>;
|
||||||
|
constexpr int N_READS = 4;
|
||||||
|
MLX_SWITCH_BOOL(has_w, HAS_W, {
|
||||||
|
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||||
|
auto kernel = cu::rms_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>;
|
||||||
|
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||||
|
x.data<DataType>(),
|
||||||
|
w.data<DataType>(),
|
||||||
|
g.data<DataType>(),
|
||||||
|
gx.data<DataType>(),
|
||||||
|
gw_temp.data<DataType>(),
|
||||||
|
eps_,
|
||||||
|
axis_size,
|
||||||
|
w_stride);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
if (has_w) {
|
||||||
|
ReductionPlan plan(
|
||||||
|
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||||
|
col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace fast
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
177
mlx/backend/cuda/ternary.cu
Normal file
177
mlx/backend/cuda/ternary.cu
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
#include "mlx/backend/common/ternary.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/ternary_ops.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename Op, typename T, typename IdxT>
|
||||||
|
__global__ void
|
||||||
|
ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
out[index] = Op{}(a[index], b[index], c[index]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename T, typename IdxT, int NDIM>
|
||||||
|
__global__ void ternary_g_nd(
|
||||||
|
const bool* a,
|
||||||
|
const T* b,
|
||||||
|
const T* c,
|
||||||
|
T* out,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> c_strides) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto [a_idx, b_idx, c_idx] = elem_to_loc_nd<NDIM>(
|
||||||
|
index,
|
||||||
|
shape.data(),
|
||||||
|
a_strides.data(),
|
||||||
|
b_strides.data(),
|
||||||
|
c_strides.data());
|
||||||
|
out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename T, typename IdxT>
|
||||||
|
__global__ void ternary_g(
|
||||||
|
const bool* a,
|
||||||
|
const T* b,
|
||||||
|
const T* c,
|
||||||
|
T* out,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ Shape shape,
|
||||||
|
const __grid_constant__ Strides a_strides,
|
||||||
|
const __grid_constant__ Strides b_strides,
|
||||||
|
const __grid_constant__ Strides c_strides,
|
||||||
|
int ndim) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto [a_idx, b_idx, c_idx] = elem_to_loc_4d(
|
||||||
|
index,
|
||||||
|
shape.data(),
|
||||||
|
a_strides.data(),
|
||||||
|
b_strides.data(),
|
||||||
|
c_strides.data(),
|
||||||
|
ndim);
|
||||||
|
out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void ternary_op_gpu_inplace(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
array& out,
|
||||||
|
const Stream& s) {
|
||||||
|
const auto& a = inputs[0];
|
||||||
|
const auto& b = inputs[1];
|
||||||
|
const auto& c = inputs[2];
|
||||||
|
if (out.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_input_array(c);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE, {
|
||||||
|
using DType = cuda_type_t<CTYPE>;
|
||||||
|
|
||||||
|
auto topt = get_ternary_op_type(a, b, c);
|
||||||
|
if (topt == TernaryOpType::General) {
|
||||||
|
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
|
||||||
|
auto& a_strides = strides[0];
|
||||||
|
auto& b_strides = strides[1];
|
||||||
|
auto& c_strides = strides[2];
|
||||||
|
bool large = a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX ||
|
||||||
|
c.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
||||||
|
MLX_SWITCH_BOOL(large, LARGE, {
|
||||||
|
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||||
|
int ndim = shape.size();
|
||||||
|
if (ndim <= 3) {
|
||||||
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
|
auto kernel = cu::ternary_g_nd<Op, DType, IdxT, NDIM>;
|
||||||
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(kernel, out, large);
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
a.data<bool>(),
|
||||||
|
b.data<DType>(),
|
||||||
|
c.data<DType>(),
|
||||||
|
out.data<DType>(),
|
||||||
|
out.data_size(),
|
||||||
|
const_param<NDIM>(shape),
|
||||||
|
const_param<NDIM>(a_strides),
|
||||||
|
const_param<NDIM>(b_strides),
|
||||||
|
const_param<NDIM>(c_strides));
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
auto kernel = cu::ternary_g<Op, DType, IdxT>;
|
||||||
|
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
a.data<bool>(),
|
||||||
|
b.data<DType>(),
|
||||||
|
c.data<DType>(),
|
||||||
|
out.data<DType>(),
|
||||||
|
out.data_size(),
|
||||||
|
const_param(shape),
|
||||||
|
const_param(a_strides),
|
||||||
|
const_param(b_strides),
|
||||||
|
const_param(c_strides),
|
||||||
|
ndim);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
||||||
|
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||||
|
auto kernel = cu::ternary_v<Op, DType, IdxT>;
|
||||||
|
auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE);
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
a.data<bool>(),
|
||||||
|
b.data<DType>(),
|
||||||
|
c.data<DType>(),
|
||||||
|
out.data<DType>(),
|
||||||
|
out.data_size());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void ternary_op_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
array& out,
|
||||||
|
const Stream& s) {
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
auto& c = inputs[2];
|
||||||
|
auto topt = get_ternary_op_type(a, b, c);
|
||||||
|
set_ternary_op_output_data(a, b, c, out, topt);
|
||||||
|
ternary_op_gpu_inplace<Op>(inputs, out, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("select::eval_gpu");
|
||||||
|
auto& s = out.primitive().stream();
|
||||||
|
ternary_op_gpu<cu::Select>(inputs, out, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
Loading…
Reference in New Issue
Block a user