Compare commits

...

7 Commits

Author SHA1 Message Date
Cheng
ebdd22a8d4 CUDA backend: layernorm 2025-06-11 15:04:08 -07:00
Cheng
c371baf53a
CUDA backend: softmax (#2272) 2025-06-11 13:55:22 -07:00
Cheng
ccf78f566c
CUDA backend: argreduce (#2270) 2025-06-11 13:26:17 -07:00
Cheng
c9fa68664a
CUDA backend: reduce (#2269) 2025-06-11 11:22:25 -07:00
Awni Hannun
c35f4d089a
start cuda circle config (#2256)
* rebase

* fix metal kernel linking issue on cuda

* start cuda circle config
2025-06-10 21:19:47 -07:00
Angelos Katharopoulos
8590c0941e
Add load_safe to the general conv loaders (#2258) 2025-06-10 20:58:16 -07:00
Cheng
095163b8d1
Fix building cpp benchmarks on Linux (#2268) 2025-06-10 17:10:24 -07:00
37 changed files with 2505 additions and 54 deletions

View File

@ -212,6 +212,29 @@ jobs:
METAL_DEBUG_ERROR_MODE=0 \
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
cuda_build_and_test:
machine:
image: linux-cuda-12:default
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- run:
name: Install Python package
command: |
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
python -m venv env
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install -e ".[dev]"
- run:
name: Run Python tests
command: |
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
build_release:
parameters:
python_version:
@ -348,6 +371,7 @@ workflows:
parameters:
macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test
- cuda_build_and_test
- build_documentation
build_pypi_release:
@ -455,6 +479,8 @@ workflows:
macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test:
requires: [ hold ]
- cuda_build_and_test:
requires: [ hold ]
nightly_build:
when:
and:

View File

@ -1,5 +1,6 @@
// Copyright © 2023 Apple Inc.
#include <cstring>
#include <iostream>
#include <sstream>

View File

@ -0,0 +1,107 @@
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
dtype = "float32"
shapes = (
(4, 32, 32, 21, 3, 3, 128),
(4, 32, 32, 21, 3, 3, 37),
(4, 32, 32, 370, 3, 3, 370),
(4, 32, 32, 370, 7, 7, 128),
(2, 320, 640, 21, 7, 7, 21),
)
for N, H, W, C, kh, kw, O in shapes:
time_mlx, time_torch = bench_shape(
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@ -55,6 +55,9 @@ endif()
if(MLX_BUILD_CUDA)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
else()
target_sources(mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
endif()
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)

View File

@ -6,21 +6,30 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp

View File

@ -0,0 +1,189 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
#include <cub/block/block_load.cuh>
#include <cub/block/block_reduce.cuh>
#include <cassert>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T>
struct IndexValPair {
uint32_t index;
T val;
};
template <typename T>
struct ArgMin {
constexpr __device__ T init() {
return Limits<T>::max();
}
__device__ IndexValPair<T> operator()(
const IndexValPair<T>& best,
const IndexValPair<T>& current) {
if (best.val > current.val ||
(best.val == current.val && best.index > current.index)) {
return current;
} else {
return best;
}
}
template <int N>
__device__ IndexValPair<T>
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
for (int i = 0; i < N; i++) {
if (vals[i] < best.val) {
best.val = vals[i];
best.index = offset + i;
}
}
return best;
}
};
template <typename T>
struct ArgMax {
constexpr __device__ T init() {
return Limits<T>::min();
}
__device__ IndexValPair<T> operator()(
const IndexValPair<T>& best,
const IndexValPair<T>& current) {
if (best.val < current.val ||
(best.val == current.val && best.index > current.index)) {
return current;
} else {
return best;
}
}
template <int N>
__device__ IndexValPair<T>
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
for (int i = 0; i < N; i++) {
if (vals[i] > best.val) {
best.val = vals[i];
best.index = offset + i;
}
}
return best;
}
};
template <typename T, typename Op, int BLOCK_DIM, int N_READS = 4>
__global__ void arg_reduce_general(
const T* in,
uint32_t* out,
size_t size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides in_strides,
const __grid_constant__ Strides out_strides,
int32_t ndim,
int64_t axis_stride,
int32_t axis_size) {
auto block = cg::this_thread_block();
int64_t index = cg::this_grid().block_rank();
if (index >= size) {
return;
}
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
Op op;
T init = op.init();
IndexValPair<T> best{0, init};
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T vals[N_READS];
auto tid = r * BLOCK_DIM + block.thread_index().z;
cub::LoadDirectBlocked(
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
best = op.reduce_many(best, vals, tid * N_READS);
}
typedef cub::BlockReduce<IndexValPair<T>, BLOCK_DIM> BlockReduceT;
__shared__ typename BlockReduceT::TempStorage temp;
best = BlockReduceT(temp).Reduce(best, op);
if (block.thread_rank() == 0) {
out[out_idx] = best.index;
}
}
} // namespace cu
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("ArgReduce::eval_gpu");
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
// Prepare the shapes, strides and axis arguments.
Shape shape = remove_index(in.shape(), axis_);
Strides in_strides = remove_index(in.strides(), axis_);
Strides out_strides = out.ndim() == in.ndim()
? remove_index(out.strides(), axis_)
: out.strides();
int64_t axis_stride = in.strides()[axis_];
int32_t axis_size = in.shape()[axis_];
int32_t ndim = shape.size();
// ArgReduce.
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_REAL_TYPES_CHECKED(in.dtype(), "ArgReduce", CTYPE, {
using InType = cuda_type_t<CTYPE>;
constexpr uint32_t N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
dim3 block_dims{1, 1, BLOCK_DIM};
auto kernel = &cu::arg_reduce_general<
InType,
cu::ArgMax<InType>,
BLOCK_DIM,
N_READS>;
if (reduce_type_ == ArgReduce::ArgMin) {
kernel = &cu::arg_reduce_general<
InType,
cu::ArgMin<InType>,
BLOCK_DIM,
N_READS>;
}
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>(),
out.data<uint32_t>(),
out.size(),
const_param(shape),
const_param(in_strides),
const_param(out_strides),
ndim,
axis_stride,
axis_size);
});
});
});
}
} // namespace mlx::core

11
mlx/backend/cuda/cuda.cpp Normal file
View File

@ -0,0 +1,11 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/cuda.h"
namespace mlx::core::cu {
bool is_available() {
return true;
}
} // namespace mlx::core::cu

10
mlx/backend/cuda/cuda.h Normal file
View File

@ -0,0 +1,10 @@
// Copyright © 2025 Apple Inc.
#pragma once
namespace mlx::core::cu {
/* Check if the CUDA backend is available. */
bool is_available();
} // namespace mlx::core::cu

View File

@ -0,0 +1,60 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <thrust/iterator/iterator_adaptor.h>
#include <thrust/iterator/iterator_facade.h>
namespace mlx::core::cu {
// RandomAccessIterator for strided access to array entries.
template <typename Iterator, typename Stride = int64_t>
class strided_iterator
: public thrust::
iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator> {
public:
using super_t =
thrust::iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator>;
using reference = typename super_t::reference;
using difference_type = typename super_t::difference_type;
__host__ __device__ strided_iterator(Iterator it, Stride stride)
: super_t(it), stride_(stride) {}
__host__ __device__ Stride stride() const {
return stride_;
}
private:
friend class thrust::iterator_core_access;
__host__ __device__ bool equal(const strided_iterator& other) const {
return this->base() == other.base();
}
__host__ __device__ void advance(difference_type n) {
this->base_reference() += n * stride_;
}
__host__ __device__ void increment() {
this->base_reference() += stride_;
}
__host__ __device__ void decrement() {
this->base_reference() -= stride_;
}
__host__ __device__ difference_type
distance_to(const strided_iterator& other) const {
const difference_type dist = other.base() - this->base();
_CCCL_ASSERT(
dist % stride() == 0,
"Underlying iterator difference must be divisible by the stride");
return dist / stride();
}
Stride stride_;
};
} // namespace mlx::core::cu

