mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-17 09:18:12 +08:00
@@ -101,10 +101,12 @@ constexpr bool supports_binary_op() {
|
||||
return std::is_same_v<Out, bool> && std::is_same_v<In, bool>;
|
||||
}
|
||||
if (std::is_same_v<Op, NaNEqual>) {
|
||||
return std::is_same_v<Out, bool> &&
|
||||
(is_floating_v<In> || std::is_same_v<In, complex64_t>);
|
||||
return std::is_same_v<Out, bool> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, LogAddExp> || std::is_same_v<Op, ArcTan2>) {
|
||||
if (std::is_same_v<Op, LogAddExp>) {
|
||||
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, ArcTan2>) {
|
||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, BitwiseAnd> || std::is_same_v<Op, BitwiseOr> ||
|
||||
@@ -150,10 +152,10 @@ void binary_op_gpu_inplace(
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
bool large = a.data_size() > UINT32_MAX ||
|
||||
b.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
||||
bool large = a.data_size() > INT32_MAX ||
|
||||
b.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
|
@@ -130,11 +130,13 @@ struct FusedKernelBuilder {
|
||||
|
||||
constexpr const char* g_jit_includes = R"(
|
||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/ternary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
#define inf cuda::std::numeric_limits<float>::infinity()
|
||||
)";
|
||||
|
||||
void Compiled::eval_gpu(
|
||||
|
@@ -1,6 +1,8 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda/std/array>
|
||||
@@ -122,6 +124,26 @@ struct LogAddExp {
|
||||
? maxval
|
||||
: T(float(maxval) + log1p(expf(minval - maxval)));
|
||||
};
|
||||
|
||||
__device__ cuComplex operator()(cuComplex x, cuComplex y) {
|
||||
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
|
||||
isnan(cuCimagf(y))) {
|
||||
return {
|
||||
cuda::std::numeric_limits<float>::quiet_NaN(),
|
||||
cuda::std::numeric_limits<float>::quiet_NaN()};
|
||||
}
|
||||
constexpr float inf = cuda::std::numeric_limits<float>::infinity();
|
||||
auto maxval = x > y ? x : y;
|
||||
auto minval = x < y ? x : y;
|
||||
if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf)
|
||||
return maxval;
|
||||
float m = exp(cuCrealf(minval) - cuCrealf(maxval));
|
||||
cuComplex dexp{
|
||||
m * cos(cuCimagf(minval) - cuCimagf(maxval)),
|
||||
m * sin(cuCimagf(minval) - cuCimagf(maxval)),
|
||||
};
|
||||
return maxval + log1p(dexp);
|
||||
}
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
|
@@ -1,4 +1,5 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
|
@@ -187,8 +187,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ IdxT
|
||||
elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
|
||||
IdxT loc = elem_to_loc_nd<3>(elem, shape, strides);
|
||||
for (int i = ndim - 1; i >= 3; --i) {
|
||||
IdxT loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
@@ -202,8 +202,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides,
|
||||
int ndim) {
|
||||
auto [a_loc, b_loc] = elem_to_loc_nd<3>(elem, shape, a_strides, b_strides);
|
||||
for (int i = ndim - 1; i >= 3; --i) {
|
||||
IdxT a_loc = 0;
|
||||
IdxT b_loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
@@ -220,9 +221,10 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
|
||||
const int64_t* b_strides,
|
||||
const int64_t* c_strides,
|
||||
int ndim) {
|
||||
auto [a_loc, b_loc, c_loc] =
|
||||
elem_to_loc_nd<3>(elem, shape, a_strides, b_strides, c_strides);
|
||||
for (int i = ndim - 1; i >= 3; --i) {
|
||||
IdxT a_loc = 0;
|
||||
IdxT b_loc = 0;
|
||||
IdxT c_loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * a_strides[i];
|
||||
b_loc += dim_idx * b_strides[i];
|
||||
@@ -336,4 +338,21 @@ struct LoopedElemToLoc<1, false, OffsetT> {
|
||||
}
|
||||
};
|
||||
|
||||
inline __device__ cuComplex log1p(cuComplex in) {
|
||||
float x = cuCrealf(in);
|
||||
float y = cuCimagf(in);
|
||||
float zabs = sqrt(x * x + y * y);
|
||||
float theta = atan2f(y, x + 1);
|
||||
if (zabs < 0.5f) {
|
||||
float r = x * (2 + x) + y * y;
|
||||
if (r == 0) { // handle underflow
|
||||
return {x, theta};
|
||||
}
|
||||
return {0.5f * log1pf(r), theta};
|
||||
} else {
|
||||
auto z0 = sqrt((x + 1) * (x + 1) + y * y);
|
||||
return {log(z0), theta};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
@@ -65,8 +65,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
|
||||
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
|
||||
|
||||
bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) ||
|
||||
(src.size() > UINT32_MAX) || (out.size() > UINT32_MAX);
|
||||
bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) ||
|
||||
(src.size() > INT32_MAX) || (out.size() > INT32_MAX);
|
||||
|
||||
uint32_t slice_size = std::accumulate(
|
||||
slice_sizes_.begin(), slice_sizes_.end(), 1, std::multiplies<uint32_t>());
|
||||
@@ -88,7 +88,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
dtype_to_cuda_type(idx_dtype),
|
||||
nidx,
|
||||
ndim,
|
||||
large ? "int64_t" : "uint32_t"));
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_gather, std::move(kernel_names));
|
||||
@@ -99,7 +99,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(out.size());
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(out.size());
|
||||
mod.append_arg<int32_t>(out.size());
|
||||
}
|
||||
mod.append_ndim_arg(src.shape());
|
||||
mod.append_ndim_arg(src.strides());
|
||||
@@ -115,7 +115,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
dtype_to_cuda_type(idx_dtype),
|
||||
nidx,
|
||||
idx_ndim,
|
||||
large ? "int64_t" : "uint32_t");
|
||||
large ? "int64_t" : "int32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
@@ -152,14 +152,14 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
|
||||
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
|
||||
|
||||
bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) ||
|
||||
(upd.size() > UINT32_MAX) || (out.size() > UINT32_MAX);
|
||||
bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) ||
|
||||
(upd.size() > INT32_MAX) || (out.size() > INT32_MAX);
|
||||
|
||||
uint32_t upd_post_idx_size = std::accumulate(
|
||||
int32_t upd_post_idx_size = std::accumulate(
|
||||
upd.shape().begin() + idx_ndim,
|
||||
upd.shape().end(),
|
||||
1,
|
||||
std::multiplies<uint32_t>());
|
||||
std::multiplies<int32_t>());
|
||||
|
||||
const char* op = g_scatter_ops[reduce_type_];
|
||||
std::string module_name = fmt::format(
|
||||
@@ -181,7 +181,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
op,
|
||||
nidx,
|
||||
ndim,
|
||||
large ? "int64_t" : "uint32_t"));
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_scatter, std::move(kernel_names));
|
||||
@@ -192,7 +192,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(upd.size());
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(upd.size());
|
||||
mod.append_arg<int32_t>(upd.size());
|
||||
}
|
||||
mod.append_ndim_arg(upd.shape());
|
||||
mod.append_ndim_arg(upd.strides());
|
||||
@@ -200,7 +200,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(upd_post_idx_size);
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(upd_post_idx_size);
|
||||
mod.append_arg<int32_t>(upd_post_idx_size);
|
||||
}
|
||||
mod.append_ndim_arg(out.shape());
|
||||
mod.append_ndim_arg(out.strides());
|
||||
@@ -215,7 +215,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
op,
|
||||
nidx,
|
||||
idx_ndim,
|
||||
large ? "int64_t" : "uint32_t");
|
||||
large ? "int64_t" : "int32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
@@ -238,7 +238,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX;
|
||||
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
||||
|
||||
std::string module_name = fmt::format(
|
||||
"gather_axis_{}_{}",
|
||||
@@ -258,7 +258,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
ndim,
|
||||
contiguous & 1 ? true : false,
|
||||
contiguous & 2 ? true : false,
|
||||
large ? "int64_t" : "uint32_t"));
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -283,9 +283,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
mod.append_arg<int64_t>(idx_size_axis);
|
||||
mod.append_arg<int64_t>(idx_size_post);
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(idx_size_pre);
|
||||
mod.append_arg<uint32_t>(idx_size_axis);
|
||||
mod.append_arg<uint32_t>(idx_size_post);
|
||||
mod.append_arg<int32_t>(idx_size_pre);
|
||||
mod.append_arg<int32_t>(idx_size_axis);
|
||||
mod.append_arg<int32_t>(idx_size_post);
|
||||
}
|
||||
mod.append_arg(remove_index(idx.shape(), axis_));
|
||||
mod.append_arg(remove_index(src.strides(), axis_));
|
||||
@@ -302,7 +302,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
src.ndim() - 1,
|
||||
src.flags().row_contiguous,
|
||||
idx.flags().row_contiguous,
|
||||
large ? "int64_t" : "uint32_t");
|
||||
large ? "int64_t" : "int32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
@@ -337,7 +337,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX;
|
||||
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
||||
|
||||
const char* op = reduce_type_ == ScatterAxis::Sum ? "Sum" : "Assign";
|
||||
std::string module_name = fmt::format(
|
||||
@@ -360,7 +360,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
ndim,
|
||||
contiguous & 1 ? true : false,
|
||||
contiguous & 2 ? true : false,
|
||||
large ? "int64_t" : "uint32_t"));
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -385,9 +385,9 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
mod.append_arg<int64_t>(idx_size_axis);
|
||||
mod.append_arg<int64_t>(idx_size_post);
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(idx_size_pre);
|
||||
mod.append_arg<uint32_t>(idx_size_axis);
|
||||
mod.append_arg<uint32_t>(idx_size_post);
|
||||
mod.append_arg<int32_t>(idx_size_pre);
|
||||
mod.append_arg<int32_t>(idx_size_axis);
|
||||
mod.append_arg<int32_t>(idx_size_post);
|
||||
}
|
||||
mod.append_arg(remove_index(idx.shape(), axis_));
|
||||
mod.append_arg(remove_index(upd.strides(), axis_));
|
||||
@@ -405,7 +405,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
idx.ndim() - 1,
|
||||
upd.flags().row_contiguous,
|
||||
idx.flags().row_contiguous,
|
||||
large ? "int64_t" : "uint32_t");
|
||||
large ? "int64_t" : "int32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
|
@@ -101,10 +101,10 @@ void ternary_op_gpu_inplace(
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
auto& c_strides = strides[2];
|
||||
bool large = a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX ||
|
||||
c.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
||||
bool large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||
MLX_SWITCH_BOOL(large, LARGE, {
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
|
@@ -27,13 +27,12 @@ constexpr bool supports_unary_op() {
|
||||
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
|
||||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
|
||||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
|
||||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Log1p> ||
|
||||
std::is_same_v<Op, Sigmoid> || std::is_same_v<Op, Sqrt> ||
|
||||
std::is_same_v<Op, Rsqrt>) {
|
||||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Sigmoid> ||
|
||||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
|
||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
||||
std::is_same_v<Op, Log10>) {
|
||||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p>) {
|
||||
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, BitwiseInvert>) {
|
||||
|
@@ -31,6 +31,9 @@ const char* dtype_to_cuda_type(const Dtype& dtype) {
|
||||
if (dtype == bfloat16) {
|
||||
return "__nv_bfloat16";
|
||||
}
|
||||
if (dtype == complex64) {
|
||||
return "cuComplex";
|
||||
}
|
||||
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
|
||||
if (dtype == DTYPE) { \
|
||||
return #CPP_TYPE; \
|
||||
|
Reference in New Issue
Block a user