mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 19:11:17 +08:00
Compare commits
7 Commits
557a21e767
...
ebdd22a8d4
Author | SHA1 | Date | |
---|---|---|---|
![]() |
ebdd22a8d4 | ||
![]() |
c371baf53a | ||
![]() |
ccf78f566c | ||
![]() |
c9fa68664a | ||
![]() |
c35f4d089a | ||
![]() |
8590c0941e | ||
![]() |
095163b8d1 |
@ -212,6 +212,29 @@ jobs:
|
||||
METAL_DEBUG_ERROR_MODE=0 \
|
||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
||||
|
||||
cuda_build_and_test:
|
||||
machine:
|
||||
image: linux-cuda-12:default
|
||||
resource_class: gpu.nvidia.small.gen2
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||
python -m venv env
|
||||
source env/bin/activate
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
pip install -e ".[dev]"
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
python_version:
|
||||
@ -348,6 +371,7 @@ workflows:
|
||||
parameters:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test
|
||||
- cuda_build_and_test
|
||||
- build_documentation
|
||||
|
||||
build_pypi_release:
|
||||
@ -455,6 +479,8 @@ workflows:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
- cuda_build_and_test:
|
||||
requires: [ hold ]
|
||||
nightly_build:
|
||||
when:
|
||||
and:
|
||||
|
@ -1,5 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
|
107
benchmarks/python/conv_unaligned_bench.py
Normal file
107
benchmarks/python/conv_unaligned_bench.py
Normal file
@ -0,0 +1,107 @@
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 10
|
||||
N_iter_bench = 100
|
||||
N_iter_func = 5
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
torch.mps.synchronize()
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_2D
|
||||
|
||||
|
||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
return pt_conv_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||
out_pt = torch.conv2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dtype = "float32"
|
||||
shapes = (
|
||||
(4, 32, 32, 21, 3, 3, 128),
|
||||
(4, 32, 32, 21, 3, 3, 37),
|
||||
(4, 32, 32, 370, 3, 3, 370),
|
||||
(4, 32, 32, 370, 7, 7, 128),
|
||||
(2, 320, 640, 21, 7, 7, 21),
|
||||
)
|
||||
for N, H, W, C, kh, kw, O in shapes:
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
@ -55,6 +55,9 @@ endif()
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
||||
else()
|
||||
target_sources(mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
||||
|
@ -6,21 +6,30 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
|
189
mlx/backend/cuda/arg_reduce.cu
Normal file
189
mlx/backend/cuda/arg_reduce.cu
Normal file
@ -0,0 +1,189 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T>
|
||||
struct IndexValPair {
|
||||
uint32_t index;
|
||||
T val;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ArgMin {
|
||||
constexpr __device__ T init() {
|
||||
return Limits<T>::max();
|
||||
}
|
||||
|
||||
__device__ IndexValPair<T> operator()(
|
||||
const IndexValPair<T>& best,
|
||||
const IndexValPair<T>& current) {
|
||||
if (best.val > current.val ||
|
||||
(best.val == current.val && best.index > current.index)) {
|
||||
return current;
|
||||
} else {
|
||||
return best;
|
||||
}
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ IndexValPair<T>
|
||||
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (vals[i] < best.val) {
|
||||
best.val = vals[i];
|
||||
best.index = offset + i;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ArgMax {
|
||||
constexpr __device__ T init() {
|
||||
return Limits<T>::min();
|
||||
}
|
||||
|
||||
__device__ IndexValPair<T> operator()(
|
||||
const IndexValPair<T>& best,
|
||||
const IndexValPair<T>& current) {
|
||||
if (best.val < current.val ||
|
||||
(best.val == current.val && best.index > current.index)) {
|
||||
return current;
|
||||
} else {
|
||||
return best;
|
||||
}
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ IndexValPair<T>
|
||||
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (vals[i] > best.val) {
|
||||
best.val = vals[i];
|
||||
best.index = offset + i;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Op, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void arg_reduce_general(
|
||||
const T* in,
|
||||
uint32_t* out,
|
||||
size_t size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides in_strides,
|
||||
const __grid_constant__ Strides out_strides,
|
||||
int32_t ndim,
|
||||
int64_t axis_stride,
|
||||
int32_t axis_size) {
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
int64_t index = cg::this_grid().block_rank();
|
||||
if (index >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
|
||||
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
|
||||
|
||||
Op op;
|
||||
T init = op.init();
|
||||
IndexValPair<T> best{0, init};
|
||||
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
T vals[N_READS];
|
||||
auto tid = r * BLOCK_DIM + block.thread_index().z;
|
||||
cub::LoadDirectBlocked(
|
||||
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
|
||||
best = op.reduce_many(best, vals, tid * N_READS);
|
||||
}
|
||||
|
||||
typedef cub::BlockReduce<IndexValPair<T>, BLOCK_DIM> BlockReduceT;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
best = BlockReduceT(temp).Reduce(best, op);
|
||||
|
||||
if (block.thread_rank() == 0) {
|
||||
out[out_idx] = best.index;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("ArgReduce::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& s = stream();
|
||||
|
||||
// Prepare the shapes, strides and axis arguments.
|
||||
Shape shape = remove_index(in.shape(), axis_);
|
||||
Strides in_strides = remove_index(in.strides(), axis_);
|
||||
Strides out_strides = out.ndim() == in.ndim()
|
||||
? remove_index(out.strides(), axis_)
|
||||
: out.strides();
|
||||
int64_t axis_stride = in.strides()[axis_];
|
||||
int32_t axis_size = in.shape()[axis_];
|
||||
int32_t ndim = shape.size();
|
||||
|
||||
// ArgReduce.
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_REAL_TYPES_CHECKED(in.dtype(), "ArgReduce", CTYPE, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
constexpr uint32_t N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
||||
dim3 block_dims{1, 1, BLOCK_DIM};
|
||||
auto kernel = &cu::arg_reduce_general<
|
||||
InType,
|
||||
cu::ArgMax<InType>,
|
||||
BLOCK_DIM,
|
||||
N_READS>;
|
||||
if (reduce_type_ == ArgReduce::ArgMin) {
|
||||
kernel = &cu::arg_reduce_general<
|
||||
InType,
|
||||
cu::ArgMin<InType>,
|
||||
BLOCK_DIM,
|
||||
N_READS>;
|
||||
}
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>(),
|
||||
out.data<uint32_t>(),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(in_strides),
|
||||
const_param(out_strides),
|
||||
ndim,
|
||||
axis_stride,
|
||||
axis_size);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
11
mlx/backend/cuda/cuda.cpp
Normal file
11
mlx/backend/cuda/cuda.cpp
Normal file
@ -0,0 +1,11 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
10
mlx/backend/cuda/cuda.h
Normal file
10
mlx/backend/cuda/cuda.h
Normal file
@ -0,0 +1,10 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
/* Check if the CUDA backend is available. */
|
||||
bool is_available();
|
||||
|
||||
} // namespace mlx::core::cu
|
60
mlx/backend/cuda/iterators/strided_iterator.cuh
Normal file
60
mlx/backend/cuda/iterators/strided_iterator.cuh
Normal file
@ -0,0 +1,60 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <thrust/iterator/iterator_adaptor.h>
|
||||
#include <thrust/iterator/iterator_facade.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// RandomAccessIterator for strided access to array entries.
|
||||
template <typename Iterator, typename Stride = int64_t>
|
||||
class strided_iterator
|
||||
: public thrust::
|
||||
iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator> {
|
||||
public:
|
||||
using super_t =
|
||||
thrust::iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator>;
|
||||
|
||||
using reference = typename super_t::reference;
|
||||
using difference_type = typename super_t::difference_type;
|
||||
|
||||
__host__ __device__ strided_iterator(Iterator it, Stride stride)
|
||||
: super_t(it), stride_(stride) {}
|
||||
|
||||
__host__ __device__ Stride stride() const {
|
||||
return stride_;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class thrust::iterator_core_access;
|
||||
|
||||
__host__ __device__ bool equal(const strided_iterator& other) const {
|
||||
return this->base() == other.base();
|
||||
}
|
||||
|
||||
__host__ __device__ void advance(difference_type n) {
|
||||
this->base_reference() += n * stride_;
|
||||
}
|
||||
|
||||
__host__ __device__ void increment() {
|
||||
this->base_reference() += stride_;
|
||||
}
|
||||
|
||||
__host__ __device__ void decrement() {
|
||||
this->base_reference() -= stride_;
|
||||
}
|
||||
|
||||
__host__ __device__ difference_type
|
||||
distance_to(const strided_iterator& other) const {
|
||||
const difference_type dist = other.base() - this->base();
|
||||
_CCCL_ASSERT(
|
||||
dist % stride() == 0,
|
||||
"Underlying iterator difference must be divisible by the stride");
|
||||
return dist / stride();
|
||||
}
|
||||
|
||||
Stride stride_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
@ -47,6 +47,31 @@ namespace mlx::core {
|
||||
__VA_ARGS__; \
|
||||
}
|
||||
|
||||
// Convert a block_dim to constexpr between WARP_SIZE and WARP_SIZE ^ 2.
|
||||
#define MLX_SWITCH_BLOCK_DIM(NUM_THREADS, BLOCK_DIM, ...) \
|
||||
{ \
|
||||
uint32_t _num_threads = NUM_THREADS; \
|
||||
if (_num_threads <= WARP_SIZE) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 2) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 2; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 4) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 4; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 8) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 8; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 16) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 16; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * WARP_SIZE; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
}
|
||||
|
||||
// Maps CPU types to CUDA types.
|
||||
template <typename T>
|
||||
struct CTypeToCudaType {
|
||||
|
@ -9,6 +9,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda/std/array>
|
||||
#include <cuda/std/limits>
|
||||
#include <cuda/std/tuple>
|
||||
@ -19,6 +21,10 @@ namespace mlx::core::cu {
|
||||
// CUDA kernel utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
|
||||
// warpSize variable exists, using it would prevent compile-time optimizations.
|
||||
#define WARP_SIZE 32
|
||||
|
||||
// To pass shape/strides to kernels via constant memory, their size must be
|
||||
// known at compile time.
|
||||
#define MAX_NDIM 8
|
||||
@ -26,6 +32,94 @@ namespace mlx::core::cu {
|
||||
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
|
||||
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Type limits utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename = void>
|
||||
struct Limits {
|
||||
static constexpr __host__ __device__ T max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T min() {
|
||||
return cuda::std::numeric_limits<T>::min();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_min() {
|
||||
return cuda::std::numeric_limits<T>::min();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Limits<
|
||||
T,
|
||||
cuda::std::enable_if_t<
|
||||
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double>>> {
|
||||
static constexpr __host__ __device__ T max() {
|
||||
return cuda::std::numeric_limits<T>::infinity();
|
||||
}
|
||||
static constexpr __host__ __device__ T min() {
|
||||
return -cuda::std::numeric_limits<T>::infinity();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_min() {
|
||||
return cuda::std::numeric_limits<T>::lowest();
|
||||
}
|
||||
};
|
||||
|
||||
// CUDA 11 does not have host side arithmatic operators for half types.
|
||||
template <typename T>
|
||||
struct Limits<
|
||||
T,
|
||||
cuda::std::enable_if_t<
|
||||
cuda::std::is_same_v<T, __half> ||
|
||||
cuda::std::is_same_v<T, __nv_bfloat16>>> {
|
||||
static constexpr __host__ __device__ T max() {
|
||||
return cuda::std::numeric_limits<T>::infinity();
|
||||
}
|
||||
static constexpr __host__ __device__ T min() {
|
||||
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
|
||||
return -cuda::std::numeric_limits<T>::infinity();
|
||||
#else
|
||||
return -cuda::std::numeric_limits<float>::infinity();
|
||||
#endif
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_min() {
|
||||
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
|
||||
return cuda::std::numeric_limits<T>::lowest();
|
||||
#else
|
||||
return cuda::std::numeric_limits<float>::lowest();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Limits<bool> {
|
||||
static constexpr __host__ __device__ bool max() {
|
||||
return true;
|
||||
}
|
||||
static constexpr __host__ __device__ bool min() {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Limits<cuComplex> {
|
||||
static constexpr __host__ __device__ cuComplex max() {
|
||||
return {Limits<float>::max(), Limits<float>::max()};
|
||||
}
|
||||
static constexpr __host__ __device__ cuComplex min() {
|
||||
return {Limits<float>::min(), Limits<float>::min()};
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Indexing utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -101,4 +195,108 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
||||
return cuda::std::make_tuple(a_loc, b_loc);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Elem to loc in a loop utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int DIM, bool General = true, typename OffsetT = size_t>
|
||||
struct LoopedElemToLoc {
|
||||
int dim;
|
||||
LoopedElemToLoc<DIM - 1, General, OffsetT> inner_looper;
|
||||
OffsetT offset{0};
|
||||
int index{0};
|
||||
|
||||
__device__ LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}
|
||||
|
||||
__device__ void next(const int* shape, const int64_t* strides) {
|
||||
if (dim == 0) {
|
||||
return;
|
||||
}
|
||||
index++;
|
||||
offset += OffsetT(strides[dim - 1]);
|
||||
if (index >= shape[dim - 1]) {
|
||||
index = 0;
|
||||
inner_looper.next(shape, strides);
|
||||
offset = inner_looper.offset;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void next(int n, const int* shape, const int64_t* strides) {
|
||||
if (dim == 0) {
|
||||
return;
|
||||
}
|
||||
index += n;
|
||||
offset += n * OffsetT(strides[dim - 1]);
|
||||
|
||||
if (index >= shape[dim - 1]) {
|
||||
int extra = index - shape[dim - 1];
|
||||
if (extra >= shape[dim - 1]) {
|
||||
inner_looper.next(1 + extra / shape[dim - 1], shape, strides);
|
||||
extra = extra % shape[dim - 1];
|
||||
} else {
|
||||
inner_looper.next(shape, strides);
|
||||
}
|
||||
index = 0;
|
||||
offset = inner_looper.offset;
|
||||
if (extra > 0) {
|
||||
next(extra, shape, strides);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OffsetT>
|
||||
struct LoopedElemToLoc<1, true, OffsetT> {
|
||||
int dim;
|
||||
OffsetT offset{0};
|
||||
int index{0};
|
||||
|
||||
__device__ LoopedElemToLoc(int dim) : dim(dim) {}
|
||||
|
||||
__device__ void next(const int* shape, const int64_t* strides) {
|
||||
index++;
|
||||
if (dim > 1) {
|
||||
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
||||
} else {
|
||||
offset += OffsetT(strides[0]);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void next(int n, const int* shape, const int64_t* strides) {
|
||||
index += n;
|
||||
if (dim > 1) {
|
||||
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
||||
} else {
|
||||
offset = index * OffsetT(strides[0]);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OffsetT>
|
||||
struct LoopedElemToLoc<1, false, OffsetT> {
|
||||
OffsetT offset{0};
|
||||
|
||||
__device__ LoopedElemToLoc(int) {}
|
||||
|
||||
__device__ void next(const int*, const int64_t* strides) {
|
||||
offset += OffsetT(strides[0]);
|
||||
}
|
||||
|
||||
__device__ void next(int n, const int*, const int64_t* strides) {
|
||||
offset += n * OffsetT(strides[0]);
|
||||
}
|
||||
|
||||
__device__ OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
390
mlx/backend/cuda/layer_norm.cu
Normal file
390
mlx/backend/cuda/layer_norm.cu
Normal file
@ -0,0 +1,390 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
inline __device__ float3 plus_f3(const float3& a, const float3& b) {
|
||||
return {a.x + b.x, a.y + b.y, a.z + b.z};
|
||||
}
|
||||
|
||||
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
|
||||
template <typename T, int BLOCK_DIM>
|
||||
struct BlockBroadcastReduce {
|
||||
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
|
||||
static_assert(BLOCK_DIM % WARP_SIZE == 0);
|
||||
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
|
||||
|
||||
cg::thread_block& block;
|
||||
TempStorage& temp;
|
||||
|
||||
template <typename Op>
|
||||
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
T x = cg::reduce(warp, input, op);
|
||||
if (warp.thread_rank() == 0) {
|
||||
temp[warp.meta_group_rank()] = x;
|
||||
}
|
||||
block.sync();
|
||||
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
|
||||
: init_value;
|
||||
return cg::reduce(warp, x, op);
|
||||
}
|
||||
|
||||
__device__ T Sum(const T& input) {
|
||||
return Reduce(input, cg::plus<T>{}, T{});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void layer_norm(
|
||||
const T* x,
|
||||
const T* w,
|
||||
const T* b,
|
||||
T* out,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
int64_t w_stride,
|
||||
int64_t b_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
x += grid.block_rank() * axis_size;
|
||||
out += grid.block_rank() * axis_size;
|
||||
|
||||
// Sum.
|
||||
float sum = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS] = {};
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
|
||||
}
|
||||
sum = BlockReduceT{block, temp}.Sum(sum);
|
||||
|
||||
// Mean.
|
||||
float mean = sum / axis_size;
|
||||
|
||||
// Normalizer.
|
||||
float normalizer = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float t = static_cast<float>(xn[i]) - mean;
|
||||
normalizer += t * t;
|
||||
}
|
||||
}
|
||||
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
|
||||
normalizer = rsqrt(normalizer / axis_size + eps);
|
||||
|
||||
// Outputs.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
T bn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
cub::LoadDirectBlocked(index, strided_iterator(b, b_stride), bn, axis_size);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||
xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
|
||||
}
|
||||
cub::StoreDirectBlocked(index, out, xn, axis_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void layer_norm_vjp(
|
||||
const T* x,
|
||||
const T* w,
|
||||
const T* g,
|
||||
T* gx,
|
||||
T* gw,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
int64_t w_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||
using BlockReduceF3 = BlockBroadcastReduce<float3, BLOCK_DIM>;
|
||||
__shared__ union {
|
||||
typename BlockReduceF::TempStorage f;
|
||||
typename BlockReduceF3::TempStorage f3;
|
||||
} temp;
|
||||
|
||||
x += grid.block_rank() * axis_size;
|
||||
g += grid.block_rank() * axis_size;
|
||||
gx += grid.block_rank() * axis_size;
|
||||
gw += grid.block_rank() * axis_size;
|
||||
|
||||
// Sum.
|
||||
float sum = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS] = {};
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
|
||||
}
|
||||
sum = BlockReduceF{block, temp.f}.Sum(sum);
|
||||
|
||||
// Mean.
|
||||
float mean = sum / axis_size;
|
||||
|
||||
// Normalizer.
|
||||
float3 factors = {};
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
T xn[N_READS];
|
||||
T wn[N_READS] = {};
|
||||
T gn[N_READS] = {};
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
|
||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float t = static_cast<float>(xn[i]) - mean;
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
float wg = wi * gi;
|
||||
factors = plus_f3(factors, {wg, wg * t, t * t});
|
||||
}
|
||||
}
|
||||
factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {});
|
||||
float meanwg = factors.x / axis_size;
|
||||
float meanwgxc = factors.y / axis_size;
|
||||
float normalizer2 = 1 / (factors.z / axis_size + eps);
|
||||
float normalizer = sqrt(normalizer2);
|
||||
|
||||
// Outputs.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
T gn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
xn[i] = normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2;
|
||||
if constexpr (HAS_W) {
|
||||
wn[i] = gi * xi;
|
||||
}
|
||||
}
|
||||
cub::StoreDirectBlocked(index, gx, xn, axis_size);
|
||||
if constexpr (HAS_W) {
|
||||
cub::StoreDirectBlocked(index, gw, wn, axis_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace fast {
|
||||
|
||||
bool LayerNorm::use_fallback(Stream s) {
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
// TODO: There are duplicate code with backend/metal/normalization.cpp
|
||||
void LayerNorm::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("LayerNorm::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Make sure that the last dimension is contiguous.
|
||||
auto set_output = [&s, &out](const array& x) {
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
array o = set_output(inputs[0]);
|
||||
const array& x = o.data_shared_ptr() ? o : out;
|
||||
const array& w = inputs[1];
|
||||
const array& b = inputs[2];
|
||||
|
||||
int32_t axis_size = x.shape().back();
|
||||
int32_t n_rows = x.data_size() / axis_size;
|
||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "layernorm", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
constexpr uint32_t N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::layer_norm<DataType, BLOCK_DIM, N_READS>;
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
b.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride,
|
||||
b_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void LayerNormVJP::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("LayerNormVJP::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
auto check_input = [&s](const array& x) -> std::pair<array, bool> {
|
||||
if (x.flags().row_contiguous) {
|
||||
return {x, false};
|
||||
}
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
return {x_copy, true};
|
||||
};
|
||||
bool donate_x = inputs[0].is_donatable();
|
||||
bool donate_g = inputs[3].is_donatable();
|
||||
auto [x, copied] = check_input(inputs[0]);
|
||||
donate_x |= copied;
|
||||
const array& w = inputs[1];
|
||||
const array& b = inputs[2];
|
||||
auto [g, g_copied] = check_input(inputs[3]);
|
||||
donate_g |= g_copied;
|
||||
array& gx = outputs[0];
|
||||
array& gw = outputs[1];
|
||||
array& gb = outputs[2];
|
||||
|
||||
// Check whether we had a weight.
|
||||
bool has_w = w.ndim() != 0;
|
||||
|
||||
// Allocate space for the outputs.
|
||||
bool g_in_gx = false;
|
||||
if (donate_x) {
|
||||
gx.copy_shared_buffer(x);
|
||||
} else if (donate_g) {
|
||||
gx.copy_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc(gx.nbytes()));
|
||||
}
|
||||
if (g_copied && !g_in_gx) {
|
||||
encoder.add_temporary(g);
|
||||
}
|
||||
|
||||
int32_t axis_size = x.shape().back();
|
||||
int32_t n_rows = x.data_size() / axis_size;
|
||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and allocate the output
|
||||
// gradient accumulators.
|
||||
array gw_temp =
|
||||
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
|
||||
if (has_w) {
|
||||
if (!g_in_gx && donate_g) {
|
||||
gw_temp.copy_shared_buffer(g);
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||
encoder.add_temporary(gw_temp);
|
||||
}
|
||||
}
|
||||
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||
gb.set_data(allocator::malloc(gb.nbytes()));
|
||||
|
||||
// Finish with the gradient for b in case we had a b.
|
||||
if (gb.ndim() == 1 && gb.size() == axis_size) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan);
|
||||
}
|
||||
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(g);
|
||||
encoder.set_output_array(gx);
|
||||
encoder.set_output_array(gw_temp);
|
||||
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "layernorm_vjp", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
constexpr int N_READS = 4;
|
||||
MLX_SWITCH_BOOL(has_w, HAS_W, {
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::layer_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>;
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
g.data<DataType>(),
|
||||
gx.data<DataType>(),
|
||||
gw_temp.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
if (has_w) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
159
mlx/backend/cuda/logsumexp.cu
Normal file
159
mlx/backend/cuda/logsumexp.cu
Normal file
@ -0,0 +1,159 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <cub/block/block_load.cuh>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T softmax_exp(T x) {
|
||||
// Softmax doesn't need high precision exponential cause x is gonna be in
|
||||
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
|
||||
return __expf(x);
|
||||
}
|
||||
|
||||
template <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void logsumexp(const T* in, T* out, int axis_size) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
in += grid.block_rank() * axis_size;
|
||||
|
||||
cg::greater<AccT> max_op;
|
||||
cg::plus<AccT> plus_op;
|
||||
|
||||
// Thread reduce.
|
||||
AccT prevmax;
|
||||
AccT maxval = Limits<AccT>::finite_min();
|
||||
AccT normalizer = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||
AccT vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r * BLOCK_DIM + block.thread_rank(),
|
||||
make_cast_iterator<AccT>(in),
|
||||
vals,
|
||||
axis_size,
|
||||
Limits<AccT>::min());
|
||||
prevmax = maxval;
|
||||
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
|
||||
// Online normalizer calculation for softmax:
|
||||
// https://github.com/NVIDIA/online-softmax
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
normalizer = normalizer + softmax_exp(vals[i] - maxval);
|
||||
}
|
||||
}
|
||||
|
||||
// First warp reduce.
|
||||
prevmax = maxval;
|
||||
maxval = cg::reduce(warp, maxval, max_op);
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
normalizer = cg::reduce(warp, normalizer, plus_op);
|
||||
|
||||
__shared__ AccT local_max[WARP_SIZE];
|
||||
__shared__ AccT local_normalizer[WARP_SIZE];
|
||||
|
||||
// Write to shared memory and do second warp reduce.
|
||||
prevmax = maxval;
|
||||
if (warp.thread_rank() == 0) {
|
||||
local_max[warp.meta_group_rank()] = maxval;
|
||||
}
|
||||
block.sync();
|
||||
maxval = warp.thread_rank() < warp.meta_group_size()
|
||||
? local_max[warp.thread_rank()]
|
||||
: Limits<AccT>::finite_min();
|
||||
maxval = cg::reduce(warp, maxval, max_op);
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
if (warp.thread_rank() == 0) {
|
||||
local_normalizer[warp.meta_group_rank()] = normalizer;
|
||||
}
|
||||
block.sync();
|
||||
normalizer = warp.thread_rank() < warp.meta_group_size()
|
||||
? local_normalizer[warp.thread_rank()]
|
||||
: AccT{};
|
||||
normalizer = cg::reduce(warp, normalizer, plus_op);
|
||||
|
||||
// Write output.
|
||||
if (block.thread_rank() == 0) {
|
||||
out[grid.block_rank()] = isinf(maxval) ? maxval : log(normalizer) + maxval;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("LogSumExp::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
// Make sure that the last dimension is contiguous.
|
||||
auto ensure_contiguous = [&s, &encoder](const array& x) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
encoder.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
auto in = ensure_contiguous(inputs[0]);
|
||||
if (in.flags().row_contiguous) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
} else {
|
||||
auto n = in.shape(-1);
|
||||
auto flags = in.flags();
|
||||
auto strides = in.strides();
|
||||
for (auto& s : strides) {
|
||||
s /= n;
|
||||
}
|
||||
bool col_contig = strides[0] == 1;
|
||||
for (int i = 1; col_contig && i < strides.size(); ++i) {
|
||||
col_contig &=
|
||||
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
|
||||
}
|
||||
flags.col_contiguous = col_contig;
|
||||
out.set_data(
|
||||
allocator::malloc(in.nbytes() / n),
|
||||
in.data_size() / n,
|
||||
std::move(strides),
|
||||
flags);
|
||||
}
|
||||
|
||||
int axis_size = in.shape().back();
|
||||
int n_rows = in.data_size() / axis_size;
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "logsumexp", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
constexpr int N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::logsumexp<DataType, float, BLOCK_DIM, N_READS>;
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
in.data<DataType>(), out.data<DataType>(), axis_size);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
11
mlx/backend/cuda/no_cuda.cpp
Normal file
11
mlx/backend/cuda/no_cuda.cpp
Normal file
@ -0,0 +1,11 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
@ -72,7 +72,6 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
||||
}
|
||||
|
||||
NO_GPU(ArgPartition)
|
||||
NO_GPU(ArgReduce)
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU_MULTI(Compiled)
|
||||
NO_GPU(Convolution)
|
||||
@ -86,18 +85,15 @@ NO_GPU(GatherMM)
|
||||
NO_GPU(GatherQMM)
|
||||
NO_GPU(Hadamard)
|
||||
NO_GPU(Load)
|
||||
NO_GPU(LogSumExp)
|
||||
NO_GPU_MULTI(LUF)
|
||||
NO_GPU(Partition)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
NO_GPU(Reduce)
|
||||
NO_GPU(Scan)
|
||||
NO_GPU(Scatter)
|
||||
NO_GPU(ScatterAxis)
|
||||
NO_GPU(Select)
|
||||
NO_GPU(SliceUpdate)
|
||||
NO_GPU(Softmax)
|
||||
NO_GPU_MULTI(SVD)
|
||||
NO_GPU(Inverse)
|
||||
NO_GPU(Cholesky)
|
||||
@ -105,8 +101,6 @@ NO_GPU_MULTI(Eig)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_USE_FALLBACK(LayerNorm)
|
||||
NO_GPU_MULTI(LayerNormVJP)
|
||||
NO_GPU_USE_FALLBACK(RMSNorm)
|
||||
NO_GPU_MULTI(RMSNormVJP)
|
||||
NO_GPU_USE_FALLBACK(RoPE)
|
||||
|
82
mlx/backend/cuda/reduce.cu
Normal file
82
mlx/backend/cuda/reduce.cu
Normal file
@ -0,0 +1,82 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/fill.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Reduce::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
array in = inputs[0];
|
||||
|
||||
// Make sure no identity reductions trickle down here.
|
||||
assert(!axes_.empty());
|
||||
assert(out.size() != in.size());
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
// Fill out with init value.
|
||||
if (in.size() == 0) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
thrust::fill_n(
|
||||
cu::thrust_policy(stream),
|
||||
thrust::device_pointer_cast(out.data<OutType>()),
|
||||
out.data_size(),
|
||||
cu::ReduceInit<OP, InType>::value());
|
||||
});
|
||||
});
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Reduce.
|
||||
ReductionPlan plan = get_reduction_plan(in, axes_);
|
||||
|
||||
// If it is a general reduce then copy the input to a contiguous array and
|
||||
// recompute the plan.
|
||||
if (plan.type == GeneralReduce) {
|
||||
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, in_copy, CopyType::General, s);
|
||||
encoder.add_temporary(in_copy);
|
||||
in = in_copy;
|
||||
plan = get_reduction_plan(in, axes_);
|
||||
}
|
||||
|
||||
if ((plan.type == ContiguousAllReduce) ||
|
||||
(plan.type == ContiguousReduce && plan.shape.size() == 1)) {
|
||||
segmented_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {
|
||||
row_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == ContiguousStridedReduce ||
|
||||
plan.type == GeneralStridedReduce) {
|
||||
col_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
||||
return;
|
||||
}
|
||||
|
||||
throw std::runtime_error("No plan reached in reduce.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
278
mlx/backend/cuda/reduce/col_reduce.cu
Normal file
278
mlx/backend/cuda/reduce/col_reduce.cu
Normal file
@ -0,0 +1,278 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cub/block/block_load.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
struct ColReduceArgs {
|
||||
// The size of the contiguous column reduction.
|
||||
size_t reduction_size;
|
||||
int64_t reduction_stride;
|
||||
|
||||
// Input shape and strides excluding the reduction axes.
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
int ndim;
|
||||
|
||||
// Input shape and strides of the reduction axes (including last dimension).
|
||||
Shape reduce_shape;
|
||||
Strides reduce_strides;
|
||||
int reduce_ndim;
|
||||
|
||||
// The number of column we are reducing. Namely prod(reduce_shape).
|
||||
size_t non_col_reductions;
|
||||
|
||||
ColReduceArgs(
|
||||
const array& in,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes) {
|
||||
assert(!plan.shape.empty());
|
||||
reduction_size = plan.shape.back();
|
||||
reduction_stride = plan.strides.back();
|
||||
|
||||
int64_t stride_back = 1;
|
||||
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
|
||||
while (!shape_vec.empty() && stride_back < reduction_stride) {
|
||||
stride_back *= shape_vec.back();
|
||||
shape_vec.pop_back();
|
||||
strides_vec.pop_back();
|
||||
}
|
||||
std::tie(shape_vec, strides_vec) =
|
||||
collapse_contiguous_dims(shape_vec, strides_vec);
|
||||
shape = const_param(shape_vec);
|
||||
strides = const_param(strides_vec);
|
||||
ndim = shape_vec.size();
|
||||
|
||||
reduce_shape = const_param(plan.shape);
|
||||
reduce_strides = const_param(plan.strides);
|
||||
reduce_ndim = plan.shape.size();
|
||||
|
||||
non_col_reductions = 1;
|
||||
for (int i = 0; i < reduce_ndim - 1; i++) {
|
||||
non_col_reductions *= reduce_shape[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||
__global__ void col_reduce_small(
|
||||
const T* in,
|
||||
U* out,
|
||||
const __grid_constant__ ColReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
int column =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
if (column * N_READS >= args.reduction_stride) {
|
||||
return;
|
||||
}
|
||||
|
||||
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
Op op;
|
||||
U totals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = ReduceInit<Op, T>::value();
|
||||
}
|
||||
|
||||
// Read input to local.
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
loop.next(
|
||||
block.thread_index().y,
|
||||
args.reduce_shape.data(),
|
||||
args.reduce_strides.data());
|
||||
for (size_t r = block.thread_index().y;
|
||||
r < args.non_col_reductions * args.reduction_size;
|
||||
r += block.dim_threads().y) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
column,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
vals,
|
||||
args.reduction_stride,
|
||||
ReduceInit<Op, T>::value());
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
}
|
||||
loop.next(
|
||||
block.dim_threads().y,
|
||||
args.reduce_shape.data(),
|
||||
args.reduce_strides.data());
|
||||
}
|
||||
|
||||
// Do block reduce when each column has more than 1 element to reduce.
|
||||
if (block.dim_threads().y > 1) {
|
||||
__shared__ U shared_vals[32 * 8 * N_READS];
|
||||
size_t col =
|
||||
block.thread_index().y * block.dim_threads().x + block.thread_index().x;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
shared_vals[col * N_READS + i] = totals[i];
|
||||
}
|
||||
block.sync();
|
||||
if (block.thread_index().y == 0) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = shared_vals[block.thread_index().x * N_READS + i];
|
||||
}
|
||||
for (int j = 1; j < block.dim_threads().y; j++) {
|
||||
col = j * block.dim_threads().x + block.thread_index().x;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(shared_vals[col * N_READS + i], totals[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write result.
|
||||
if (block.thread_index().y == 0) {
|
||||
cub::StoreDirectBlocked(
|
||||
column,
|
||||
out + out_idx * args.reduction_stride,
|
||||
totals,
|
||||
args.reduction_stride);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIM,
|
||||
int BM,
|
||||
int BN,
|
||||
int N_READS = 4>
|
||||
__global__ void col_reduce_looped(
|
||||
const T* in,
|
||||
U* out,
|
||||
const __grid_constant__ ColReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
constexpr int n_warps = BN / N_READS;
|
||||
|
||||
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
Op op;
|
||||
U totals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = ReduceInit<Op, T>::value();
|
||||
}
|
||||
|
||||
// Read input to local.
|
||||
int r = block.thread_rank() / n_warps;
|
||||
int column = block.thread_rank() % n_warps;
|
||||
int in_offset = grid.block_index().x * BN;
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
loop.next(r, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
for (; r < args.non_col_reductions * args.reduction_size; r += BM) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
column,
|
||||
make_cast_iterator<U>(in + loop.location() + in_offset),
|
||||
vals,
|
||||
args.reduction_stride - in_offset,
|
||||
ReduceInit<Op, T>::value());
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
}
|
||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
// Do warp reduce for each output.
|
||||
constexpr int n_outputs = BN / n_warps;
|
||||
static_assert(BM == 32 && n_outputs == N_READS);
|
||||
__shared__ U shared_vals[BM * BN];
|
||||
size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
shared_vals[col + i] = totals[i];
|
||||
}
|
||||
block.sync();
|
||||
col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
|
||||
for (int i = 0; i < n_outputs; i++) {
|
||||
totals[i] = cg::reduce(warp, shared_vals[col + i], op);
|
||||
}
|
||||
|
||||
// Write result.
|
||||
if (warp.thread_rank() == 0) {
|
||||
size_t out_offset = grid.block_index().x * BN;
|
||||
cub::StoreDirectBlocked(
|
||||
warp.meta_group_rank(),
|
||||
out + out_idx * args.reduction_stride + out_offset,
|
||||
totals,
|
||||
args.reduction_stride - out_offset);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
inline auto output_grid_for_col_reduce(
|
||||
const array& out,
|
||||
const cu::ColReduceArgs& args) {
|
||||
auto out_shape = out.shape();
|
||||
auto out_strides = out.strides();
|
||||
while (!out_shape.empty() && out_strides.back() < args.reduction_stride) {
|
||||
out_shape.pop_back();
|
||||
out_strides.pop_back();
|
||||
}
|
||||
return get_2d_grid_dims(out_shape, out_strides);
|
||||
}
|
||||
|
||||
void col_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
cu::ColReduceArgs args(in, plan, axes);
|
||||
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||
constexpr int N_READS = 4;
|
||||
dim3 block_dims;
|
||||
dim3 num_blocks = output_grid_for_col_reduce(out, args);
|
||||
num_blocks.z = num_blocks.y;
|
||||
num_blocks.y = num_blocks.x;
|
||||
auto kernel =
|
||||
cu::col_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
||||
size_t total = args.non_col_reductions * args.reduction_size;
|
||||
if (total < 32) {
|
||||
size_t stride_blocks =
|
||||
cuda::ceil_div(args.reduction_stride, N_READS);
|
||||
block_dims.x = std::min(stride_blocks, 32ul);
|
||||
block_dims.y = std::min(total, 8ul);
|
||||
num_blocks.x = cuda::ceil_div(stride_blocks, block_dims.x);
|
||||
} else {
|
||||
constexpr int BM = 32;
|
||||
constexpr int BN = 32;
|
||||
block_dims.x = BM * BN / N_READS;
|
||||
num_blocks.x = cuda::ceil_div(args.reduction_stride, BN);
|
||||
kernel = cu::
|
||||
col_reduce_looped<InType, OutType, OP, NDIM, BM, BN, N_READS>;
|
||||
}
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>(), out.data<OutType>(), args);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
74
mlx/backend/cuda/reduce/reduce.cuh
Normal file
74
mlx/backend/cuda/reduce/reduce.cuh
Normal file
@ -0,0 +1,74 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce_ops.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Dispatch dynamic ndim to constexpr.
|
||||
// The behavior follows get_kernel_reduce_ndim in metal/reduce.cpp file.
|
||||
#define MLX_SWITCH_REDUCE_NDIM(ndim, NDIM, ...) \
|
||||
if (ndim == 1) { \
|
||||
constexpr uint32_t NDIM = 1; \
|
||||
__VA_ARGS__; \
|
||||
} else if (ndim == 2) { \
|
||||
constexpr uint32_t NDIM = 2; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr uint32_t NDIM = 5; \
|
||||
__VA_ARGS__; \
|
||||
}
|
||||
|
||||
// Dispatch reduce ops to constexpr.
|
||||
#define MLX_SWITCH_REDUCE_OPS(REDUCE, OP, ...) \
|
||||
if (REDUCE == Reduce::ReduceType::And) { \
|
||||
using OP = cu::And; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Or) { \
|
||||
using OP = cu::Or; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Sum) { \
|
||||
using OP = cu::Sum; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Prod) { \
|
||||
using OP = cu::Prod; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Max) { \
|
||||
using OP = cu::Max; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Min) { \
|
||||
using OP = cu::Min; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
throw std::invalid_argument("Unknown reduce type."); \
|
||||
}
|
||||
|
||||
void segmented_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan);
|
||||
|
||||
void row_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan);
|
||||
|
||||
void col_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan);
|
||||
|
||||
} // namespace mlx::core
|
144
mlx/backend/cuda/reduce/reduce_ops.cuh
Normal file
144
mlx/backend/cuda/reduce/reduce_ops.cuh
Normal file
@ -0,0 +1,144 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/kernels/utils.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// Reduce ops.
|
||||
struct And {
|
||||
__device__ bool operator()(bool a, bool b) {
|
||||
return a && b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Or {
|
||||
__device__ bool operator()(bool a, bool b) {
|
||||
return a || b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sum {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Prod {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Min {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Max {
|
||||
template <typename T>
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
};
|
||||
|
||||
// Traits to get the result type of reduce op.
|
||||
template <typename Op, typename T>
|
||||
struct ReduceResult;
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<And, T> {
|
||||
using type = bool;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<Or, T> {
|
||||
using type = bool;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<Sum, T> {
|
||||
using type = cuda::std::conditional_t<
|
||||
(cuda::std::is_integral_v<T> && sizeof(T) <= 4),
|
||||
int32_t,
|
||||
T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<Prod, T> {
|
||||
using type = cuda::std::conditional_t<
|
||||
(cuda::std::is_integral_v<T> && sizeof(T) <= 4),
|
||||
int32_t,
|
||||
T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<Min, T> {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceResult<Max, T> {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
// Traits to get the init value of reduce op.
|
||||
template <typename Op, typename T>
|
||||
struct ReduceInit;
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<And, T> {
|
||||
static constexpr __host__ __device__ bool value() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<Or, T> {
|
||||
static constexpr __host__ __device__ bool value() {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<Sum, T> {
|
||||
static constexpr __host__ __device__ auto value() {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return T{0, 0};
|
||||
} else {
|
||||
return typename ReduceResult<Sum, T>::type{0};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<Prod, T> {
|
||||
static constexpr __host__ __device__ auto value() {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return T{1, 1};
|
||||
} else {
|
||||
return typename ReduceResult<Prod, T>::type{1};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<Min, T> {
|
||||
static constexpr __host__ __device__ T value() {
|
||||
return Limits<T>::max();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReduceInit<Max, T> {
|
||||
static constexpr __host__ __device__ T value() {
|
||||
return Limits<T>::min();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
250
mlx/backend/cuda/reduce/row_reduce.cu
Normal file
250
mlx/backend/cuda/reduce/row_reduce.cu
Normal file
@ -0,0 +1,250 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
struct RowReduceArgs {
|
||||
// The size of the row being reduced, i.e. the size of last dimension.
|
||||
int row_size;
|
||||
|
||||
// Input shape and strides excluding the reduction axes.
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
int ndim;
|
||||
|
||||
// Input shape and strides of the reduction axes excluding last dimension.
|
||||
Shape reduce_shape;
|
||||
Strides reduce_strides;
|
||||
int reduce_ndim;
|
||||
|
||||
// The number of rows we are reducing. Namely prod(reduce_shape).
|
||||
size_t non_row_reductions;
|
||||
|
||||
RowReduceArgs(
|
||||
const array& in,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes) {
|
||||
assert(!plan.shape.empty());
|
||||
row_size = plan.shape.back();
|
||||
|
||||
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
|
||||
std::tie(shape_vec, strides_vec) =
|
||||
collapse_contiguous_dims(shape_vec, strides_vec);
|
||||
shape = const_param(shape_vec);
|
||||
strides = const_param(strides_vec);
|
||||
ndim = shape_vec.size();
|
||||
|
||||
reduce_shape = const_param(plan.shape);
|
||||
reduce_strides = const_param(plan.strides);
|
||||
reduce_ndim = plan.shape.size() - 1;
|
||||
|
||||
non_row_reductions = 1;
|
||||
for (int i = 0; i < reduce_ndim; i++) {
|
||||
non_row_reductions *= reduce_shape[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||
__global__ void row_reduce_small(
|
||||
const T* in,
|
||||
U* out,
|
||||
size_t out_size,
|
||||
const __grid_constant__ RowReduceArgs args) {
|
||||
size_t out_idx = cg::this_grid().thread_rank();
|
||||
if (out_idx >= out_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
Op op;
|
||||
|
||||
U total_val = ReduceInit<Op, T>::value();
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
vals,
|
||||
args.row_size,
|
||||
ReduceInit<Op, T>::value());
|
||||
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||
}
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||
__global__ void row_reduce_small_warp(
|
||||
const T* in,
|
||||
U* out,
|
||||
size_t out_size,
|
||||
const __grid_constant__ RowReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
size_t out_idx = grid.thread_rank() / WARP_SIZE;
|
||||
if (out_idx >= out_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
Op op;
|
||||
|
||||
U total_val = ReduceInit<Op, T>::value();
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
for (size_t n = warp.thread_rank(); n < args.non_row_reductions;
|
||||
n += WARP_SIZE) {
|
||||
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
vals,
|
||||
args.row_size,
|
||||
ReduceInit<Op, T>::value());
|
||||
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||
}
|
||||
loop.next(WARP_SIZE, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
total_val = cg::reduce(warp, total_val, op);
|
||||
|
||||
if (warp.thread_rank() == 0) {
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIM,
|
||||
int BLOCK_DIM_X,
|
||||
int N_READS = 4>
|
||||
__global__ void row_reduce_looped(
|
||||
const T* in,
|
||||
U* out,
|
||||
size_t out_size,
|
||||
const __grid_constant__ RowReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
size_t out_idx = grid.thread_rank() / BLOCK_DIM_X;
|
||||
if (out_idx >= out_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
Op op;
|
||||
|
||||
U total_val = ReduceInit<Op, T>::value();
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
for (size_t r = 0; r < cuda::ceil_div(args.row_size, BLOCK_DIM_X * N_READS);
|
||||
r++) {
|
||||
U vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r * BLOCK_DIM_X + block.thread_index().x,
|
||||
make_cast_iterator<U>(in + loop.location()),
|
||||
vals,
|
||||
args.row_size,
|
||||
ReduceInit<Op, T>::value());
|
||||
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||
}
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
typedef cub::BlockReduce<U, BLOCK_DIM_X> BlockReduceT;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
total_val = BlockReduceT(temp).Reduce(total_val, op);
|
||||
|
||||
if (block.thread_rank() == 0) {
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void row_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
cu::RowReduceArgs args(in, plan, axes);
|
||||
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||
constexpr size_t N_READS = 4;
|
||||
dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides());
|
||||
dim3 block_dims, num_blocks;
|
||||
auto kernel =
|
||||
cu::row_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
||||
if (args.row_size <= 64) {
|
||||
if ((args.non_row_reductions < 32 && args.row_size <= 8) ||
|
||||
(args.non_row_reductions <= 8)) {
|
||||
block_dims.x = std::min(out_dims.x, 1024u);
|
||||
num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x);
|
||||
num_blocks.y = out_dims.y;
|
||||
} else {
|
||||
block_dims.x = WARP_SIZE;
|
||||
num_blocks.y = out_dims.x;
|
||||
num_blocks.z = out_dims.y;
|
||||
kernel =
|
||||
cu::row_reduce_small_warp<InType, OutType, OP, NDIM, N_READS>;
|
||||
}
|
||||
} else {
|
||||
size_t num_threads = cuda::ceil_div(args.row_size, N_READS);
|
||||
num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE;
|
||||
MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, {
|
||||
num_blocks.y = out_dims.x;
|
||||
num_blocks.z = out_dims.y;
|
||||
block_dims.x = BLOCK_DIM_X;
|
||||
kernel = cu::row_reduce_looped<
|
||||
InType,
|
||||
OutType,
|
||||
OP,
|
||||
NDIM,
|
||||
BLOCK_DIM_X,
|
||||
N_READS>;
|
||||
});
|
||||
}
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
in.data<InType>(), out.data<OutType>(), out.size(), args);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
84
mlx/backend/cuda/reduce/segmented_reduce.cu
Normal file
84
mlx/backend/cuda/reduce/segmented_reduce.cu
Normal file
@ -0,0 +1,84 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <cub/device/device_reduce.cuh>
|
||||
#include <cub/device/device_segmented_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename... Args>
|
||||
void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
// Allocate temporary storage.
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...));
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
// Run op.
|
||||
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
// Allocate temporary storage.
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...));
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
// Run op.
|
||||
CHECK_CUDA_ERROR(
|
||||
cub::DeviceSegmentedReduce::Reduce(temp.data<void>(), size, args...));
|
||||
}
|
||||
|
||||
struct MultiplyOp {
|
||||
int factor;
|
||||
__device__ int operator()(int i) {
|
||||
return i * factor;
|
||||
}
|
||||
};
|
||||
|
||||
void segmented_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
auto in_iter = cu::make_cast_iterator<OutType>(
|
||||
thrust::device_pointer_cast(in.data<InType>()));
|
||||
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
|
||||
auto init = cu::ReduceInit<OP, InType>::value();
|
||||
|
||||
if (plan.type == ContiguousAllReduce) {
|
||||
cub_all_reduce(
|
||||
encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream);
|
||||
} else if (plan.type == ContiguousReduce) {
|
||||
auto offsets = thrust::make_transform_iterator(
|
||||
thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()});
|
||||
cub_segmented_reduce(
|
||||
encoder,
|
||||
in_iter,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
offsets,
|
||||
offsets + 1,
|
||||
OP(),
|
||||
init,
|
||||
stream);
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported plan in segmented_reduce.");
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
160
mlx/backend/cuda/softmax.cu
Normal file
160
mlx/backend/cuda/softmax.cu
Normal file
@ -0,0 +1,160 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <cub/block/block_load.cuh>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T softmax_exp(T x) {
|
||||
// Softmax doesn't need high precision exponential cause x is gonna be in
|
||||
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
|
||||
return __expf(x);
|
||||
}
|
||||
|
||||
template <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void softmax(const T* in, T* out, int axis_size) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
in += grid.block_rank() * axis_size;
|
||||
out += grid.block_rank() * axis_size;
|
||||
|
||||
cg::greater<AccT> max_op;
|
||||
cg::plus<AccT> plus_op;
|
||||
|
||||
// Thread reduce.
|
||||
AccT prevmax;
|
||||
AccT maxval = Limits<AccT>::finite_min();
|
||||
AccT normalizer = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||
AccT vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
r * BLOCK_DIM + block.thread_rank(),
|
||||
make_cast_iterator<AccT>(in),
|
||||
vals,
|
||||
axis_size,
|
||||
Limits<AccT>::finite_min());
|
||||
prevmax = maxval;
|
||||
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
|
||||
// Online normalizer calculation for softmax:
|
||||
// https://github.com/NVIDIA/online-softmax
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
normalizer = normalizer + softmax_exp(vals[i] - maxval);
|
||||
}
|
||||
}
|
||||
|
||||
// First warp reduce.
|
||||
prevmax = maxval;
|
||||
maxval = cg::reduce(warp, maxval, max_op);
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
normalizer = cg::reduce(warp, normalizer, plus_op);
|
||||
|
||||
__shared__ AccT local_max[WARP_SIZE];
|
||||
__shared__ AccT local_normalizer[WARP_SIZE];
|
||||
|
||||
// Write to shared memory and do second warp reduce.
|
||||
prevmax = maxval;
|
||||
if (warp.thread_rank() == 0) {
|
||||
local_max[warp.meta_group_rank()] = maxval;
|
||||
}
|
||||
block.sync();
|
||||
maxval = warp.thread_rank() < warp.meta_group_size()
|
||||
? local_max[warp.thread_rank()]
|
||||
: Limits<AccT>::finite_min();
|
||||
maxval = cg::reduce(warp, maxval, max_op);
|
||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||
if (warp.thread_rank() == 0) {
|
||||
local_normalizer[warp.meta_group_rank()] = normalizer;
|
||||
}
|
||||
block.sync();
|
||||
normalizer = warp.thread_rank() < warp.meta_group_size()
|
||||
? local_normalizer[warp.thread_rank()]
|
||||
: AccT{};
|
||||
normalizer = cg::reduce(warp, normalizer, plus_op);
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Write output.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlocked(index, in, vals, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = softmax_exp(static_cast<AccT>(vals[i]) - maxval) * normalizer;
|
||||
}
|
||||
cub::StoreDirectBlocked(index, out, vals, axis_size);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Softmax::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
auto& s = stream();
|
||||
|
||||
// Make sure that the last dimension is contiguous.
|
||||
auto set_output = [&s, &out](const array& x) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
array in = set_output(inputs[0]);
|
||||
bool precise = in.dtype() != float32 && precise_;
|
||||
|
||||
int axis_size = in.shape().back();
|
||||
int n_rows = in.data_size() / axis_size;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "softmax", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
constexpr int N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::softmax<DataType, DataType, BLOCK_DIM, N_READS>;
|
||||
if (precise) {
|
||||
kernel = cu::softmax<DataType, float, BLOCK_DIM, N_READS>;
|
||||
}
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
in.data<DataType>(), out.data<DataType>(), axis_size);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -391,6 +391,7 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
// Get channel iteration info
|
||||
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
|
||||
int gemm_k_iters = channel_k_iters;
|
||||
bool align_C = conv_params.C % bk == 0;
|
||||
|
||||
// Fix host side helper params
|
||||
int sign = (conv_params.flip ? -1 : 1);
|
||||
@ -419,14 +420,33 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
/* const int swizzle_log = */ swizzle_log};
|
||||
|
||||
// Determine kernel
|
||||
std::ostringstream kname;
|
||||
kname << "implicit_gemm_conv_2d_general_" << type_to_name(out) << "_bm" << bm
|
||||
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
concatenate(
|
||||
kname,
|
||||
"implicit_gemm_conv_2d_general_",
|
||||
type_to_name(out),
|
||||
"_bm",
|
||||
bm,
|
||||
"_bn",
|
||||
bn,
|
||||
"_bk",
|
||||
bk,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn);
|
||||
std::string hash_name;
|
||||
hash_name.reserve(64);
|
||||
concatenate(hash_name, kname, "_alC_", align_C);
|
||||
metal::MTLFCList func_consts = {
|
||||
{&align_C, MTL::DataType::DataTypeBool, 200},
|
||||
};
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel =
|
||||
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
|
||||
auto kernel = get_steel_conv_general_kernel(
|
||||
d, kname, hash_name, func_consts, out, bm, bn, bk, wm, wn);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Deduce grid launch dimensions
|
||||
@ -728,8 +748,10 @@ void dispatch_conv_2D_gpu(
|
||||
|
||||
// Direct to winograd conv
|
||||
bool inp_large =
|
||||
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
|
||||
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 4096;
|
||||
bool channels_large = (conv_params.C + conv_params.O) >= 256;
|
||||
bool out_large =
|
||||
(conv_params.N * conv_params.oS[0] * conv_params.oS[1]) >= 256;
|
||||
if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one &&
|
||||
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
||||
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
|
||||
@ -743,7 +765,7 @@ void dispatch_conv_2D_gpu(
|
||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
|
||||
else if ((conv_params.C % 16 == 0 && conv_params.O % 16 == 0) || out_large) {
|
||||
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
|
@ -727,6 +727,8 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
int bm,
|
||||
int bn,
|
||||
@ -749,7 +751,7 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
wn);
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_fft_kernel(
|
||||
|
@ -205,6 +205,8 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
int bm,
|
||||
int bn,
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
|
||||
|
||||
constant bool align_C [[function_constant(200)]];
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
@ -118,23 +120,58 @@ implicit_gemm_conv_2d_general(
|
||||
// Prepare threadgroup mma operation
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
int gemm_k_iterations =
|
||||
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
|
||||
if (align_C) {
|
||||
int gemm_k_iterations =
|
||||
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
|
||||
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
}
|
||||
|
||||
else {
|
||||
for (int k = 1; k < gemm_params->gemm_k_iterations; k++) {
|
||||
for (int j = 0; j < base_wh_size * base_ww_size; j++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
}
|
||||
const short remaining_k = params->C % BK;
|
||||
for (int j = 0; j < base_wh_size * base_ww_size; j++) {
|
||||
// Load elements into threadgroup
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_a.load_safe(remaining_k);
|
||||
loader_b.load_safe(remaining_k);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
@ -137,6 +137,52 @@ struct Conv2DInputBlockLoaderGeneral {
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void load_safe(const short remaining_k) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
|
||||
// Find bounds
|
||||
int n = read_n[i];
|
||||
|
||||
int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h;
|
||||
int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w;
|
||||
|
||||
int ih_dil = read_ih[i] + h_flip * params->kdil[0];
|
||||
int iw_dil = read_iw[i] + w_flip * params->kdil[1];
|
||||
|
||||
int ih = ih_dil / params->idil[0];
|
||||
int iw = iw_dil / params->idil[1];
|
||||
|
||||
size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2];
|
||||
|
||||
// Read from input if in bounds
|
||||
if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) &&
|
||||
(iw_dil >= 0 && iw < params->iS[1])) {
|
||||
if (bj + vec_size <= remaining_k) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = (src[i])[offset + j];
|
||||
}
|
||||
} else {
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
if (bj + j < remaining_k) {
|
||||
dst[is * dst_ld + j] = (src[i])[offset + j];
|
||||
} else {
|
||||
dst[is * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Zero pad otherwise
|
||||
else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
weight_w += jump_params->f_wgt_jump_w;
|
||||
@ -262,6 +308,55 @@ struct Conv2DWeightBlockLoaderGeneral {
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void load_safe(const short remaining_k) const {
|
||||
const device T* curr_src = src + weight_h * params->wt_strides[1] +
|
||||
weight_w * params->wt_strides[2];
|
||||
|
||||
if ((start_row + BN <= params->O)) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BN; i += TROWS) {
|
||||
if (bj + vec_size <= remaining_k) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||
}
|
||||
} else {
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
if (bj + j < remaining_k) {
|
||||
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||
} else {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (short i = 0; i < BN; i += TROWS) {
|
||||
if ((start_row + i) < params->O) {
|
||||
if (bj + vec_size <= remaining_k) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||
}
|
||||
} else {
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
if (bj + j < remaining_k) {
|
||||
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||
} else {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
weight_w += jump_params->f_wgt_jump_w;
|
||||
|
@ -3,8 +3,11 @@
|
||||
#include <stdexcept>
|
||||
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/fast.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
namespace mlx::core {
|
||||
|
||||
namespace metal {
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
@ -19,4 +22,21 @@ device_info() {
|
||||
"[metal::device_info] Cannot get device info without metal backend");
|
||||
};
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
} // namespace metal
|
||||
|
||||
namespace fast {
|
||||
|
||||
MetalKernelFunction metal_kernel(
|
||||
const std::string&,
|
||||
const std::vector<std::string>&,
|
||||
const std::vector<std::string>&,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
bool ensure_row_contiguous,
|
||||
bool atomic_outputs) {
|
||||
throw std::runtime_error("[metal_kernel] No GPU back-end.");
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -244,13 +244,15 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int) {
|
||||
return d.get_kernel(kernel_name);
|
||||
return d.get_kernel(kernel_name, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_fft_kernel(
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#define NO_GPU_MULTI(func) \
|
||||
@ -156,18 +155,6 @@ NO_GPU_USE_FALLBACK(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(AffineQuantize)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
|
||||
MetalKernelFunction metal_kernel(
|
||||
const std::string&,
|
||||
const std::vector<std::string>&,
|
||||
const std::vector<std::string>&,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
bool ensure_row_contiguous,
|
||||
bool atomic_outputs) {
|
||||
throw std::runtime_error("[metal_kernel] No GPU back-end.");
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
namespace distributed {
|
||||
|
@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/compile.h"
|
||||
#include "mlx/device.h"
|
||||
|
@ -17,10 +17,7 @@
|
||||
#include "python/src/indexing.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/utils.h"
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
@ -461,9 +458,12 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__dlpack_device__",
|
||||
[](const mx::array& a) {
|
||||
// See
|
||||
// https://github.com/dmlc/dlpack/blob/5c210da409e7f1e51ddf445134a4376fdbd70d7d/include/dlpack/dlpack.h#L74
|
||||
if (mx::metal::is_available()) {
|
||||
// Metal device is available
|
||||
return nb::make_tuple(8, 0);
|
||||
} else if (mx::cu::is_available()) {
|
||||
return nb::make_tuple(13, 0);
|
||||
} else {
|
||||
// CPU device
|
||||
return nb::make_tuple(1, 0);
|
||||
|
@ -58,4 +58,9 @@ void init_device(nb::module_& m) {
|
||||
&mx::set_default_device,
|
||||
"device"_a,
|
||||
R"pbdoc(Set the default device.)pbdoc");
|
||||
m.def(
|
||||
"is_available",
|
||||
&mx::is_available,
|
||||
"device"_a,
|
||||
R"pbdoc(Check if a back-end is available for the given device.)pbdoc");
|
||||
}
|
||||
|
@ -198,7 +198,7 @@ class TestInequality(mlx_tests.MLXTestCase):
|
||||
def test_dlx_device_type(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
device_type, device_id = a.__dlpack_device__()
|
||||
self.assertIn(device_type, [1, 8])
|
||||
self.assertIn(device_type, [1, 8, 13])
|
||||
self.assertEqual(device_id, 0)
|
||||
|
||||
if device_type == 8:
|
||||
|
@ -1173,6 +1173,19 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertTrue(mx.allclose(out, out_2d.squeeze(2)))
|
||||
|
||||
def test_conv2d_unaligned_channels(self):
|
||||
x = mx.random.uniform(shape=(2, 16, 16, 21))
|
||||
w = mx.random.uniform(shape=(32, 3, 3, 21))
|
||||
y = mx.conv2d(x, w, stream=mx.cpu)
|
||||
y_hat = mx.conv2d(x, w)
|
||||
self.assertTrue(mx.allclose(y, y_hat))
|
||||
|
||||
x = mx.random.uniform(shape=(2, 16, 16, 21))
|
||||
w = mx.random.uniform(shape=(21, 3, 3, 21))
|
||||
y = mx.conv2d(x, w, stream=mx.cpu)
|
||||
y_hat = mx.conv2d(x, w)
|
||||
self.assertTrue(mx.allclose(y, y_hat))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -10,7 +10,7 @@ import mlx_tests
|
||||
class TestDefaultDevice(unittest.TestCase):
|
||||
def test_mlx_default_device(self):
|
||||
device = mx.default_device()
|
||||
if mx.metal.is_available():
|
||||
if mx.is_available(mx.gpu):
|
||||
self.assertEqual(device, mx.Device(mx.gpu))
|
||||
self.assertEqual(str(device), "Device(gpu, 0)")
|
||||
self.assertEqual(device, mx.gpu)
|
||||
@ -73,7 +73,7 @@ class TestStream(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(s2.device, mx.default_device())
|
||||
self.assertNotEqual(s1, s2)
|
||||
|
||||
if mx.metal.is_available():
|
||||
if mx.is_available(mx.gpu):
|
||||
s_gpu = mx.default_stream(mx.gpu)
|
||||
self.assertEqual(s_gpu.device, mx.gpu)
|
||||
else:
|
||||
@ -86,7 +86,7 @@ class TestStream(mlx_tests.MLXTestCase):
|
||||
s_cpu = mx.new_stream(mx.cpu)
|
||||
self.assertEqual(s_cpu.device, mx.cpu)
|
||||
|
||||
if mx.metal.is_available():
|
||||
if mx.is_available(mx.gpu):
|
||||
s_gpu = mx.new_stream(mx.gpu)
|
||||
self.assertEqual(s_gpu.device, mx.gpu)
|
||||
else:
|
||||
@ -99,7 +99,7 @@ class TestStream(mlx_tests.MLXTestCase):
|
||||
|
||||
a = mx.add(x, y, stream=mx.default_stream(mx.default_device()))
|
||||
|
||||
if mx.metal.is_available():
|
||||
if mx.is_available(mx.gpu):
|
||||
b = mx.add(x, y, stream=mx.default_stream(mx.gpu))
|
||||
self.assertEqual(a.item(), b.item())
|
||||
s_gpu = mx.new_stream(mx.gpu)
|
||||
|
@ -353,7 +353,7 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
|
||||
|
||||
|
||||
class TestSchedulers(unittest.TestCase):
|
||||
class TestSchedulers(mlx_tests.MLXTestCase):
|
||||
def test_decay_lr(self):
|
||||
for optim_class in optimizers_dict.values():
|
||||
lr_schedule = opt.step_decay(1e-1, 0.9, 1)
|
||||
|
Loading…
Reference in New Issue
Block a user