View File

@ -47,6 +47,31 @@ namespace mlx::core {
__VA_ARGS__; \
}
// Convert a block_dim to constexpr between WARP_SIZE and WARP_SIZE ^ 2.
#define MLX_SWITCH_BLOCK_DIM(NUM_THREADS, BLOCK_DIM, ...) \
{ \
uint32_t _num_threads = NUM_THREADS; \
if (_num_threads <= WARP_SIZE) { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE; \
__VA_ARGS__; \
} else if (_num_threads <= WARP_SIZE * 2) { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 2; \
__VA_ARGS__; \
} else if (_num_threads <= WARP_SIZE * 4) { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 4; \
__VA_ARGS__; \
} else if (_num_threads <= WARP_SIZE * 8) { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 8; \
__VA_ARGS__; \
} else if (_num_threads <= WARP_SIZE * 16) { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 16; \
__VA_ARGS__; \
} else { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * WARP_SIZE; \
__VA_ARGS__; \
} \
}
// Maps CPU types to CUDA types.
template <typename T>
struct CTypeToCudaType {

View File

@ -9,6 +9,8 @@
#pragma once
#include <cuComplex.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda/std/array>
#include <cuda/std/limits>
#include <cuda/std/tuple>
@ -19,6 +21,10 @@ namespace mlx::core::cu {
// CUDA kernel utils
///////////////////////////////////////////////////////////////////////////////
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
// warpSize variable exists, using it would prevent compile-time optimizations.
#define WARP_SIZE 32
// To pass shape/strides to kernels via constant memory, their size must be
// known at compile time.
#define MAX_NDIM 8
@ -26,6 +32,94 @@ namespace mlx::core::cu {
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
///////////////////////////////////////////////////////////////////////////////
// Type limits utils
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename = void>
struct Limits {
static constexpr __host__ __device__ T max() {
return cuda::std::numeric_limits<T>::max();
}
static constexpr __host__ __device__ T min() {
return cuda::std::numeric_limits<T>::min();
}
static constexpr __host__ __device__ T finite_max() {
return cuda::std::numeric_limits<T>::max();
}
static constexpr __host__ __device__ T finite_min() {
return cuda::std::numeric_limits<T>::min();
}
};
template <typename T>
struct Limits<
T,
cuda::std::enable_if_t<
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double>>> {
static constexpr __host__ __device__ T max() {
return cuda::std::numeric_limits<T>::infinity();
}
static constexpr __host__ __device__ T min() {
return -cuda::std::numeric_limits<T>::infinity();
}
static constexpr __host__ __device__ T finite_max() {
return cuda::std::numeric_limits<T>::max();
}
static constexpr __host__ __device__ T finite_min() {
return cuda::std::numeric_limits<T>::lowest();
}
};
// CUDA 11 does not have host side arithmatic operators for half types.
template <typename T>
struct Limits<
T,
cuda::std::enable_if_t<
cuda::std::is_same_v<T, __half> ||
cuda::std::is_same_v<T, __nv_bfloat16>>> {
static constexpr __host__ __device__ T max() {
return cuda::std::numeric_limits<T>::infinity();
}
static constexpr __host__ __device__ T min() {
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
return -cuda::std::numeric_limits<T>::infinity();
#else
return -cuda::std::numeric_limits<float>::infinity();
#endif
}
static constexpr __host__ __device__ T finite_max() {
return cuda::std::numeric_limits<T>::max();
}
static constexpr __host__ __device__ T finite_min() {
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
return cuda::std::numeric_limits<T>::lowest();
#else
return cuda::std::numeric_limits<float>::lowest();
#endif
}
};
template <>
struct Limits<bool> {
static constexpr __host__ __device__ bool max() {
return true;
}
static constexpr __host__ __device__ bool min() {
return false;
}
};
template <>
struct Limits<cuComplex> {
static constexpr __host__ __device__ cuComplex max() {
return {Limits<float>::max(), Limits<float>::max()};
}
static constexpr __host__ __device__ cuComplex min() {
return {Limits<float>::min(), Limits<float>::min()};
}
};
///////////////////////////////////////////////////////////////////////////////
// Indexing utils
///////////////////////////////////////////////////////////////////////////////
@ -101,4 +195,108 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
return cuda::std::make_tuple(a_loc, b_loc);
}
///////////////////////////////////////////////////////////////////////////////
// Elem to loc in a loop utils
///////////////////////////////////////////////////////////////////////////////
template <int DIM, bool General = true, typename OffsetT = size_t>
struct LoopedElemToLoc {
int dim;
LoopedElemToLoc<DIM - 1, General, OffsetT> inner_looper;
OffsetT offset{0};
int index{0};
__device__ LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}
__device__ void next(const int* shape, const int64_t* strides) {
if (dim == 0) {
return;
}
index++;
offset += OffsetT(strides[dim - 1]);
if (index >= shape[dim - 1]) {
index = 0;
inner_looper.next(shape, strides);
offset = inner_looper.offset;
}
}
__device__ void next(int n, const int* shape, const int64_t* strides) {
if (dim == 0) {
return;
}
index += n;
offset += n * OffsetT(strides[dim - 1]);
if (index >= shape[dim - 1]) {
int extra = index - shape[dim - 1];
if (extra >= shape[dim - 1]) {
inner_looper.next(1 + extra / shape[dim - 1], shape, strides);
extra = extra % shape[dim - 1];
} else {
inner_looper.next(shape, strides);
}
index = 0;
offset = inner_looper.offset;
if (extra > 0) {
next(extra, shape, strides);
}
}
}
__device__ OffsetT location() {
return offset;
}
};
template <typename OffsetT>
struct LoopedElemToLoc<1, true, OffsetT> {
int dim;
OffsetT offset{0};
int index{0};
__device__ LoopedElemToLoc(int dim) : dim(dim) {}
__device__ void next(const int* shape, const int64_t* strides) {
index++;
if (dim > 1) {
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
} else {
offset += OffsetT(strides[0]);
}
}
__device__ void next(int n, const int* shape, const int64_t* strides) {
index += n;
if (dim > 1) {
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
} else {
offset = index * OffsetT(strides[0]);
}
}
__device__ OffsetT location() {
return offset;
}
};
template <typename OffsetT>
struct LoopedElemToLoc<1, false, OffsetT> {
OffsetT offset{0};
__device__ LoopedElemToLoc(int) {}
__device__ void next(const int*, const int64_t* strides) {
offset += OffsetT(strides[0]);
}
__device__ void next(int n, const int*, const int64_t* strides) {
offset += n * OffsetT(strides[0]);
}
__device__ OffsetT location() {
return offset;
}
};
} // namespace mlx::core::cu

View File

