mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-14 12:31:13 +08:00
[CUDA] Implement Scan kernel (#2347)
* Contiguous scan * Strided scan * Enable tests * Fix failing logaddexp test * Use cexpf in Metal
This commit is contained in:
parent
b6eec20260
commit
8347575ba1
@ -35,6 +35,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
||||
@ -67,6 +68,11 @@ target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
|
||||
target_compile_options(mlx
|
||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
||||
|
||||
# Enable calling host constexpr functions from device. This is needed because
|
||||
# the constexpr version of isnan is host only.
|
||||
target_compile_options(
|
||||
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>")
|
||||
|
||||
# CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive.
|
||||
# Explicitly pass this flag to suppress the warning, it is safe to set it to
|
||||
# true but the warning wouldn't be suppressed.
|
||||
|
@ -1,10 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda/std/array>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
@ -114,36 +111,38 @@ struct LessEqual {
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
if (isnan(x) || isnan(y)) {
|
||||
return cuda::std::numeric_limits<T>::quiet_NaN();
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
|
||||
isnan(cuCimagf(y))) {
|
||||
return {
|
||||
cuda::std::numeric_limits<float>::quiet_NaN(),
|
||||
cuda::std::numeric_limits<float>::quiet_NaN()};
|
||||
}
|
||||
auto max = cuCrealf(x) > cuCrealf(y) ? x : y;
|
||||
auto min = cuCrealf(x) < cuCrealf(y) ? x : y;
|
||||
auto min_real = cuCrealf(min);
|
||||
auto max_real = cuCrealf(max);
|
||||
if (!isfinite(min_real) && (min_real == max_real)) {
|
||||
if (min_real < 0) {
|
||||
return min;
|
||||
} else {
|
||||
return Log{}(Exp{}(min) + Exp{}(max));
|
||||
}
|
||||
} else {
|
||||
return Log1p{}(Exp{}(min - max)) + max;
|
||||
}
|
||||
} else {
|
||||
if (isnan(x) || isnan(y)) {
|
||||
return cuda::std::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
T maxval = max(x, y);
|
||||
T minval = min(x, y);
|
||||
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
|
||||
maxval == cuda::std::numeric_limits<T>::infinity())
|
||||
? maxval
|
||||
: T(float(maxval) + log1p(expf(minval - maxval)));
|
||||
}
|
||||
T maxval = max(x, y);
|
||||
T minval = min(x, y);
|
||||
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
|
||||
maxval == cuda::std::numeric_limits<T>::infinity())
|
||||
? maxval
|
||||
: T(float(maxval) + log1p(expf(minval - maxval)));
|
||||
};
|
||||
|
||||
__device__ cuComplex operator()(cuComplex x, cuComplex y) {
|
||||
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
|
||||
isnan(cuCimagf(y))) {
|
||||
return {
|
||||
cuda::std::numeric_limits<float>::quiet_NaN(),
|
||||
cuda::std::numeric_limits<float>::quiet_NaN()};
|
||||
}
|
||||
float inf = cuda::std::numeric_limits<float>::infinity();
|
||||
auto maxval = x > y ? x : y;
|
||||
auto minval = x < y ? x : y;
|
||||
if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf)
|
||||
return maxval;
|
||||
float m = exp(cuCrealf(minval) - cuCrealf(maxval));
|
||||
cuComplex dexp{
|
||||
m * cos(cuCimagf(minval) - cuCimagf(maxval)),
|
||||
m * sin(cuCimagf(minval) - cuCimagf(maxval)),
|
||||
};
|
||||
return maxval + log1p(dexp);
|
||||
}
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
|
138
mlx/backend/cuda/device/cexpf.cuh
Normal file
138
mlx/backend/cuda/device/cexpf.cuh
Normal file
@ -0,0 +1,138 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
// Copyright © 2008-2013 NVIDIA Corporation
|
||||
// Copyright © 2013 Filipe RNC Maia
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Forked from
|
||||
// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h
|
||||
|
||||
// TODO: We should use thrust::exp but the thrust header in old CUDA versions
|
||||
// can not be used in JIT.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda/std/cstdint>
|
||||
|
||||
namespace mlx::core::cu::detail {
|
||||
|
||||
using ieee_float_shape_type = union {
|
||||
float value;
|
||||
uint32_t word;
|
||||
};
|
||||
|
||||
inline __device__ void get_float_word(uint32_t& i, float d) {
|
||||
ieee_float_shape_type gf_u;
|
||||
gf_u.value = (d);
|
||||
(i) = gf_u.word;
|
||||
}
|
||||
|
||||
inline __device__ void get_float_word(int32_t& i, float d) {
|
||||
ieee_float_shape_type gf_u;
|
||||
gf_u.value = (d);
|
||||
(i) = gf_u.word;
|
||||
}
|
||||
|
||||
inline __device__ void set_float_word(float& d, uint32_t i) {
|
||||
ieee_float_shape_type sf_u;
|
||||
sf_u.word = (i);
|
||||
(d) = sf_u.value;
|
||||
}
|
||||
|
||||
inline __device__ float frexp_expf(float x, int* expt) {
|
||||
const uint32_t k = 235;
|
||||
const float kln2 = 162.88958740F;
|
||||
|
||||
float exp_x;
|
||||
uint32_t hx;
|
||||
|
||||
exp_x = expf(x - kln2);
|
||||
get_float_word(hx, exp_x);
|
||||
*expt = (hx >> 23) - (0x7f + 127) + k;
|
||||
set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23));
|
||||
return exp_x;
|
||||
}
|
||||
|
||||
inline __device__ cuComplex ldexp_cexpf(cuComplex z, int expt) {
|
||||
float x, y, exp_x, scale1, scale2;
|
||||
int ex_expt, half_expt;
|
||||
|
||||
x = cuCrealf(z);
|
||||
y = cuCimagf(z);
|
||||
exp_x = frexp_expf(x, &ex_expt);
|
||||
expt += ex_expt;
|
||||
|
||||
half_expt = expt / 2;
|
||||
set_float_word(scale1, (0x7f + half_expt) << 23);
|
||||
half_expt = expt - half_expt;
|
||||
set_float_word(scale2, (0x7f + half_expt) << 23);
|
||||
|
||||
return cuComplex{
|
||||
cosf(y) * exp_x * scale1 * scale2, sinf(y) * exp_x * scale1 * scale2};
|
||||
}
|
||||
|
||||
inline __device__ cuComplex cexpf(const cuComplex& z) {
|
||||
float x, y, exp_x;
|
||||
uint32_t hx, hy;
|
||||
|
||||
const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074;
|
||||
|
||||
x = cuCrealf(z);
|
||||
y = cuCimagf(z);
|
||||
|
||||
get_float_word(hy, y);
|
||||
hy &= 0x7fffffff;
|
||||
|
||||
/* cexp(x + I 0) = exp(x) + I 0 */
|
||||
if (hy == 0) {
|
||||
return cuComplex{expf(x), y};
|
||||
}
|
||||
get_float_word(hx, x);
|
||||
/* cexp(0 + I y) = cos(y) + I sin(y) */
|
||||
if ((hx & 0x7fffffff) == 0) {
|
||||
return cuComplex{cosf(y), sinf(y)};
|
||||
}
|
||||
if (hy >= 0x7f800000) {
|
||||
if ((hx & 0x7fffffff) != 0x7f800000) {
|
||||
/* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */
|
||||
return cuComplex{y - y, y - y};
|
||||
} else if (hx & 0x80000000) {
|
||||
/* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */
|
||||
return cuComplex{0.0, 0.0};
|
||||
} else {
|
||||
/* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */
|
||||
return cuComplex{x, y - y};
|
||||
}
|
||||
}
|
||||
|
||||
if (hx >= exp_ovfl && hx <= cexp_ovfl) {
|
||||
/*
|
||||
* x is between 88.7 and 192, so we must scale to avoid
|
||||
* overflow in expf(x).
|
||||
*/
|
||||
return ldexp_cexpf(z, 0);
|
||||
} else {
|
||||
/*
|
||||
* Cases covered here:
|
||||
* - x < exp_ovfl and exp(x) won't overflow (common case)
|
||||
* - x > cexp_ovfl, so exp(x) * s overflows for all s > 0
|
||||
* - x = +-Inf (generated by exp())
|
||||
* - x = NaN (spurious inexact exception from y)
|
||||
*/
|
||||
exp_x = expf(x);
|
||||
return cuComplex{exp_x * cosf(y), exp_x * sinf(y)};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu::detail
|
@ -2,6 +2,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/cexpf.cuh"
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
@ -150,8 +152,7 @@ struct Exp {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
auto m = exp(cuCrealf(x));
|
||||
return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))};
|
||||
return detail::cexpf(x);
|
||||
} else {
|
||||
return exp(x);
|
||||
}
|
||||
@ -228,8 +229,25 @@ struct Log10 {
|
||||
|
||||
struct Log1p {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log1p(x);
|
||||
__device__ T operator()(T z) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
float x = cuCrealf(z);
|
||||
float y = cuCimagf(z);
|
||||
float zabs = cuCrealf(Abs{}(z));
|
||||
float theta = atan2f(y, x + 1);
|
||||
if (zabs < 0.5f) {
|
||||
float r = x * (2 + x) + y * y;
|
||||
if (r == 0) { // handle underflow
|
||||
return {x, theta};
|
||||
}
|
||||
return {0.5f * log1pf(r), theta};
|
||||
} else {
|
||||
float z0 = hypotf(x + 1, y);
|
||||
return {logf(z0), theta};
|
||||
}
|
||||
} else {
|
||||
return log1p(z);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -387,19 +405,19 @@ struct Tanh {
|
||||
}
|
||||
};
|
||||
|
||||
__device__ cuComplex ArcCos::operator()(cuComplex x) {
|
||||
inline __device__ cuComplex ArcCos::operator()(cuComplex x) {
|
||||
auto i = cuComplex{0.0, 1.0};
|
||||
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
|
||||
return {cuCimagf(y), -cuCrealf(y)};
|
||||
};
|
||||
|
||||
__device__ cuComplex ArcSin::operator()(cuComplex x) {
|
||||
inline __device__ cuComplex ArcSin::operator()(cuComplex x) {
|
||||
auto i = cuComplex{0.0f, 1.0f};
|
||||
auto y = Log{}(i * x + Sqrt{}(1.0f - x * x));
|
||||
return {cuCimagf(y), -cuCrealf(y)};
|
||||
};
|
||||
|
||||
__device__ cuComplex ArcTan::operator()(cuComplex x) {
|
||||
inline __device__ cuComplex ArcTan::operator()(cuComplex x) {
|
||||
auto i = cuComplex{0.0f, 1.0f};
|
||||
auto ix = i * x;
|
||||
return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix));
|
||||
|
@ -359,21 +359,4 @@ struct LoopedElemToLoc<1, false, OffsetT> {
|
||||
}
|
||||
};
|
||||
|
||||
inline __device__ cuComplex log1p(cuComplex in) {
|
||||
float x = cuCrealf(in);
|
||||
float y = cuCimagf(in);
|
||||
float zabs = sqrt(x * x + y * y);
|
||||
float theta = atan2f(y, x + 1);
|
||||
if (zabs < 0.5f) {
|
||||
float r = x * (2 + x) + y * y;
|
||||
if (r == 0) { // handle underflow
|
||||
return {x, theta};
|
||||
}
|
||||
return {0.5f * log1pf(r), theta};
|
||||
} else {
|
||||
auto z0 = sqrt((x + 1) * (x + 1) + y * y);
|
||||
return {log(z0), theta};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
@ -161,6 +161,7 @@ constexpr const char* g_include_names[] = {
|
||||
INCLUDE_PREFIX "atomic_ops.cuh",
|
||||
INCLUDE_PREFIX "binary_ops.cuh",
|
||||
INCLUDE_PREFIX "cast_op.cuh",
|
||||
INCLUDE_PREFIX "cexpf.cuh",
|
||||
INCLUDE_PREFIX "config.h",
|
||||
INCLUDE_PREFIX "cucomplex_math.cuh",
|
||||
INCLUDE_PREFIX "fp16_math.cuh",
|
||||
@ -177,6 +178,7 @@ constexpr const char* g_headers[] = {
|
||||
jit_source_atomic_ops,
|
||||
jit_source_binary_ops,
|
||||
jit_source_cast_op,
|
||||
jit_source_cexpf,
|
||||
jit_source_config,
|
||||
jit_source_cucomplex_math,
|
||||
jit_source_fp16_math,
|
||||
|
@ -82,7 +82,6 @@ NO_GPU(Load)
|
||||
NO_GPU_MULTI(LUF)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
NO_GPU(Scan)
|
||||
NO_GPU(SegmentedMM)
|
||||
NO_GPU_MULTI(SVD)
|
||||
NO_GPU(Inverse)
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
467
mlx/backend/cuda/scan.cu
Normal file
467
mlx/backend/cuda/scan.cu
Normal file
@ -0,0 +1,467 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce_ops.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/scan.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename Op, typename T>
|
||||
struct ScanResult {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ScanResult<Sum, bool> {
|
||||
using type = int32_t;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<LogAddExp, T> {
|
||||
static constexpr __host__ __device__ T value() {
|
||||
return Limits<T>::min();
|
||||
}
|
||||
};
|
||||
|
||||
template <bool reverse, typename T, typename U, int N_READS>
|
||||
inline __device__ void
|
||||
load_values(int index, const T* in, U (&values)[N_READS], int size, U init) {
|
||||
int remaining = size - index * N_READS;
|
||||
if constexpr (reverse) {
|
||||
in += remaining - N_READS;
|
||||
if (remaining < N_READS) {
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
values[N_READS - i - 1] =
|
||||
(N_READS - i - 1 < remaining) ? cast_to<U>(in[i]) : init;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
values[N_READS - i - 1] = cast_to<U>(in[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
in += index * N_READS;
|
||||
if (remaining < N_READS) {
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
values[i] = (i < remaining) ? cast_to<U>(in[i]) : init;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
values[i] = cast_to<U>(in[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool reverse, int offset, typename T, int N_READS>
|
||||
inline __device__ void
|
||||
store_values(int index, T* out, T (&values)[N_READS], int size) {
|
||||
int start = index * N_READS + offset;
|
||||
int remaining = size - start;
|
||||
if constexpr (reverse) {
|
||||
out += remaining - N_READS;
|
||||
if (remaining < N_READS) {
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
if (N_READS - i - 1 < remaining) {
|
||||
out[i] = values[N_READS - i - 1];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out[i] = values[N_READS - i - 1];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
out += start;
|
||||
if (remaining < N_READS) {
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
if (i < remaining) {
|
||||
out[i] = values[i];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out[i] = values[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N_READS,
|
||||
bool inclusive,
|
||||
bool reverse>
|
||||
__global__ void contiguous_scan(const T* in, U* out, int32_t axis_size) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
in += grid.block_rank() * axis_size;
|
||||
out += grid.block_rank() * axis_size;
|
||||
|
||||
__shared__ U warp_sums[WARP_SIZE];
|
||||
|
||||
Op op;
|
||||
U init = ReduceInit<Op, T>::value();
|
||||
U prefix = init;
|
||||
|
||||
// Scan per block.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, block.size() * N_READS); ++r) {
|
||||
int32_t index = r * block.size() + block.thread_rank();
|
||||
U values[N_READS];
|
||||
load_values<reverse>(index, in, values, axis_size, init);
|
||||
|
||||
// Compute an inclusive scan per thread.
|
||||
for (int i = 1; i < N_READS; ++i) {
|
||||
values[i] = op(values[i], values[i - 1]);
|
||||
}
|
||||
|
||||
// Compute exclusive scan of thread sums.
|
||||
U prev_thread_sum = cg::exclusive_scan(warp, values[N_READS - 1], op);
|
||||
if (warp.thread_rank() == 0) {
|
||||
prev_thread_sum = init;
|
||||
}
|
||||
|
||||
// Write wrap's sum to shared memory.
|
||||
if (warp.thread_rank() == WARP_SIZE - 1) {
|
||||
warp_sums[warp.meta_group_rank()] =
|
||||
op(prev_thread_sum, values[N_READS - 1]);
|
||||
}
|
||||
block.sync();
|
||||
|
||||
// Compute exclusive scan of warp sums.
|
||||
if (warp.meta_group_rank() == 0) {
|
||||
U prev_warp_sum =
|
||||
cg::exclusive_scan(warp, warp_sums[warp.thread_rank()], op);
|
||||
if (warp.thread_rank() == 0) {
|
||||
prev_warp_sum = init;
|
||||
}
|
||||
warp_sums[warp.thread_rank()] = prev_warp_sum;
|
||||
}
|
||||
block.sync();
|
||||
|
||||
// Compute the output.
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
values[i] = op(values[i], prefix);
|
||||
values[i] = op(values[i], warp_sums[warp.meta_group_rank()]);
|
||||
values[i] = op(values[i], prev_thread_sum);
|
||||
}
|
||||
|
||||
// Write the values.
|
||||
if (inclusive) {
|
||||
store_values<reverse, 0>(index, out, values, axis_size);
|
||||
} else {
|
||||
store_values<reverse, 1>(index, out, values, axis_size);
|
||||
if (reverse) {
|
||||
if (block.thread_rank() == 0 && index == 0) {
|
||||
out[axis_size - 1] = init;
|
||||
}
|
||||
} else {
|
||||
if (block.thread_rank() == 0 && index == 0) {
|
||||
out[0] = init;
|
||||
}
|
||||
}
|
||||
}
|
||||
block.sync();
|
||||
|
||||
// Share the prefix.
|
||||
if ((warp.meta_group_rank() == warp.meta_group_size() - 1) &&
|
||||
(warp.thread_rank() == WARP_SIZE - 1)) {
|
||||
warp_sums[0] = values[N_READS - 1];
|
||||
}
|
||||
block.sync();
|
||||
prefix = warp_sums[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N_READS,
|
||||
int BM,
|
||||
int BN,
|
||||
bool inclusive,
|
||||
bool reverse>
|
||||
__global__ void strided_scan(
|
||||
const T* in,
|
||||
U* out,
|
||||
int32_t axis_size,
|
||||
int64_t stride,
|
||||
int64_t stride_blocks) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
constexpr int BN_pad = WARP_SIZE + 16 / sizeof(U);
|
||||
constexpr int n_warps = BN / N_READS;
|
||||
constexpr int n_scans = BN / n_warps;
|
||||
|
||||
__shared__ U read_buffer[BM * BN_pad];
|
||||
|
||||
Op op;
|
||||
U init = ReduceInit<Op, T>::value();
|
||||
U values[n_scans];
|
||||
U prefix[n_scans];
|
||||
for (int i = 0; i < n_scans; ++i) {
|
||||
prefix[i] = init;
|
||||
}
|
||||
|
||||
// Compute offsets.
|
||||
int64_t offset = (grid.block_rank() / stride_blocks) * axis_size * stride;
|
||||
int64_t global_index_x = (grid.block_rank() % stride_blocks) * BN;
|
||||
uint read_offset_y = (block.thread_rank() * N_READS) / BN;
|
||||
uint read_offset_x = (block.thread_rank() * N_READS) % BN;
|
||||
uint scan_offset_y = warp.thread_rank();
|
||||
uint scan_offset_x = warp.meta_group_rank() * n_scans;
|
||||
|
||||
uint stride_limit = stride - global_index_x;
|
||||
in += offset + global_index_x + read_offset_x;
|
||||
out += offset + global_index_x + read_offset_x;
|
||||
U* read_into = read_buffer + read_offset_y * BN_pad + read_offset_x;
|
||||
U* read_from = read_buffer + scan_offset_y * BN_pad + scan_offset_x;
|
||||
|
||||
for (uint j = 0; j < axis_size; j += BM) {
|
||||
// Calculate the indices for the current thread.
|
||||
uint index_y = j + read_offset_y;
|
||||
uint check_index_y = index_y;
|
||||
if (reverse) {
|
||||
index_y = axis_size - 1 - index_y;
|
||||
}
|
||||
|
||||
// Read in SM.
|
||||
if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
read_into[i] = in[index_y * stride + i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {
|
||||
read_into[i] = in[index_y * stride + i];
|
||||
} else {
|
||||
read_into[i] = init;
|
||||
}
|
||||
}
|
||||
}
|
||||
block.sync();
|
||||
|
||||
// Read strided into registers.
|
||||
for (int i = 0; i < n_scans; ++i) {
|
||||
values[i] = read_from[i];
|
||||
}
|
||||
|
||||
// Perform the scan.
|
||||
for (int i = 0; i < n_scans; ++i) {
|
||||
values[i] = cg::inclusive_scan(warp, values[i], op);
|
||||
values[i] = op(values[i], prefix[i]);
|
||||
prefix[i] = warp.shfl(values[i], WARP_SIZE - 1);
|
||||
}
|
||||
|
||||
// Write to SM.
|
||||
for (int i = 0; i < n_scans; ++i) {
|
||||
read_from[i] = values[i];
|
||||
}
|
||||
block.sync();
|
||||
|
||||
// Write to device memory.
|
||||
if (!inclusive) {
|
||||
if (check_index_y == 0) {
|
||||
if ((read_offset_x + N_READS) < stride_limit) {
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out[index_y * stride + i] = init;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
if ((read_offset_x + i) < stride_limit) {
|
||||
out[index_y * stride + i] = init;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (reverse) {
|
||||
index_y -= 1;
|
||||
check_index_y += 1;
|
||||
} else {
|
||||
index_y += 1;
|
||||
check_index_y += 1;
|
||||
}
|
||||
}
|
||||
if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out[index_y * stride + i] = read_into[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {
|
||||
out[index_y * stride + i] = read_into[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <typename F>
|
||||
void dispatch_scan_ops(Scan::ReduceType scan_op, F&& f) {
|
||||
if (scan_op == Scan::ReduceType::Max) {
|
||||
f(type_identity<cu::Max>{});
|
||||
} else if (scan_op == Scan::ReduceType::Min) {
|
||||
f(type_identity<cu::Min>{});
|
||||
} else if (scan_op == Scan::ReduceType::Sum) {
|
||||
f(type_identity<cu::Sum>{});
|
||||
} else if (scan_op == Scan::ReduceType::Prod) {
|
||||
f(type_identity<cu::Prod>{});
|
||||
} else if (scan_op == Scan::ReduceType::LogAddExp) {
|
||||
f(type_identity<cu::LogAddExp>{});
|
||||
} else {
|
||||
throw std::invalid_argument("Unknown reduce type.");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
const char* op_to_string() {
|
||||
if (cuda::std::is_same_v<Op, cu::Max>) {
|
||||
return "Max";
|
||||
} else if (cuda::std::is_same_v<Op, cu::Min>) {
|
||||
return "Min";
|
||||
} else if (cuda::std::is_same_v<Op, cu::Sum>) {
|
||||
return "Sum";
|
||||
} else if (cuda::std::is_same_v<Op, cu::Prod>) {
|
||||
return "Prod";
|
||||
} else if (cuda::std::is_same_v<Op, cu::LogAddExp>) {
|
||||
return "LogAddExp";
|
||||
} else {
|
||||
throw std::invalid_argument("Unknown op.");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename T>
|
||||
constexpr bool supports_scan_op() {
|
||||
if constexpr (cuda::std::is_same_v<Op, LogAddExp>) {
|
||||
return is_inexact_v<T>;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Scan::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
auto in = inputs[0];
|
||||
auto& s = stream();
|
||||
|
||||
if (in.flags().contiguous && in.strides()[axis_] != 0) {
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, arr_copy, CopyType::General, s);
|
||||
in = std::move(arr_copy);
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
|
||||
constexpr int N_READS = 4;
|
||||
int32_t axis_size = in.shape(axis_);
|
||||
bool contiguous = in.strides()[axis_] == 1;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
dispatch_scan_ops(reduce_type_, [&](auto scan_op_tag) {
|
||||
using Op = MLX_GET_TYPE(scan_op_tag);
|
||||
if constexpr (supports_scan_op<Op, T>) {
|
||||
using U = typename cu::ScanResult<Op, T>::type;
|
||||
dispatch_bool(inclusive_, [&](auto inclusive) {
|
||||
dispatch_bool(reverse_, [&](auto reverse) {
|
||||
if (contiguous) {
|
||||
auto kernel = cu::contiguous_scan<
|
||||
T,
|
||||
U,
|
||||
Op,
|
||||
N_READS,
|
||||
inclusive.value,
|
||||
reverse.value>;
|
||||
int block_dim = cuda::ceil_div(axis_size, N_READS);
|
||||
block_dim = cuda::ceil_div(block_dim, WARP_SIZE) * WARP_SIZE;
|
||||
block_dim = std::min(block_dim, WARP_SIZE * WARP_SIZE);
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
in.data_size() / axis_size,
|
||||
block_dim,
|
||||
in.data<T>(),
|
||||
out.data<U>(),
|
||||
axis_size);
|
||||
} else {
|
||||
constexpr int BM = WARP_SIZE;
|
||||
constexpr int BN = WARP_SIZE;
|
||||
auto kernel = cu::strided_scan<
|
||||
T,
|
||||
U,
|
||||
Op,
|
||||
N_READS,
|
||||
BM,
|
||||
BN,
|
||||
inclusive.value,
|
||||
reverse.value>;
|
||||
int64_t stride = in.strides()[axis_];
|
||||
int64_t stride_blocks = cuda::ceil_div(stride, BN);
|
||||
dim3 num_blocks = get_2d_grid_dims(
|
||||
in.shape(), in.strides(), axis_size * stride);
|
||||
if (num_blocks.x * stride_blocks <= UINT32_MAX) {
|
||||
num_blocks.x *= stride_blocks;
|
||||
} else {
|
||||
num_blocks.y *= stride_blocks;
|
||||
}
|
||||
int block_dim = (BN / N_READS) * WARP_SIZE;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dim,
|
||||
in.data<T>(),
|
||||
out.data<U>(),
|
||||
axis_size,
|
||||
stride,
|
||||
stride_blocks);
|
||||
}
|
||||
});
|
||||
});
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do scan op {} on inputs of {} with result of {}.",
|
||||
op_to_string<Op>(),
|
||||
dtype_to_string(in.dtype()),
|
||||
dtype_to_string(out.dtype())));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
134
mlx/backend/metal/kernels/cexpf.h
Normal file
134
mlx/backend/metal/kernels/cexpf.h
Normal file
@ -0,0 +1,134 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
// Copyright © 2008-2013 NVIDIA Corporation
|
||||
// Copyright © 2013 Filipe RNC Maia
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Forked from
|
||||
// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h
|
||||
|
||||
// TODO: We should use thrust::exp but the thrust header in old CUDA versions
|
||||
// can not be used in JIT.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_math>
|
||||
|
||||
using ieee_float_shape_type = union {
|
||||
float value;
|
||||
uint32_t word;
|
||||
};
|
||||
|
||||
inline void get_float_word(thread uint32_t& i, float d) {
|
||||
ieee_float_shape_type gf_u;
|
||||
gf_u.value = (d);
|
||||
(i) = gf_u.word;
|
||||
}
|
||||
|
||||
inline void get_float_word(thread int32_t& i, float d) {
|
||||
ieee_float_shape_type gf_u;
|
||||
gf_u.value = (d);
|
||||
(i) = gf_u.word;
|
||||
}
|
||||
|
||||
inline void set_float_word(thread float& d, uint32_t i) {
|
||||
ieee_float_shape_type sf_u;
|
||||
sf_u.word = (i);
|
||||
(d) = sf_u.value;
|
||||
}
|
||||
|
||||
inline float frexp_expf(float x, thread int* expt) {
|
||||
const uint32_t k = 235;
|
||||
const float kln2 = 162.88958740F;
|
||||
|
||||
float exp_x;
|
||||
uint32_t hx;
|
||||
|
||||
exp_x = metal::exp(x - kln2);
|
||||
get_float_word(hx, exp_x);
|
||||
*expt = (hx >> 23) - (0x7f + 127) + k;
|
||||
set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23));
|
||||
return exp_x;
|
||||
}
|
||||
|
||||
inline complex64_t ldexp_cexpf(complex64_t z, int expt) {
|
||||
float x, y, exp_x, scale1, scale2;
|
||||
int ex_expt, half_expt;
|
||||
|
||||
x = z.real;
|
||||
y = z.imag;
|
||||
exp_x = frexp_expf(x, &ex_expt);
|
||||
expt += ex_expt;
|
||||
|
||||
half_expt = expt / 2;
|
||||
set_float_word(scale1, (0x7f + half_expt) << 23);
|
||||
half_expt = expt - half_expt;
|
||||
set_float_word(scale2, (0x7f + half_expt) << 23);
|
||||
|
||||
return complex64_t{
|
||||
metal::cos(y) * exp_x * scale1 * scale2,
|
||||
metal::sin(y) * exp_x * scale1 * scale2};
|
||||
}
|
||||
|
||||
inline complex64_t cexpf(const thread complex64_t& z) {
|
||||
float x, y, exp_x;
|
||||
uint32_t hx, hy;
|
||||
|
||||
const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074;
|
||||
|
||||
x = z.real;
|
||||
y = z.imag;
|
||||
|
||||
get_float_word(hy, y);
|
||||
hy &= 0x7fffffff;
|
||||
|
||||
/* cexp(x + I 0) = exp(x) + I 0 */
|
||||
if (hy == 0) {
|
||||
return complex64_t{metal::exp(x), y};
|
||||
}
|
||||
get_float_word(hx, x);
|
||||
/* cexp(0 + I y) = cos(y) + I sin(y) */
|
||||
if ((hx & 0x7fffffff) == 0) {
|
||||
return complex64_t{metal::cos(y), metal::sin(y)};
|
||||
}
|
||||
if (hy >= 0x7f800000) {
|
||||
if ((hx & 0x7fffffff) != 0x7f800000) {
|
||||
/* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */
|
||||
return complex64_t{y - y, y - y};
|
||||
} else if (hx & 0x80000000) {
|
||||
/* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */
|
||||
return complex64_t{0.0, 0.0};
|
||||
} else {
|
||||
/* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */
|
||||
return complex64_t{x, y - y};
|
||||
}
|
||||
}
|
||||
|
||||
if (hx >= exp_ovfl && hx <= cexp_ovfl) {
|
||||
/*
|
||||
* x is between 88.7 and 192, so we must scale to avoid
|
||||
* overflow in expf(x).
|
||||
*/
|
||||
return ldexp_cexpf(z, 0);
|
||||
} else {
|
||||
/*
|
||||
* Cases covered here:
|
||||
* - x < exp_ovfl and exp(x) won't overflow (common case)
|
||||
* - x > cexp_ovfl, so exp(x) * s overflows for all s > 0
|
||||
* - x = +-Inf (generated by exp())
|
||||
* - x = NaN (spurious inexact exception from y)
|
||||
*/
|
||||
exp_x = metal::exp(x);
|
||||
return complex64_t{exp_x * metal::cos(y), exp_x * metal::sin(y)};
|
||||
}
|
||||
}
|
@ -5,6 +5,7 @@
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/cexpf.h"
|
||||
#include "mlx/backend/metal/kernels/erf.h"
|
||||
#include "mlx/backend/metal/kernels/expm1f.h"
|
||||
|
||||
@ -178,8 +179,7 @@ struct Exp {
|
||||
return metal::precise::exp(x);
|
||||
};
|
||||
complex64_t operator()(complex64_t x) {
|
||||
auto m = metal::precise::exp(x.real);
|
||||
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
|
||||
return cexpf(x);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -13,11 +13,6 @@ cuda_skip = {
|
||||
"TestBlas.test_gather_mm_sorted",
|
||||
# Segmented matmul NYI
|
||||
"TestBlas.test_segmented_mm",
|
||||
# Scan NYI
|
||||
"TestArray.test_api",
|
||||
"TestAutograd.test_cumprod_grad",
|
||||
"TestOps.test_scans",
|
||||
"TestOps.test_logcumsumexp",
|
||||
# Hadamard NYI
|
||||
"TestOps.test_hadamard",
|
||||
"TestOps.test_hadamard_grad_vmap",
|
||||
|
@ -1350,6 +1350,11 @@ TEST_CASE("test arithmetic unary ops") {
|
||||
x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0];
|
||||
auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1});
|
||||
CHECK(allclose(exp(x), expected).item<bool>());
|
||||
|
||||
// Complex of -inf
|
||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||
x = array(complex64_t{-inf, -inf});
|
||||
CHECK_EQ(exp(x).item<complex64_t>(), complex64_t{0, 0});
|
||||
}
|
||||
|
||||
// Test expm1
|
||||
@ -1830,6 +1835,10 @@ TEST_CASE("test arithmetic binary ops") {
|
||||
x = array(-inf);
|
||||
y = array(inf);
|
||||
CHECK_EQ(logaddexp(x, y).item<float>(), inf);
|
||||
|
||||
x = array(complex64_t{1, 1});
|
||||
y = array(complex64_t{-inf, -inf});
|
||||
CHECK_EQ(logaddexp(x, y).item<complex64_t>(), complex64_t{1, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test broadcast") {
|
||||
|
Loading…
Reference in New Issue
Block a user