@ -0,0 +1,390 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
#include <cub/block/block_load.cuh>
#include <cub/block/block_reduce.cuh>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
inline __device__ float3 plus_f3(const float3& a, const float3& b) {
return {a.x + b.x, a.y + b.y, a.z + b.z};
}
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
template <typename T, int BLOCK_DIM>
struct BlockBroadcastReduce {
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
static_assert(BLOCK_DIM % WARP_SIZE == 0);
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
cg::thread_block& block;
TempStorage& temp;
template <typename Op>
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
auto warp = cg::tiled_partition<WARP_SIZE>(block);
T x = cg::reduce(warp, input, op);
if (warp.thread_rank() == 0) {
temp[warp.meta_group_rank()] = x;
}
block.sync();
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
: init_value;
return cg::reduce(warp, x, op);
}
__device__ T Sum(const T& input) {
return Reduce(input, cg::plus<T>{}, T{});
}
};
template <typename T, int BLOCK_DIM, int N_READS = 4>
__global__ void layer_norm(
const T* x,
const T* w,
const T* b,
T* out,
float eps,
int32_t axis_size,
int64_t w_stride,
int64_t b_stride) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM>;
__shared__ typename BlockReduceT::TempStorage temp;
x += grid.block_rank() * axis_size;
out += grid.block_rank() * axis_size;
// Sum.
float sum = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS] = {};
cub::LoadDirectBlocked(index, x, xn, axis_size);
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
}
sum = BlockReduceT{block, temp}.Sum(sum);
// Mean.
float mean = sum / axis_size;
// Normalizer.
float normalizer = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
for (int i = 0; i < N_READS; ++i) {
float t = static_cast<float>(xn[i]) - mean;
normalizer += t * t;
}
}
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
normalizer = rsqrt(normalizer / axis_size + eps);
// Outputs.
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS];
T wn[N_READS];
T bn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(b, b_stride), bn, axis_size);
for (int i = 0; i < N_READS; ++i) {
float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
}
cub::StoreDirectBlocked(index, out, xn, axis_size);
}
}
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
__global__ void layer_norm_vjp(
const T* x,
const T* w,
const T* g,
T* gx,
T* gw,
float eps,
int32_t axis_size,
int64_t w_stride) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
using BlockReduceF3 = BlockBroadcastReduce<float3, BLOCK_DIM>;
__shared__ union {
typename BlockReduceF::TempStorage f;
typename BlockReduceF3::TempStorage f3;
} temp;
x += grid.block_rank() * axis_size;
g += grid.block_rank() * axis_size;
gx += grid.block_rank() * axis_size;
gw += grid.block_rank() * axis_size;
// Sum.
float sum = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS] = {};
cub::LoadDirectBlocked(index, x, xn, axis_size);
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
}
sum = BlockReduceF{block, temp.f}.Sum(sum);
// Mean.
float mean = sum / axis_size;
// Normalizer.
float3 factors = {};
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T xn[N_READS];
T wn[N_READS] = {};
T gn[N_READS] = {};
auto index = r * BLOCK_DIM + block.thread_rank();
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) {
float t = static_cast<float>(xn[i]) - mean;
float wi = wn[i];
float gi = gn[i];
float wg = wi * gi;
factors = plus_f3(factors, {wg, wg * t, t * t});
}
}
factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {});
float meanwg = factors.x / axis_size;
float meanwgxc = factors.y / axis_size;
float normalizer2 = 1 / (factors.z / axis_size + eps);
float normalizer = sqrt(normalizer2);
// Outputs.
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS];
T wn[N_READS];
T gn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) {
float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
float wi = wn[i];
float gi = gn[i];
xn[i] = normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2;
if constexpr (HAS_W) {
wn[i] = gi * xi;
}
}
cub::StoreDirectBlocked(index, gx, xn, axis_size);
if constexpr (HAS_W) {
cub::StoreDirectBlocked(index, gw, wn, axis_size);
}
}
}
} // namespace cu
namespace fast {
bool LayerNorm::use_fallback(Stream s) {
return s.device == Device::cpu;
}
// TODO: There are duplicate code with backend/metal/normalization.cpp
void LayerNorm::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("LayerNorm::eval_gpu");
auto& s = stream();
auto& out = outputs[0];
// Make sure that the last dimension is contiguous.
auto set_output = [&s, &out](const array& x) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.is_donatable()) {
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
}
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
};
array o = set_output(inputs[0]);
const array& x = o.data_shared_ptr() ? o : out;
const array& w = inputs[1];
const array& b = inputs[2];
int32_t axis_size = x.shape().back();
int32_t n_rows = x.data_size() / axis_size;
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "layernorm", CTYPE, {
using DataType = cuda_type_t<CTYPE>;
constexpr uint32_t N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
auto kernel = cu::layer_norm<DataType, BLOCK_DIM, N_READS>;
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
x.data<DataType>(),
w.data<DataType>(),
b.data<DataType>(),
out.data<DataType>(),
eps_,
axis_size,
w_stride,
b_stride);
});
});
});
}
void LayerNormVJP::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("LayerNormVJP::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// Ensure row contiguity. We could relax this step by checking that the array
// is contiguous (no broadcasts or holes) and that the input strides are the
// same as the cotangent strides but for now this is simpler.
auto check_input = [&s](const array& x) -> std::pair<array, bool> {
if (x.flags().row_contiguous) {
return {x, false};
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return {x_copy, true};
};
bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[3].is_donatable();
auto [x, copied] = check_input(inputs[0]);
donate_x |= copied;
const array& w = inputs[1];
const array& b = inputs[2];
auto [g, g_copied] = check_input(inputs[3]);
donate_g |= g_copied;
array& gx = outputs[0];
array& gw = outputs[1];
array& gb = outputs[2];
// Check whether we had a weight.
bool has_w = w.ndim() != 0;
// Allocate space for the outputs.
bool g_in_gx = false;
if (donate_x) {
gx.copy_shared_buffer(x);
} else if (donate_g) {
gx.copy_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(allocator::malloc(gx.nbytes()));
}
if (g_copied && !g_in_gx) {
encoder.add_temporary(g);
}
int32_t axis_size = x.shape().back();
int32_t n_rows = x.data_size() / axis_size;
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
// Allocate a temporary to store the gradients for w and allocate the output
// gradient accumulators.
array gw_temp =
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
if (has_w) {
if (!g_in_gx && donate_g) {
gw_temp.copy_shared_buffer(g);
} else {
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
encoder.add_temporary(gw_temp);
}
}
gw.set_data(allocator::malloc(gw.nbytes()));
gb.set_data(allocator::malloc(gb.nbytes()));
// Finish with the gradient for b in case we had a b.
if (gb.ndim() == 1 && gb.size() == axis_size) {
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan);
}
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(g);
encoder.set_output_array(gx);
encoder.set_output_array(gw_temp);
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "layernorm_vjp", CTYPE, {
using DataType = cuda_type_t<CTYPE>;
constexpr int N_READS = 4;
MLX_SWITCH_BOOL(has_w, HAS_W, {
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
auto kernel = cu::layer_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>;
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
x.data<DataType>(),
w.data<DataType>(),
g.data<DataType>(),
gx.data<DataType>(),
gw_temp.data<DataType>(),
eps_,
axis_size,
w_stride);
});
});
});
});
if (has_w) {
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan);
}
}
} // namespace fast
} // namespace mlx::core

View File

@ -0,0 +1,159 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/kernels/cast_op.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
#include <cub/block/block_load.cuh>
#include <cassert>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T>
inline __device__ T softmax_exp(T x) {
// Softmax doesn't need high precision exponential cause x is gonna be in
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
return __expf(x);
}
template <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4>
__global__ void logsumexp(const T* in, T* out, int axis_size) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
in += grid.block_rank() * axis_size;
cg::greater<AccT> max_op;
cg::plus<AccT> plus_op;
// Thread reduce.
AccT prevmax;
AccT maxval = Limits<AccT>::finite_min();
AccT normalizer = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
AccT vals[N_READS];
cub::LoadDirectBlocked(
r * BLOCK_DIM + block.thread_rank(),
make_cast_iterator<AccT>(in),
vals,
axis_size,
Limits<AccT>::min());
prevmax = maxval;
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
// Online normalizer calculation for softmax:
// https://github.com/NVIDIA/online-softmax
normalizer = normalizer * softmax_exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) {
normalizer = normalizer + softmax_exp(vals[i] - maxval);
}
}
// First warp reduce.
prevmax = maxval;
maxval = cg::reduce(warp, maxval, max_op);
normalizer = normalizer * softmax_exp(prevmax - maxval);
normalizer = cg::reduce(warp, normalizer, plus_op);
__shared__ AccT local_max[WARP_SIZE];
__shared__ AccT local_normalizer[WARP_SIZE];
// Write to shared memory and do second warp reduce.
prevmax = maxval;
if (warp.thread_rank() == 0) {
local_max[warp.meta_group_rank()] = maxval;
}
block.sync();
maxval = warp.thread_rank() < warp.meta_group_size()
? local_max[warp.thread_rank()]
: Limits<AccT>::finite_min();
maxval = cg::reduce(warp, maxval, max_op);
normalizer = normalizer * softmax_exp(prevmax - maxval);
if (warp.thread_rank() == 0) {
local_normalizer[warp.meta_group_rank()] = normalizer;
}
block.sync();
normalizer = warp.thread_rank() < warp.meta_group_size()
? local_normalizer[warp.thread_rank()]
: AccT{};
normalizer = cg::reduce(warp, normalizer, plus_op);
// Write output.
if (block.thread_rank() == 0) {
out[grid.block_rank()] = isinf(maxval) ? maxval : log(normalizer) + maxval;
}
}
} // namespace cu
void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("LogSumExp::eval_gpu");
assert(inputs.size() == 1);
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// Make sure that the last dimension is contiguous.
auto ensure_contiguous = [&s, &encoder](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy);
return x_copy;
}
};
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
auto strides = in.strides();
for (auto& s : strides) {
s /= n;
}
bool col_contig = strides[0] == 1;
for (int i = 1; col_contig && i < strides.size(); ++i) {
col_contig &=
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
}
flags.col_contiguous = col_contig;
out.set_data(
allocator::malloc(in.nbytes() / n),
in.data_size() / n,
std::move(strides),
flags);
}
int axis_size = in.shape().back();
int n_rows = in.data_size() / axis_size;
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "logsumexp", CTYPE, {
using DataType = cuda_type_t<CTYPE>;
constexpr int N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
auto kernel = cu::logsumexp<DataType, float, BLOCK_DIM, N_READS>;
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
in.data<DataType>(), out.data<DataType>(), axis_size);
});
});
});
}
} // namespace mlx::core

View File

@ -0,0 +1,11 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/cuda.h"
namespace mlx::core::cu {
bool is_available() {
return false;
}
} // namespace mlx::core::cu

View File

@ -72,7 +72,6 @@ bool fast::ScaledDotProductAttention::use_fallback(
}
NO_GPU(ArgPartition)
NO_GPU(ArgReduce)
NO_GPU(BlockMaskedMM)
NO_GPU_MULTI(Compiled)
NO_GPU(Convolution)
@ -86,18 +85,15 @@ NO_GPU(GatherMM)
NO_GPU(GatherQMM)
NO_GPU(Hadamard)
NO_GPU(Load)
NO_GPU(LogSumExp)
NO_GPU_MULTI(LUF)
NO_GPU(Partition)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(Reduce)
NO_GPU(Scan)
NO_GPU(Scatter)
NO_GPU(ScatterAxis)
NO_GPU(Select)
NO_GPU(SliceUpdate)
NO_GPU(Softmax)
NO_GPU_MULTI(SVD)
NO_GPU(Inverse)
NO_GPU(Cholesky)
@ -105,8 +101,6 @@ NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU_USE_FALLBACK(LayerNorm)
NO_GPU_MULTI(LayerNormVJP)
NO_GPU_USE_FALLBACK(RMSNorm)
NO_GPU_MULTI(RMSNormVJP)
NO_GPU_USE_FALLBACK(RoPE)

View File

@ -0,0 +1,82 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/reduce/reduce.cuh"
#include "mlx/backend/gpu/copy.h"
#include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/fill.h>
#include <cassert>
namespace mlx::core {
void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Reduce::eval_gpu");
assert(inputs.size() == 1);
array in = inputs[0];
// Make sure no identity reductions trickle down here.
assert(!axes_.empty());
assert(out.size() != in.size());
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
// Fill out with init value.
if (in.size() == 0) {
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, {
using InType = cuda_type_t<CTYPE>;
using OutType = cu::ReduceResult<OP, InType>::type;
thrust::fill_n(
cu::thrust_policy(stream),
thrust::device_pointer_cast(out.data<OutType>()),
out.data_size(),
cu::ReduceInit<OP, InType>::value());
});
});
});
return;
}
// Reduce.
ReductionPlan plan = get_reduction_plan(in, axes_);
// If it is a general reduce then copy the input to a contiguous array and
// recompute the plan.
if (plan.type == GeneralReduce) {
array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s);
encoder.add_temporary(in_copy);
in = in_copy;
plan = get_reduction_plan(in, axes_);
}
if ((plan.type == ContiguousAllReduce) ||
(plan.type == ContiguousReduce && plan.shape.size() == 1)) {
segmented_reduce(encoder, in, out, reduce_type_, axes_, plan);
return;
}
if (plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {
row_reduce(encoder, in, out, reduce_type_, axes_, plan);
return;
}
if (plan.type == ContiguousStridedReduce ||
plan.type == GeneralStridedReduce) {
col_reduce(encoder, in, out, reduce_type_, axes_, plan);
return;
}
throw std::runtime_error("No plan reached in reduce.");
}
} // namespace mlx::core

View File

@ -0,0 +1,278 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernels/cast_op.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/block/block_load.cuh>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
struct ColReduceArgs {
// The size of the contiguous column reduction.
size_t reduction_size;
int64_t reduction_stride;
// Input shape and strides excluding the reduction axes.
Shape shape;
Strides strides;
int ndim;
// Input shape and strides of the reduction axes (including last dimension).
Shape reduce_shape;
Strides reduce_strides;
int reduce_ndim;
// The number of column we are reducing. Namely prod(reduce_shape).
size_t non_col_reductions;
ColReduceArgs(
const array& in,
const ReductionPlan& plan,
const std::vector<int>& axes) {
assert(!plan.shape.empty());
reduction_size = plan.shape.back();
reduction_stride = plan.strides.back();
int64_t stride_back = 1;
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
while (!shape_vec.empty() && stride_back < reduction_stride) {
stride_back *= shape_vec.back();
shape_vec.pop_back();
strides_vec.pop_back();
}
std::tie(shape_vec, strides_vec) =
collapse_contiguous_dims(shape_vec, strides_vec);
shape = const_param(shape_vec);
strides = const_param(strides_vec);
ndim = shape_vec.size();
reduce_shape = const_param(plan.shape);
reduce_strides = const_param(plan.strides);
reduce_ndim = plan.shape.size();
non_col_reductions = 1;
for (int i = 0; i < reduce_ndim - 1; i++) {
non_col_reductions *= reduce_shape[i];
}
}
};
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
__global__ void col_reduce_small(
const T* in,
U* out,
const __grid_constant__ ColReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
int column =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
if (column * N_READS >= args.reduction_stride) {
return;
}
int out_idx = grid.block_rank() / grid.dim_blocks().x;
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
Op op;
U totals[N_READS];
for (int i = 0; i < N_READS; i++) {
totals[i] = ReduceInit<Op, T>::value();
}
// Read input to local.
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(
block.thread_index().y,
args.reduce_shape.data(),
args.reduce_strides.data());
for (size_t r = block.thread_index().y;
r < args.non_col_reductions * args.reduction_size;
r += block.dim_threads().y) {
U vals[N_READS];
cub::LoadDirectBlocked(
column,
make_cast_iterator<U>(in + loop.location()),
vals,
args.reduction_stride,
ReduceInit<Op, T>::value());
for (int i = 0; i < N_READS; i++) {
totals[i] = op(vals[i], totals[i]);
}
loop.next(
block.dim_threads().y,
args.reduce_shape.data(),
args.reduce_strides.data());
}
// Do block reduce when each column has more than 1 element to reduce.
if (block.dim_threads().y > 1) {
__shared__ U shared_vals[32 * 8 * N_READS];
size_t col =
block.thread_index().y * block.dim_threads().x + block.thread_index().x;
for (int i = 0; i < N_READS; i++) {
shared_vals[col * N_READS + i] = totals[i];
}
block.sync();
if (block.thread_index().y == 0) {
for (int i = 0; i < N_READS; i++) {
totals[i] = shared_vals[block.thread_index().x * N_READS + i];
}
for (int j = 1; j < block.dim_threads().y; j++) {
col = j * block.dim_threads().x + block.thread_index().x;
for (int i = 0; i < N_READS; i++) {
totals[i] = op(shared_vals[col * N_READS + i], totals[i]);
}
}
}
}
// Write result.
if (block.thread_index().y == 0) {
cub::StoreDirectBlocked(
column,
out + out_idx * args.reduction_stride,
totals,
args.reduction_stride);
}
}
template <
typename T,
typename U,
typename Op,
int NDIM,
int BM,
int BN,
int N_READS = 4>
__global__ void col_reduce_looped(
const T* in,
U* out,
const __grid_constant__ ColReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
constexpr int n_warps = BN / N_READS;
int out_idx = grid.block_rank() / grid.dim_blocks().x;
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
Op op;
U totals[N_READS];
for (int i = 0; i < N_READS; i++) {
totals[i] = ReduceInit<Op, T>::value();
}
// Read input to local.
int r = block.thread_rank() / n_warps;
int column = block.thread_rank() % n_warps;
int in_offset = grid.block_index().x * BN;
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(r, args.reduce_shape.data(), args.reduce_strides.data());
for (; r < args.non_col_reductions * args.reduction_size; r += BM) {
U vals[N_READS];
cub::LoadDirectBlocked(
column,
make_cast_iterator<U>(in + loop.location() + in_offset),
vals,
args.reduction_stride - in_offset,
ReduceInit<Op, T>::value());
for (int i = 0; i < N_READS; i++) {
totals[i] = op(vals[i], totals[i]);
}
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
}
// Do warp reduce for each output.
constexpr int n_outputs = BN / n_warps;
static_assert(BM == 32 && n_outputs == N_READS);
__shared__ U shared_vals[BM * BN];
size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS;
for (int i = 0; i < N_READS; i++) {
shared_vals[col + i] = totals[i];
}
block.sync();
col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
for (int i = 0; i < n_outputs; i++) {
totals[i] = cg::reduce(warp, shared_vals[col + i], op);
}
// Write result.
if (warp.thread_rank() == 0) {
size_t out_offset = grid.block_index().x * BN;
cub::StoreDirectBlocked(
warp.meta_group_rank(),
out + out_idx * args.reduction_stride + out_offset,
totals,
args.reduction_stride - out_offset);
}
}
} // namespace cu
inline auto output_grid_for_col_reduce(
const array& out,
const cu::ColReduceArgs& args) {
auto out_shape = out.shape();
auto out_strides = out.strides();
while (!out_shape.empty() && out_strides.back() < args.reduction_stride) {
out_shape.pop_back();
out_strides.pop_back();
}
return get_2d_grid_dims(out_shape, out_strides);
}
void col_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan) {
cu::ColReduceArgs args(in, plan, axes);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
using InType = cuda_type_t<CTYPE>;
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using OutType = cu::ReduceResult<OP, InType>::type;
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
constexpr int N_READS = 4;
dim3 block_dims;
dim3 num_blocks = output_grid_for_col_reduce(out, args);
num_blocks.z = num_blocks.y;
num_blocks.y = num_blocks.x;
auto kernel =
cu::col_reduce_small<InType, OutType, OP, NDIM, N_READS>;
size_t total = args.non_col_reductions * args.reduction_size;
if (total < 32) {
size_t stride_blocks =
cuda::ceil_div(args.reduction_stride, N_READS);
block_dims.x = std::min(stride_blocks, 32ul);
block_dims.y = std::min(total, 8ul);
num_blocks.x = cuda::ceil_div(stride_blocks, block_dims.x);
} else {
constexpr int BM = 32;
constexpr int BN = 32;
block_dims.x = BM * BN / N_READS;
num_blocks.x = cuda::ceil_div(args.reduction_stride, BN);
kernel = cu::
col_reduce_looped<InType, OutType, OP, NDIM, BM, BN, N_READS>;
}
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>(), out.data<OutType>(), args);
});
});
});
});
}
} // namespace mlx::core

View File

@ -0,0 +1,74 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/reduce.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
#include "mlx/backend/cuda/reduce/reduce_ops.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
// Dispatch dynamic ndim to constexpr.
// The behavior follows get_kernel_reduce_ndim in metal/reduce.cpp file.
#define MLX_SWITCH_REDUCE_NDIM(ndim, NDIM, ...) \
if (ndim == 1) { \
constexpr uint32_t NDIM = 1; \
__VA_ARGS__; \
} else if (ndim == 2) { \
constexpr uint32_t NDIM = 2; \
__VA_ARGS__; \
} else { \
constexpr uint32_t NDIM = 5; \
__VA_ARGS__; \
}
// Dispatch reduce ops to constexpr.
#define MLX_SWITCH_REDUCE_OPS(REDUCE, OP, ...) \
if (REDUCE == Reduce::ReduceType::And) { \
using OP = cu::And; \
__VA_ARGS__; \
} else if (REDUCE == Reduce::ReduceType::Or) { \
using OP = cu::Or; \
__VA_ARGS__; \
} else if (REDUCE == Reduce::ReduceType::Sum) { \
using OP = cu::Sum; \
__VA_ARGS__; \
} else if (REDUCE == Reduce::ReduceType::Prod) { \
using OP = cu::Prod; \
__VA_ARGS__; \
} else if (REDUCE == Reduce::ReduceType::Max) { \
using OP = cu::Max; \
__VA_ARGS__; \
} else if (REDUCE == Reduce::ReduceType::Min) { \
using OP = cu::Min; \
__VA_ARGS__; \
} else { \
throw std::invalid_argument("Unknown reduce type."); \
}
void segmented_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan);
void row_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan);
void col_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan);
} // namespace mlx::core

View File

@ -0,0 +1,144 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/kernels/utils.cuh"
namespace mlx::core::cu {
// Reduce ops.
struct And {
__device__ bool operator()(bool a, bool b) {
return a && b;
}
};
struct Or {
__device__ bool operator()(bool a, bool b) {
return a || b;
}
};
struct Sum {
template <typename T>
__device__ T operator()(T a, T b) {
return a + b;
}
};
struct Prod {
template <typename T>
__device__ T operator()(T a, T b) {
return a * b;
}
};
struct Min {
template <typename T>
__device__ T operator()(T a, T b) {
return a < b ? a : b;
}
};
struct Max {
template <typename T>
__device__ T operator()(T a, T b) {
return a > b ? a : b;
}
};
// Traits to get the result type of reduce op.
template <typename Op, typename T>
struct ReduceResult;
template <typename T>
struct ReduceResult<And, T> {
using type = bool;
};
template <typename T>
struct ReduceResult<Or, T> {
using type = bool;
};
template <typename T>
struct ReduceResult<Sum, T> {
using type = cuda::std::conditional_t<
(cuda::std::is_integral_v<T> && sizeof(T) <= 4),
int32_t,
T>;
};
template <typename T>
struct ReduceResult<Prod, T> {
using type = cuda::std::conditional_t<
(cuda::std::is_integral_v<T> && sizeof(T) <= 4),
int32_t,
T>;
};
template <typename T>
struct ReduceResult<Min, T> {
using type = T;
};
template <typename T>
struct ReduceResult<Max, T> {
using type = T;
};
// Traits to get the init value of reduce op.
template <typename Op, typename T>
struct ReduceInit;
template <typename T>
struct ReduceInit<And, T> {
static constexpr __host__ __device__ bool value() {
return true;
}
};
template <typename T>
struct ReduceInit<Or, T> {
static constexpr __host__ __device__ bool value() {
return false;
}
};
template <typename T>
struct ReduceInit<Sum, T> {
static constexpr __host__ __device__ auto value() {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return T{0, 0};
} else {
return typename ReduceResult<Sum, T>::type{0};
}
}
};
template <typename T>
struct ReduceInit<Prod, T> {
static constexpr __host__ __device__ auto value() {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return T{1, 1};
} else {
return typename ReduceResult<Prod, T>::type{1};
}
}
};
template <typename T>
struct ReduceInit<Min, T> {
static constexpr __host__ __device__ T value() {
return Limits<T>::max();
}
};
template <typename T>
struct ReduceInit<Max, T> {
static constexpr __host__ __device__ T value() {
return Limits<T>::min();
}
};
} // namespace mlx::core::cu

View File

@ -0,0 +1,250 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernels/cast_op.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/block/block_load.cuh>
#include <cub/block/block_reduce.cuh>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
struct RowReduceArgs {
// The size of the row being reduced, i.e. the size of last dimension.
int row_size;
// Input shape and strides excluding the reduction axes.
Shape shape;
Strides strides;
int ndim;
// Input shape and strides of the reduction axes excluding last dimension.
Shape reduce_shape;
Strides reduce_strides;
int reduce_ndim;
// The number of rows we are reducing. Namely prod(reduce_shape).
size_t non_row_reductions;
RowReduceArgs(
const array& in,
const ReductionPlan& plan,
const std::vector<int>& axes) {
assert(!plan.shape.empty());
row_size = plan.shape.back();
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
std::tie(shape_vec, strides_vec) =
collapse_contiguous_dims(shape_vec, strides_vec);
shape = const_param(shape_vec);
strides = const_param(strides_vec);
ndim = shape_vec.size();
reduce_shape = const_param(plan.shape);
reduce_strides = const_param(plan.strides);
reduce_ndim = plan.shape.size() - 1;
non_row_reductions = 1;
for (int i = 0; i < reduce_ndim; i++) {
non_row_reductions *= reduce_shape[i];
}
}
};
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
__global__ void row_reduce_small(
const T* in,
U* out,
size_t out_size,
const __grid_constant__ RowReduceArgs args) {
size_t out_idx = cg::this_grid().thread_rank();
if (out_idx >= out_size) {
return;
}
Op op;
U total_val = ReduceInit<Op, T>::value();
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
for (size_t n = 0; n < args.non_row_reductions; n++) {
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
U vals[N_READS];
cub::LoadDirectBlocked(
r,
make_cast_iterator<U>(in + loop.location()),
vals,
args.row_size,
ReduceInit<Op, T>::value());
total_val = op(total_val, cub::ThreadReduce(vals, op));
}
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
out[out_idx] = total_val;
}
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
__global__ void row_reduce_small_warp(
const T* in,
U* out,
size_t out_size,
const __grid_constant__ RowReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
size_t out_idx = grid.thread_rank() / WARP_SIZE;
if (out_idx >= out_size) {
return;
}
Op op;
U total_val = ReduceInit<Op, T>::value();
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
for (size_t n = warp.thread_rank(); n < args.non_row_reductions;
n += WARP_SIZE) {
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
U vals[N_READS];
cub::LoadDirectBlocked(
r,
make_cast_iterator<U>(in + loop.location()),
vals,
args.row_size,
ReduceInit<Op, T>::value());
total_val = op(total_val, cub::ThreadReduce(vals, op));
}
loop.next(WARP_SIZE, args.reduce_shape.data(), args.reduce_strides.data());
}
total_val = cg::reduce(warp, total_val, op);
if (warp.thread_rank() == 0) {
out[out_idx] = total_val;
}
}
template <
typename T,
typename U,
typename Op,
int NDIM,
int BLOCK_DIM_X,
int N_READS = 4>
__global__ void row_reduce_looped(
const T* in,
U* out,
size_t out_size,
const __grid_constant__ RowReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
size_t out_idx = grid.thread_rank() / BLOCK_DIM_X;
if (out_idx >= out_size) {
return;
}
Op op;
U total_val = ReduceInit<Op, T>::value();
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
for (size_t n = 0; n < args.non_row_reductions; n++) {
for (size_t r = 0; r < cuda::ceil_div(args.row_size, BLOCK_DIM_X * N_READS);
r++) {
U vals[N_READS];
cub::LoadDirectBlocked(
r * BLOCK_DIM_X + block.thread_index().x,
make_cast_iterator<U>(in + loop.location()),
vals,
args.row_size,
ReduceInit<Op, T>::value());
total_val = op(total_val, cub::ThreadReduce(vals, op));
}
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
typedef cub::BlockReduce<U, BLOCK_DIM_X> BlockReduceT;
__shared__ typename BlockReduceT::TempStorage temp;
total_val = BlockReduceT(temp).Reduce(total_val, op);
if (block.thread_rank() == 0) {
out[out_idx] = total_val;
}
}
} // namespace cu
void row_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan) {
cu::RowReduceArgs args(in, plan, axes);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
using InType = cuda_type_t<CTYPE>;
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using OutType = cu::ReduceResult<OP, InType>::type;
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
constexpr size_t N_READS = 4;
dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides());
dim3 block_dims, num_blocks;
auto kernel =
cu::row_reduce_small<InType, OutType, OP, NDIM, N_READS>;
if (args.row_size <= 64) {
if ((args.non_row_reductions < 32 && args.row_size <= 8) ||
(args.non_row_reductions <= 8)) {
block_dims.x = std::min(out_dims.x, 1024u);
num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x);
num_blocks.y = out_dims.y;
} else {
block_dims.x = WARP_SIZE;
num_blocks.y = out_dims.x;
num_blocks.z = out_dims.y;
kernel =
cu::row_reduce_small_warp<InType, OutType, OP, NDIM, N_READS>;
}
} else {
size_t num_threads = cuda::ceil_div(args.row_size, N_READS);
num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE;
MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, {
num_blocks.y = out_dims.x;
num_blocks.z = out_dims.y;
block_dims.x = BLOCK_DIM_X;
kernel = cu::row_reduce_looped<
InType,
OutType,
OP,
NDIM,
BLOCK_DIM_X,
N_READS>;
});
}
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>(), out.data<OutType>(), out.size(), args);
});
});
});
});
}
} // namespace mlx::core

View File

@ -0,0 +1,84 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernels/cast_op.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh"
#include <thrust/device_ptr.h>
#include <cub/device/device_reduce.cuh>
#include <cub/device/device_segmented_reduce.cuh>
namespace mlx::core {
template <typename... Args>
void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) {
// Allocate temporary storage.
size_t size;
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Run op.
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data<void>(), size, args...));
}
template <typename... Args>
void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) {
// Allocate temporary storage.
size_t size;
CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Run op.
CHECK_CUDA_ERROR(
cub::DeviceSegmentedReduce::Reduce(temp.data<void>(), size, args...));
}
struct MultiplyOp {
int factor;
__device__ int operator()(int i) {
return i * factor;
}
};
void segmented_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan) {
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using InType = cuda_type_t<CTYPE>;
using OutType = cu::ReduceResult<OP, InType>::type;
auto in_iter = cu::make_cast_iterator<OutType>(
thrust::device_pointer_cast(in.data<InType>()));
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
auto init = cu::ReduceInit<OP, InType>::value();
if (plan.type == ContiguousAllReduce) {
cub_all_reduce(
encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream);
} else if (plan.type == ContiguousReduce) {
auto offsets = thrust::make_transform_iterator(
thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()});
cub_segmented_reduce(
encoder,
in_iter,
out_ptr,
out.size(),
offsets,
offsets + 1,
OP(),
init,
stream);
} else {
throw std::runtime_error("Unsupported plan in segmented_reduce.");
}
});
});
});
}
} // namespace mlx::core

160
mlx/backend/cuda/softmax.cu Normal file
View File

@ -0,0 +1,160 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/kernels/cast_op.cuh"
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
#include <cub/block/block_load.cuh>
#include <cassert>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T>
inline __device__ T softmax_exp(T x) {
// Softmax doesn't need high precision exponential cause x is gonna be in
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
return __expf(x);
}
template <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4>
__global__ void softmax(const T* in, T* out, int axis_size) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
in += grid.block_rank() * axis_size;
out += grid.block_rank() * axis_size;
cg::greater<AccT> max_op;
cg::plus<AccT> plus_op;
// Thread reduce.
AccT prevmax;
AccT maxval = Limits<AccT>::finite_min();
AccT normalizer = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
AccT vals[N_READS];
cub::LoadDirectBlocked(
r * BLOCK_DIM + block.thread_rank(),
make_cast_iterator<AccT>(in),
vals,
axis_size,
Limits<AccT>::finite_min());
prevmax = maxval;
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
// Online normalizer calculation for softmax:
// https://github.com/NVIDIA/online-softmax
normalizer = normalizer * softmax_exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) {
normalizer = normalizer + softmax_exp(vals[i] - maxval);
}
}
// First warp reduce.
prevmax = maxval;
maxval = cg::reduce(warp, maxval, max_op);
normalizer = normalizer * softmax_exp(prevmax - maxval);
normalizer = cg::reduce(warp, normalizer, plus_op);
__shared__ AccT local_max[WARP_SIZE];
__shared__ AccT local_normalizer[WARP_SIZE];
// Write to shared memory and do second warp reduce.
prevmax = maxval;
if (warp.thread_rank() == 0) {
local_max[warp.meta_group_rank()] = maxval;
}
block.sync();
maxval = warp.thread_rank() < warp.meta_group_size()
? local_max[warp.thread_rank()]
: Limits<AccT>::finite_min();
maxval = cg::reduce(warp, maxval, max_op);
normalizer = normalizer * softmax_exp(prevmax - maxval);
if (warp.thread_rank() == 0) {
local_normalizer[warp.meta_group_rank()] = normalizer;
}
block.sync();
normalizer = warp.thread_rank() < warp.meta_group_size()
? local_normalizer[warp.thread_rank()]
: AccT{};
normalizer = cg::reduce(warp, normalizer, plus_op);
normalizer = 1 / normalizer;
// Write output.
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
auto index = r * BLOCK_DIM + block.thread_rank();
T vals[N_READS];
cub::LoadDirectBlocked(index, in, vals, axis_size);
for (int i = 0; i < N_READS; i++) {
vals[i] = softmax_exp(static_cast<AccT>(vals[i]) - maxval) * normalizer;
}
cub::StoreDirectBlocked(index, out, vals, axis_size);
}
}
} // namespace cu
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Softmax::eval_gpu");
assert(inputs.size() == 1);
auto& s = stream();
// Make sure that the last dimension is contiguous.
auto set_output = [&s, &out](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
if (x.is_donatable()) {
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
}
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
};
array in = set_output(inputs[0]);
bool precise = in.dtype() != float32 && precise_;
int axis_size = in.shape().back();
int n_rows = in.data_size() / axis_size;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "softmax", CTYPE, {
using DataType = cuda_type_t<CTYPE>;
constexpr int N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
auto kernel = cu::softmax<DataType, DataType, BLOCK_DIM, N_READS>;
if (precise) {
kernel = cu::softmax<DataType, float, BLOCK_DIM, N_READS>;
}
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
in.data<DataType>(), out.data<DataType>(), axis_size);
});
});
});
}
} // namespace mlx::core

View File

@ -391,6 +391,7 @@ void implicit_gemm_conv_2D_general_gpu(
// Get channel iteration info
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
int gemm_k_iters = channel_k_iters;
bool align_C = conv_params.C % bk == 0;
// Fix host side helper params
int sign = (conv_params.flip ? -1 : 1);
@ -419,14 +420,33 @@ void implicit_gemm_conv_2D_general_gpu(
/* const int swizzle_log = */ swizzle_log};
// Determine kernel
std::ostringstream kname;
kname << "implicit_gemm_conv_2d_general_" << type_to_name(out) << "_bm" << bm
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
std::string kname;
kname.reserve(64);
concatenate(
kname,
"implicit_gemm_conv_2d_general_",
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn);
std::string hash_name;
hash_name.reserve(64);
concatenate(hash_name, kname, "_alC_", align_C);
metal::MTLFCList func_consts = {
{&align_C, MTL::DataType::DataTypeBool, 200},
};
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel =
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
auto kernel = get_steel_conv_general_kernel(
d, kname, hash_name, func_consts, out, bm, bn, bk, wm, wn);
compute_encoder.set_compute_pipeline_state(kernel);
// Deduce grid launch dimensions
@ -728,8 +748,10 @@ void dispatch_conv_2D_gpu(
// Direct to winograd conv
bool inp_large =
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 4096;
bool channels_large = (conv_params.C + conv_params.O) >= 256;
bool out_large =
(conv_params.N * conv_params.oS[0] * conv_params.oS[1]) >= 256;
if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
@ -743,7 +765,7 @@ void dispatch_conv_2D_gpu(
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
else if ((conv_params.C % 16 == 0 && conv_params.O % 16 == 0) || out_large) {
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
}

View File

@ -727,6 +727,8 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
MTL::ComputePipelineState* get_steel_conv_general_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
int bm,
int bn,
@ -749,7 +751,7 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
wn);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
MTL::ComputePipelineState* get_fft_kernel(

View File

@ -205,6 +205,8 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
MTL::ComputePipelineState* get_steel_conv_general_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
int bm,
int bn,

View File

@ -2,6 +2,8 @@
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
constant bool align_C [[function_constant(200)]];
template <
typename T,
int BM,
@ -118,23 +120,58 @@ implicit_gemm_conv_2d_general(
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
int gemm_k_iterations =
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
if (align_C) {
int gemm_k_iterations =
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
else {
for (int k = 1; k < gemm_params->gemm_k_iterations; k++) {
for (int j = 0; j < base_wh_size * base_ww_size; j++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
const short remaining_k = params->C % BK;
for (int j = 0; j < base_wh_size * base_ww_size; j++) {
// Load elements into threadgroup
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_a.load_safe(remaining_k);
loader_b.load_safe(remaining_k);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
threadgroup_barrier(mem_flags::mem_none);

View File

@ -137,6 +137,52 @@ struct Conv2DInputBlockLoaderGeneral {
}
}
METAL_FUNC void load_safe(const short remaining_k) const {
STEEL_PRAGMA_UNROLL
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
// Find bounds
int n = read_n[i];
int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h;
int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w;
int ih_dil = read_ih[i] + h_flip * params->kdil[0];
int iw_dil = read_iw[i] + w_flip * params->kdil[1];
int ih = ih_dil / params->idil[0];
int iw = iw_dil / params->idil[1];
size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2];
// Read from input if in bounds
if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) &&
(iw_dil >= 0 && iw < params->iS[1])) {
if (bj + vec_size <= remaining_k) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = (src[i])[offset + j];
}
} else {
for (short j = 0; j < vec_size; ++j) {
if (bj + j < remaining_k) {
dst[is * dst_ld + j] = (src[i])[offset + j];
} else {
dst[is * dst_ld + j] = T(0);
}
}
}
}
// Zero pad otherwise
else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = T(0);
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
weight_w += jump_params->f_wgt_jump_w;
@ -262,6 +308,55 @@ struct Conv2DWeightBlockLoaderGeneral {
}
}
METAL_FUNC void load_safe(const short remaining_k) const {
const device T* curr_src = src + weight_h * params->wt_strides[1] +
weight_w * params->wt_strides[2];
if ((start_row + BN <= params->O)) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BN; i += TROWS) {
if (bj + vec_size <= remaining_k) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
}
} else {
for (short j = 0; j < vec_size; j++) {
if (bj + j < remaining_k) {
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
} else {
dst[i * dst_ld + j] = T(0);
}
}
}
}
} else {
for (short i = 0; i < BN; i += TROWS) {
if ((start_row + i) < params->O) {
if (bj + vec_size <= remaining_k) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
}
} else {
for (short j = 0; j < vec_size; j++) {
if (bj + j < remaining_k) {
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
} else {
dst[i * dst_ld + j] = T(0);
}
}
}
} else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
weight_w += jump_params->f_wgt_jump_w;

View File

@ -3,8 +3,11 @@
#include <stdexcept>
#include "mlx/backend/metal/metal.h"
#include "mlx/fast.h"
namespace mlx::core::metal {
namespace mlx::core {
namespace metal {
bool is_available() {
return false;
@ -19,4 +22,21 @@ device_info() {
"[metal::device_info] Cannot get device info without metal backend");
};
} // namespace mlx::core::metal
} // namespace metal
namespace fast {
MetalKernelFunction metal_kernel(
const std::string&,
const std::vector<std::string>&,
const std::vector<std::string>&,
const std::string&,
const std::string&,
bool ensure_row_contiguous,
bool atomic_outputs) {
throw std::runtime_error("[metal_kernel] No GPU back-end.");
}
} // namespace fast
} // namespace mlx::core

View File

@ -244,13 +244,15 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
MTL::ComputePipelineState* get_steel_conv_general_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array&,
int,
int,
int,
int,
int) {
return d.get_kernel(kernel_name);
return d.get_kernel(kernel_name, hash_name, func_consts);
}
MTL::ComputePipelineState* get_fft_kernel(

View File

@ -2,7 +2,6 @@
#include "mlx/primitives.h"
#include "mlx/distributed/primitives.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"
#define NO_GPU_MULTI(func) \
@ -156,18 +155,6 @@ NO_GPU_USE_FALLBACK(RoPE)
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel)
MetalKernelFunction metal_kernel(
const std::string&,
const std::vector<std::string>&,
const std::vector<std::string>&,
const std::string&,
const std::string&,
bool ensure_row_contiguous,
bool atomic_outputs) {
throw std::runtime_error("[metal_kernel] No GPU back-end.");
}
} // namespace fast
namespace distributed {

View File

@ -3,6 +3,7 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/cuda.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/compile.h"
#include "mlx/device.h"

View File

@ -17,10 +17,7 @@
#include "python/src/indexing.h"
#include "python/src/utils.h"
#include "mlx/device.h"
#include "mlx/ops.h"
#include "mlx/transforms.h"
#include "mlx/utils.h"
#include "mlx/mlx.h"
namespace mx = mlx::core;
namespace nb = nanobind;
@ -461,9 +458,12 @@ void init_array(nb::module_& m) {
.def(
"__dlpack_device__",
[](const mx::array& a) {
// See
// https://github.com/dmlc/dlpack/blob/5c210da409e7f1e51ddf445134a4376fdbd70d7d/include/dlpack/dlpack.h#L74
if (mx::metal::is_available()) {
// Metal device is available
return nb::make_tuple(8, 0);
} else if (mx::cu::is_available()) {
return nb::make_tuple(13, 0);
} else {
// CPU device
return nb::make_tuple(1, 0);

View File

@ -58,4 +58,9 @@ void init_device(nb::module_& m) {
&mx::set_default_device,
"device"_a,
R"pbdoc(Set the default device.)pbdoc");
m.def(
"is_available",
&mx::is_available,
"device"_a,
R"pbdoc(Check if a back-end is available for the given device.)pbdoc");
}

View File

@ -198,7 +198,7 @@ class TestInequality(mlx_tests.MLXTestCase):
def test_dlx_device_type(self):
a = mx.array([1, 2, 3])
device_type, device_id = a.__dlpack_device__()
self.assertIn(device_type, [1, 8])
self.assertIn(device_type, [1, 8, 13])
self.assertEqual(device_id, 0)
if device_type == 8:

View File

@ -1173,6 +1173,19 @@ class TestConv(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(out, out_2d.squeeze(2)))
def test_conv2d_unaligned_channels(self):
x = mx.random.uniform(shape=(2, 16, 16, 21))
w = mx.random.uniform(shape=(32, 3, 3, 21))
y = mx.conv2d(x, w, stream=mx.cpu)
y_hat = mx.conv2d(x, w)
self.assertTrue(mx.allclose(y, y_hat))
x = mx.random.uniform(shape=(2, 16, 16, 21))
w = mx.random.uniform(shape=(21, 3, 3, 21))
y = mx.conv2d(x, w, stream=mx.cpu)
y_hat = mx.conv2d(x, w)
self.assertTrue(mx.allclose(y, y_hat))
if __name__ == "__main__":
unittest.main()

View File

@ -10,7 +10,7 @@ import mlx_tests
class TestDefaultDevice(unittest.TestCase):
def test_mlx_default_device(self):
device = mx.default_device()
if mx.metal.is_available():
if mx.is_available(mx.gpu):
self.assertEqual(device, mx.Device(mx.gpu))
self.assertEqual(str(device), "Device(gpu, 0)")
self.assertEqual(device, mx.gpu)
@ -73,7 +73,7 @@ class TestStream(mlx_tests.MLXTestCase):
self.assertEqual(s2.device, mx.default_device())
self.assertNotEqual(s1, s2)
if mx.metal.is_available():
if mx.is_available(mx.gpu):
s_gpu = mx.default_stream(mx.gpu)
self.assertEqual(s_gpu.device, mx.gpu)
else:
@ -86,7 +86,7 @@ class TestStream(mlx_tests.MLXTestCase):
s_cpu = mx.new_stream(mx.cpu)
self.assertEqual(s_cpu.device, mx.cpu)
if mx.metal.is_available():
if mx.is_available(mx.gpu):
s_gpu = mx.new_stream(mx.gpu)
self.assertEqual(s_gpu.device, mx.gpu)
else:
@ -99,7 +99,7 @@ class TestStream(mlx_tests.MLXTestCase):
a = mx.add(x, y, stream=mx.default_stream(mx.default_device()))
if mx.metal.is_available():
if mx.is_available(mx.gpu):
b = mx.add(x, y, stream=mx.default_stream(mx.gpu))
self.assertEqual(a.item(), b.item())
s_gpu = mx.new_stream(mx.gpu)

View File

@ -353,7 +353,7 @@ class TestOptimizers(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
class TestSchedulers(unittest.TestCase):
class TestSchedulers(mlx_tests.MLXTestCase):
def test_decay_lr(self):
for optim_class in optimizers_dict.values():
lr_schedule = opt.step_decay(1e-1, 0.9, 1)