mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-28 13:11:26 +08:00
Compare commits
7 Commits
557a21e767
...
ebdd22a8d4
Author | SHA1 | Date | |
---|---|---|---|
![]() |
ebdd22a8d4 | ||
![]() |
c371baf53a | ||
![]() |
ccf78f566c | ||
![]() |
c9fa68664a | ||
![]() |
c35f4d089a | ||
![]() |
8590c0941e | ||
![]() |
095163b8d1 |
@ -212,6 +212,29 @@ jobs:
|
|||||||
METAL_DEBUG_ERROR_MODE=0 \
|
METAL_DEBUG_ERROR_MODE=0 \
|
||||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
||||||
|
|
||||||
|
cuda_build_and_test:
|
||||||
|
machine:
|
||||||
|
image: linux-cuda-12:default
|
||||||
|
resource_class: gpu.nvidia.small.gen2
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: Install Python package
|
||||||
|
command: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||||
|
python -m venv env
|
||||||
|
source env/bin/activate
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
|
pip install -e ".[dev]"
|
||||||
|
- run:
|
||||||
|
name: Run Python tests
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||||
|
|
||||||
build_release:
|
build_release:
|
||||||
parameters:
|
parameters:
|
||||||
python_version:
|
python_version:
|
||||||
@ -348,6 +371,7 @@ workflows:
|
|||||||
parameters:
|
parameters:
|
||||||
macosx_deployment_target: ["13.5", "14.0"]
|
macosx_deployment_target: ["13.5", "14.0"]
|
||||||
- linux_build_and_test
|
- linux_build_and_test
|
||||||
|
- cuda_build_and_test
|
||||||
- build_documentation
|
- build_documentation
|
||||||
|
|
||||||
build_pypi_release:
|
build_pypi_release:
|
||||||
@ -455,6 +479,8 @@ workflows:
|
|||||||
macosx_deployment_target: ["13.5", "14.0"]
|
macosx_deployment_target: ["13.5", "14.0"]
|
||||||
- linux_build_and_test:
|
- linux_build_and_test:
|
||||||
requires: [ hold ]
|
requires: [ hold ]
|
||||||
|
- cuda_build_and_test:
|
||||||
|
requires: [ hold ]
|
||||||
nightly_build:
|
nightly_build:
|
||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
107
benchmarks/python/conv_unaligned_bench.py
Normal file
107
benchmarks/python/conv_unaligned_bench.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
N_warmup = 10
|
||||||
|
N_iter_bench = 100
|
||||||
|
N_iter_func = 5
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, a, b):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(a, b)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(a, b)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
def mx_conv_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return mx_conv_2D
|
||||||
|
|
||||||
|
|
||||||
|
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
@torch.no_grad()
|
||||||
|
def pt_conv_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return pt_conv_2D
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||||
|
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||||
|
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||||
|
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
|
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||||
|
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
|
||||||
|
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||||
|
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||||
|
|
||||||
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||||||
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||||
|
|
||||||
|
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||||
|
out_pt = torch.conv2d(
|
||||||
|
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||||
|
out_pt = out_pt.numpy(force=True)
|
||||||
|
|
||||||
|
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx, time_torch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dtype = "float32"
|
||||||
|
shapes = (
|
||||||
|
(4, 32, 32, 21, 3, 3, 128),
|
||||||
|
(4, 32, 32, 21, 3, 3, 37),
|
||||||
|
(4, 32, 32, 370, 3, 3, 370),
|
||||||
|
(4, 32, 32, 370, 7, 7, 128),
|
||||||
|
(2, 320, 640, 21, 7, 7, 21),
|
||||||
|
)
|
||||||
|
for N, H, W, C, kh, kw, O in shapes:
|
||||||
|
time_mlx, time_torch = bench_shape(
|
||||||
|
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
|
||||||
|
)
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
if time_mlx >= 2.0 * time_torch:
|
||||||
|
print("ATTENTION ^^^^^^^")
|
@ -55,6 +55,9 @@ endif()
|
|||||||
|
|
||||||
if(MLX_BUILD_CUDA)
|
if(MLX_BUILD_CUDA)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
||||||
|
else()
|
||||||
|
target_sources(mlx
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
||||||
|
@ -6,21 +6,30 @@
|
|||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
|
189
mlx/backend/cuda/arg_reduce.cu
Normal file
189
mlx/backend/cuda/arg_reduce.cu
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <cub/block/block_load.cuh>
|
||||||
|
#include <cub/block/block_reduce.cuh>
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct IndexValPair {
|
||||||
|
uint32_t index;
|
||||||
|
T val;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ArgMin {
|
||||||
|
constexpr __device__ T init() {
|
||||||
|
return Limits<T>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ IndexValPair<T> operator()(
|
||||||
|
const IndexValPair<T>& best,
|
||||||
|
const IndexValPair<T>& current) {
|
||||||
|
if (best.val > current.val ||
|
||||||
|
(best.val == current.val && best.index > current.index)) {
|
||||||
|
return current;
|
||||||
|
} else {
|
||||||
|
return best;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N>
|
||||||
|
__device__ IndexValPair<T>
|
||||||
|
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
if (vals[i] < best.val) {
|
||||||
|
best.val = vals[i];
|
||||||
|
best.index = offset + i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return best;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ArgMax {
|
||||||
|
constexpr __device__ T init() {
|
||||||
|
return Limits<T>::min();
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ IndexValPair<T> operator()(
|
||||||
|
const IndexValPair<T>& best,
|
||||||
|
const IndexValPair<T>& current) {
|
||||||
|
if (best.val < current.val ||
|
||||||
|
(best.val == current.val && best.index > current.index)) {
|
||||||
|
return current;
|
||||||
|
} else {
|
||||||
|
return best;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N>
|
||||||
|
__device__ IndexValPair<T>
|
||||||
|
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
if (vals[i] > best.val) {
|
||||||
|
best.val = vals[i];
|
||||||
|
best.index = offset + i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return best;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename Op, int BLOCK_DIM, int N_READS = 4>
|
||||||
|
__global__ void arg_reduce_general(
|
||||||
|
const T* in,
|
||||||
|
uint32_t* out,
|
||||||
|
size_t size,
|
||||||
|
const __grid_constant__ Shape shape,
|
||||||
|
const __grid_constant__ Strides in_strides,
|
||||||
|
const __grid_constant__ Strides out_strides,
|
||||||
|
int32_t ndim,
|
||||||
|
int64_t axis_stride,
|
||||||
|
int32_t axis_size) {
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
|
int64_t index = cg::this_grid().block_rank();
|
||||||
|
if (index >= size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
|
||||||
|
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
|
||||||
|
|
||||||
|
Op op;
|
||||||
|
T init = op.init();
|
||||||
|
IndexValPair<T> best{0, init};
|
||||||
|
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
|
T vals[N_READS];
|
||||||
|
auto tid = r * BLOCK_DIM + block.thread_index().z;
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
|
||||||
|
best = op.reduce_many(best, vals, tid * N_READS);
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef cub::BlockReduce<IndexValPair<T>, BLOCK_DIM> BlockReduceT;
|
||||||
|
__shared__ typename BlockReduceT::TempStorage temp;
|
||||||
|
|
||||||
|
best = BlockReduceT(temp).Reduce(best, op);
|
||||||
|
|
||||||
|
if (block.thread_rank() == 0) {
|
||||||
|
out[out_idx] = best.index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("ArgReduce::eval_gpu");
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
auto& s = stream();
|
||||||
|
|
||||||
|
// Prepare the shapes, strides and axis arguments.
|
||||||
|
Shape shape = remove_index(in.shape(), axis_);
|
||||||
|
Strides in_strides = remove_index(in.strides(), axis_);
|
||||||
|
Strides out_strides = out.ndim() == in.ndim()
|
||||||
|
? remove_index(out.strides(), axis_)
|
||||||
|
: out.strides();
|
||||||
|
int64_t axis_stride = in.strides()[axis_];
|
||||||
|
int32_t axis_size = in.shape()[axis_];
|
||||||
|
int32_t ndim = shape.size();
|
||||||
|
|
||||||
|
// ArgReduce.
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_REAL_TYPES_CHECKED(in.dtype(), "ArgReduce", CTYPE, {
|
||||||
|
using InType = cuda_type_t<CTYPE>;
|
||||||
|
constexpr uint32_t N_READS = 4;
|
||||||
|
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||||
|
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
|
dim3 block_dims{1, 1, BLOCK_DIM};
|
||||||
|
auto kernel = &cu::arg_reduce_general<
|
||||||
|
InType,
|
||||||
|
cu::ArgMax<InType>,
|
||||||
|
BLOCK_DIM,
|
||||||
|
N_READS>;
|
||||||
|
if (reduce_type_ == ArgReduce::ArgMin) {
|
||||||
|
kernel = &cu::arg_reduce_general<
|
||||||
|
InType,
|
||||||
|
cu::ArgMin<InType>,
|
||||||
|
BLOCK_DIM,
|
||||||
|
N_READS>;
|
||||||
|
}
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
in.data<InType>(),
|
||||||
|
out.data<uint32_t>(),
|
||||||
|
out.size(),
|
||||||
|
const_param(shape),
|
||||||
|
const_param(in_strides),
|
||||||
|
const_param(out_strides),
|
||||||
|
ndim,
|
||||||
|
axis_stride,
|
||||||
|
axis_size);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
11
mlx/backend/cuda/cuda.cpp
Normal file
11
mlx/backend/cuda/cuda.cpp
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
10
mlx/backend/cuda/cuda.h
Normal file
10
mlx/backend/cuda/cuda.h
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
/* Check if the CUDA backend is available. */
|
||||||
|
bool is_available();
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
60
mlx/backend/cuda/iterators/strided_iterator.cuh
Normal file
60
mlx/backend/cuda/iterators/strided_iterator.cuh
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <thrust/iterator/iterator_adaptor.h>
|
||||||
|
#include <thrust/iterator/iterator_facade.h>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
// RandomAccessIterator for strided access to array entries.
|
||||||
|
template <typename Iterator, typename Stride = int64_t>
|
||||||
|
class strided_iterator
|
||||||
|
: public thrust::
|
||||||
|
iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator> {
|
||||||
|
public:
|
||||||
|
using super_t =
|
||||||
|
thrust::iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator>;
|
||||||
|
|
||||||
|
using reference = typename super_t::reference;
|
||||||
|
using difference_type = typename super_t::difference_type;
|
||||||
|
|
||||||
|
__host__ __device__ strided_iterator(Iterator it, Stride stride)
|
||||||
|
: super_t(it), stride_(stride) {}
|
||||||
|
|
||||||
|
__host__ __device__ Stride stride() const {
|
||||||
|
return stride_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend class thrust::iterator_core_access;
|
||||||
|
|
||||||
|
__host__ __device__ bool equal(const strided_iterator& other) const {
|
||||||
|
return this->base() == other.base();
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ void advance(difference_type n) {
|
||||||
|
this->base_reference() += n * stride_;
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ void increment() {
|
||||||
|
this->base_reference() += stride_;
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ void decrement() {
|
||||||
|
this->base_reference() -= stride_;
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ difference_type
|
||||||
|
distance_to(const strided_iterator& other) const {
|
||||||
|
const difference_type dist = other.base() - this->base();
|
||||||
|
_CCCL_ASSERT(
|
||||||
|
dist % stride() == 0,
|
||||||
|
"Underlying iterator difference must be divisible by the stride");
|
||||||
|
return dist / stride();
|
||||||
|
}
|
||||||
|
|
||||||
|
Stride stride_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
@ -47,6 +47,31 @@ namespace mlx::core {
|
|||||||
__VA_ARGS__; \
|
__VA_ARGS__; \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert a block_dim to constexpr between WARP_SIZE and WARP_SIZE ^ 2.
|
||||||
|
#define MLX_SWITCH_BLOCK_DIM(NUM_THREADS, BLOCK_DIM, ...) \
|
||||||
|
{ \
|
||||||
|
uint32_t _num_threads = NUM_THREADS; \
|
||||||
|
if (_num_threads <= WARP_SIZE) { \
|
||||||
|
constexpr uint32_t BLOCK_DIM = WARP_SIZE; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else if (_num_threads <= WARP_SIZE * 2) { \
|
||||||
|
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 2; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else if (_num_threads <= WARP_SIZE * 4) { \
|
||||||
|
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 4; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else if (_num_threads <= WARP_SIZE * 8) { \
|
||||||
|
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 8; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else if (_num_threads <= WARP_SIZE * 16) { \
|
||||||
|
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 16; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else { \
|
||||||
|
constexpr uint32_t BLOCK_DIM = WARP_SIZE * WARP_SIZE; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
// Maps CPU types to CUDA types.
|
// Maps CPU types to CUDA types.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct CTypeToCudaType {
|
struct CTypeToCudaType {
|
||||||
|
@ -9,6 +9,8 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <cuComplex.h>
|
#include <cuComplex.h>
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
#include <cuda/std/array>
|
#include <cuda/std/array>
|
||||||
#include <cuda/std/limits>
|
#include <cuda/std/limits>
|
||||||
#include <cuda/std/tuple>
|
#include <cuda/std/tuple>
|
||||||
@ -19,6 +21,10 @@ namespace mlx::core::cu {
|
|||||||
// CUDA kernel utils
|
// CUDA kernel utils
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
|
||||||
|
// warpSize variable exists, using it would prevent compile-time optimizations.
|
||||||
|
#define WARP_SIZE 32
|
||||||
|
|
||||||
// To pass shape/strides to kernels via constant memory, their size must be
|
// To pass shape/strides to kernels via constant memory, their size must be
|
||||||
// known at compile time.
|
// known at compile time.
|
||||||
#define MAX_NDIM 8
|
#define MAX_NDIM 8
|
||||||
@ -26,6 +32,94 @@ namespace mlx::core::cu {
|
|||||||
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
|
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
|
||||||
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
|
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Type limits utils
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T, typename = void>
|
||||||
|
struct Limits {
|
||||||
|
static constexpr __host__ __device__ T max() {
|
||||||
|
return cuda::std::numeric_limits<T>::max();
|
||||||
|
}
|
||||||
|
static constexpr __host__ __device__ T min() {
|
||||||
|
return cuda::std::numeric_limits<T>::min();
|
||||||
|
}
|
||||||
|
static constexpr __host__ __device__ T finite_max() {
|
||||||
|
return cuda::std::numeric_limits<T>::max();
|
||||||
|
}
|
||||||
|
static constexpr __host__ __device__ T finite_min() {
|
||||||
|
return cuda::std::numeric_limits<T>::min();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct Limits<
|
||||||
|
T,
|
||||||
|
cuda::std::enable_if_t<
|
||||||
|
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double>>> {
|
||||||
|
static constexpr __host__ __device__ T max() {
|
||||||
|
return cuda::std::numeric_limits<T>::infinity();
|
||||||
|
}
|
||||||
|
static constexpr __host__ __device__ T min() {
|
||||||
|
return -cuda::std::numeric_limits<T>::infinity();
|
||||||
|
}
|
||||||
|
static constexpr __host__ __device__ T finite_max() {
|
||||||
|
return cuda::std::numeric_limits<T>::max();
|
||||||
|
}
|
||||||
|
static constexpr __host__ __device__ T finite_min() {
|
||||||
|
return cuda::std::numeric_limits<T>::lowest();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// CUDA 11 does not have host side arithmatic operators for half types.
|
||||||
|
template <typename T>
|
||||||
|
struct Limits<
|
||||||
|
T,
|
||||||
|
cuda::std::enable_if_t<
|
||||||
|
cuda::std::is_same_v<T, __half> ||
|
||||||
|
cuda::std::is_same_v<T, __nv_bfloat16>>> {
|
||||||
|
static constexpr __host__ __device__ T max() {
|
||||||
|
return cuda::std::numeric_limits<T>::infinity();
|
||||||
|
}
|
||||||
|
static constexpr __host__ __device__ T min() {
|
||||||
|
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
|
||||||
|
return -cuda::std::numeric_limits<T>::infinity();
|
||||||
|
#else
|
||||||
|
return -cuda::std::numeric_limits<float>::infinity();
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
static constexpr __host__ __device__ T finite_max() {
|
||||||
|
return cuda::std::numeric_limits<T>::max();
|
||||||
|
}
|
||||||
|
static constexpr __host__ __device__ T finite_min() {
|
||||||
|
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
|
||||||
|
return cuda::std::numeric_limits<T>::lowest();
|
||||||
|
#else
|
||||||
|
return cuda::std::numeric_limits<float>::lowest();
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct Limits<bool> {
|
||||||
|
static constexpr __host__ __device__ bool max() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
static constexpr __host__ __device__ bool min() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct Limits<cuComplex> {
|
||||||
|
static constexpr __host__ __device__ cuComplex max() {
|
||||||
|
return {Limits<float>::max(), Limits<float>::max()};
|
||||||
|
}
|
||||||
|
static constexpr __host__ __device__ cuComplex min() {
|
||||||
|
return {Limits<float>::min(), Limits<float>::min()};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Indexing utils
|
// Indexing utils
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@ -101,4 +195,108 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
|||||||
return cuda::std::make_tuple(a_loc, b_loc);
|
return cuda::std::make_tuple(a_loc, b_loc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Elem to loc in a loop utils
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <int DIM, bool General = true, typename OffsetT = size_t>
|
||||||
|
struct LoopedElemToLoc {
|
||||||
|
int dim;
|
||||||
|
LoopedElemToLoc<DIM - 1, General, OffsetT> inner_looper;
|
||||||
|
OffsetT offset{0};
|
||||||
|
int index{0};
|
||||||
|
|
||||||
|
__device__ LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}
|
||||||
|
|
||||||
|
__device__ void next(const int* shape, const int64_t* strides) {
|
||||||
|
if (dim == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
index++;
|
||||||
|
offset += OffsetT(strides[dim - 1]);
|
||||||
|
if (index >= shape[dim - 1]) {
|
||||||
|
index = 0;
|
||||||
|
inner_looper.next(shape, strides);
|
||||||
|
offset = inner_looper.offset;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ void next(int n, const int* shape, const int64_t* strides) {
|
||||||
|
if (dim == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
index += n;
|
||||||
|
offset += n * OffsetT(strides[dim - 1]);
|
||||||
|
|
||||||
|
if (index >= shape[dim - 1]) {
|
||||||
|
int extra = index - shape[dim - 1];
|
||||||
|
if (extra >= shape[dim - 1]) {
|
||||||
|
inner_looper.next(1 + extra / shape[dim - 1], shape, strides);
|
||||||
|
extra = extra % shape[dim - 1];
|
||||||
|
} else {
|
||||||
|
inner_looper.next(shape, strides);
|
||||||
|
}
|
||||||
|
index = 0;
|
||||||
|
offset = inner_looper.offset;
|
||||||
|
if (extra > 0) {
|
||||||
|
next(extra, shape, strides);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ OffsetT location() {
|
||||||
|
return offset;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OffsetT>
|
||||||
|
struct LoopedElemToLoc<1, true, OffsetT> {
|
||||||
|
int dim;
|
||||||
|
OffsetT offset{0};
|
||||||
|
int index{0};
|
||||||
|
|
||||||
|
__device__ LoopedElemToLoc(int dim) : dim(dim) {}
|
||||||
|
|
||||||
|
__device__ void next(const int* shape, const int64_t* strides) {
|
||||||
|
index++;
|
||||||
|
if (dim > 1) {
|
||||||
|
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
||||||
|
} else {
|
||||||
|
offset += OffsetT(strides[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ void next(int n, const int* shape, const int64_t* strides) {
|
||||||
|
index += n;
|
||||||
|
if (dim > 1) {
|
||||||
|
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
||||||
|
} else {
|
||||||
|
offset = index * OffsetT(strides[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ OffsetT location() {
|
||||||
|
return offset;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OffsetT>
|
||||||
|
struct LoopedElemToLoc<1, false, OffsetT> {
|
||||||
|
OffsetT offset{0};
|
||||||
|
|
||||||
|
__device__ LoopedElemToLoc(int) {}
|
||||||
|
|
||||||
|
__device__ void next(const int*, const int64_t* strides) {
|
||||||
|
offset += OffsetT(strides[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ void next(int n, const int*, const int64_t* strides) {
|
||||||
|
offset += n * OffsetT(strides[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ OffsetT location() {
|
||||||
|
return offset;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core::cu
|
||||||
|
390
mlx/backend/cuda/layer_norm.cu
Normal file
390
mlx/backend/cuda/layer_norm.cu
Normal file
@ -0,0 +1,390 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/fast_primitives.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cooperative_groups/reduce.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <cub/block/block_load.cuh>
|
||||||
|
#include <cub/block/block_reduce.cuh>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
inline __device__ float3 plus_f3(const float3& a, const float3& b) {
|
||||||
|
return {a.x + b.x, a.y + b.y, a.z + b.z};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
|
||||||
|
template <typename T, int BLOCK_DIM>
|
||||||
|
struct BlockBroadcastReduce {
|
||||||
|
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
|
||||||
|
static_assert(BLOCK_DIM % WARP_SIZE == 0);
|
||||||
|
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
|
||||||
|
|
||||||
|
cg::thread_block& block;
|
||||||
|
TempStorage& temp;
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
|
||||||
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
T x = cg::reduce(warp, input, op);
|
||||||
|
if (warp.thread_rank() == 0) {
|
||||||
|
temp[warp.meta_group_rank()] = x;
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
|
||||||
|
: init_value;
|
||||||
|
return cg::reduce(warp, x, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ T Sum(const T& input) {
|
||||||
|
return Reduce(input, cg::plus<T>{}, T{});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, int BLOCK_DIM, int N_READS = 4>
|
||||||
|
__global__ void layer_norm(
|
||||||
|
const T* x,
|
||||||
|
const T* w,
|
||||||
|
const T* b,
|
||||||
|
T* out,
|
||||||
|
float eps,
|
||||||
|
int32_t axis_size,
|
||||||
|
int64_t w_stride,
|
||||||
|
int64_t b_stride) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
|
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||||
|
__shared__ typename BlockReduceT::TempStorage temp;
|
||||||
|
|
||||||
|
x += grid.block_rank() * axis_size;
|
||||||
|
out += grid.block_rank() * axis_size;
|
||||||
|
|
||||||
|
// Sum.
|
||||||
|
float sum = 0;
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
|
T xn[N_READS] = {};
|
||||||
|
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||||
|
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
|
||||||
|
}
|
||||||
|
sum = BlockReduceT{block, temp}.Sum(sum);
|
||||||
|
|
||||||
|
// Mean.
|
||||||
|
float mean = sum / axis_size;
|
||||||
|
|
||||||
|
// Normalizer.
|
||||||
|
float normalizer = 0;
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
|
T xn[N_READS];
|
||||||
|
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
float t = static_cast<float>(xn[i]) - mean;
|
||||||
|
normalizer += t * t;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
|
||||||
|
normalizer = rsqrt(normalizer / axis_size + eps);
|
||||||
|
|
||||||
|
// Outputs.
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
|
T xn[N_READS];
|
||||||
|
T wn[N_READS];
|
||||||
|
T bn[N_READS];
|
||||||
|
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||||
|
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||||
|
cub::LoadDirectBlocked(index, strided_iterator(b, b_stride), bn, axis_size);
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||||
|
xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
|
||||||
|
}
|
||||||
|
cub::StoreDirectBlocked(index, out, xn, axis_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
|
||||||
|
__global__ void layer_norm_vjp(
|
||||||
|
const T* x,
|
||||||
|
const T* w,
|
||||||
|
const T* g,
|
||||||
|
T* gx,
|
||||||
|
T* gw,
|
||||||
|
float eps,
|
||||||
|
int32_t axis_size,
|
||||||
|
int64_t w_stride) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
|
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||||
|
using BlockReduceF3 = BlockBroadcastReduce<float3, BLOCK_DIM>;
|
||||||
|
__shared__ union {
|
||||||
|
typename BlockReduceF::TempStorage f;
|
||||||
|
typename BlockReduceF3::TempStorage f3;
|
||||||
|
} temp;
|
||||||
|
|
||||||
|
x += grid.block_rank() * axis_size;
|
||||||
|
g += grid.block_rank() * axis_size;
|
||||||
|
gx += grid.block_rank() * axis_size;
|
||||||
|
gw += grid.block_rank() * axis_size;
|
||||||
|
|
||||||
|
// Sum.
|
||||||
|
float sum = 0;
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
|
T xn[N_READS] = {};
|
||||||
|
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||||
|
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
|
||||||
|
}
|
||||||
|
sum = BlockReduceF{block, temp.f}.Sum(sum);
|
||||||
|
|
||||||
|
// Mean.
|
||||||
|
float mean = sum / axis_size;
|
||||||
|
|
||||||
|
// Normalizer.
|
||||||
|
float3 factors = {};
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
|
T xn[N_READS];
|
||||||
|
T wn[N_READS] = {};
|
||||||
|
T gn[N_READS] = {};
|
||||||
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
|
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
|
||||||
|
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||||
|
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
float t = static_cast<float>(xn[i]) - mean;
|
||||||
|
float wi = wn[i];
|
||||||
|
float gi = gn[i];
|
||||||
|
float wg = wi * gi;
|
||||||
|
factors = plus_f3(factors, {wg, wg * t, t * t});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {});
|
||||||
|
float meanwg = factors.x / axis_size;
|
||||||
|
float meanwgxc = factors.y / axis_size;
|
||||||
|
float normalizer2 = 1 / (factors.z / axis_size + eps);
|
||||||
|
float normalizer = sqrt(normalizer2);
|
||||||
|
|
||||||
|
// Outputs.
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
|
T xn[N_READS];
|
||||||
|
T wn[N_READS];
|
||||||
|
T gn[N_READS];
|
||||||
|
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||||
|
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||||
|
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||||
|
float wi = wn[i];
|
||||||
|
float gi = gn[i];
|
||||||
|
xn[i] = normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2;
|
||||||
|
if constexpr (HAS_W) {
|
||||||
|
wn[i] = gi * xi;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cub::StoreDirectBlocked(index, gx, xn, axis_size);
|
||||||
|
if constexpr (HAS_W) {
|
||||||
|
cub::StoreDirectBlocked(index, gw, wn, axis_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
namespace fast {
|
||||||
|
|
||||||
|
bool LayerNorm::use_fallback(Stream s) {
|
||||||
|
return s.device == Device::cpu;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: There are duplicate code with backend/metal/normalization.cpp
|
||||||
|
void LayerNorm::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
nvtx3::scoped_range r("LayerNorm::eval_gpu");
|
||||||
|
auto& s = stream();
|
||||||
|
auto& out = outputs[0];
|
||||||
|
|
||||||
|
// Make sure that the last dimension is contiguous.
|
||||||
|
auto set_output = [&s, &out](const array& x) {
|
||||||
|
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||||
|
if (no_copy && x.ndim() > 1) {
|
||||||
|
auto s = x.strides()[x.ndim() - 2];
|
||||||
|
no_copy &= (s == 0 || s == x.shape().back());
|
||||||
|
}
|
||||||
|
if (no_copy) {
|
||||||
|
if (x.is_donatable()) {
|
||||||
|
out.copy_shared_buffer(x);
|
||||||
|
} else {
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc(x.data_size() * x.itemsize()),
|
||||||
|
x.data_size(),
|
||||||
|
x.strides(),
|
||||||
|
x.flags());
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
} else {
|
||||||
|
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||||
|
copy_gpu(x, x_copy, CopyType::General, s);
|
||||||
|
out.copy_shared_buffer(x_copy);
|
||||||
|
return x_copy;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
array o = set_output(inputs[0]);
|
||||||
|
const array& x = o.data_shared_ptr() ? o : out;
|
||||||
|
const array& w = inputs[1];
|
||||||
|
const array& b = inputs[2];
|
||||||
|
|
||||||
|
int32_t axis_size = x.shape().back();
|
||||||
|
int32_t n_rows = x.data_size() / axis_size;
|
||||||
|
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||||
|
int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(x);
|
||||||
|
encoder.set_input_array(w);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "layernorm", CTYPE, {
|
||||||
|
using DataType = cuda_type_t<CTYPE>;
|
||||||
|
constexpr uint32_t N_READS = 4;
|
||||||
|
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||||
|
auto kernel = cu::layer_norm<DataType, BLOCK_DIM, N_READS>;
|
||||||
|
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||||
|
x.data<DataType>(),
|
||||||
|
w.data<DataType>(),
|
||||||
|
b.data<DataType>(),
|
||||||
|
out.data<DataType>(),
|
||||||
|
eps_,
|
||||||
|
axis_size,
|
||||||
|
w_stride,
|
||||||
|
b_stride);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void LayerNormVJP::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
nvtx3::scoped_range r("LayerNormVJP::eval_gpu");
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
|
// Ensure row contiguity. We could relax this step by checking that the array
|
||||||
|
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||||
|
// same as the cotangent strides but for now this is simpler.
|
||||||
|
auto check_input = [&s](const array& x) -> std::pair<array, bool> {
|
||||||
|
if (x.flags().row_contiguous) {
|
||||||
|
return {x, false};
|
||||||
|
}
|
||||||
|
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||||
|
copy_gpu(x, x_copy, CopyType::General, s);
|
||||||
|
return {x_copy, true};
|
||||||
|
};
|
||||||
|
bool donate_x = inputs[0].is_donatable();
|
||||||
|
bool donate_g = inputs[3].is_donatable();
|
||||||
|
auto [x, copied] = check_input(inputs[0]);
|
||||||
|
donate_x |= copied;
|
||||||
|
const array& w = inputs[1];
|
||||||
|
const array& b = inputs[2];
|
||||||
|
auto [g, g_copied] = check_input(inputs[3]);
|
||||||
|
donate_g |= g_copied;
|
||||||
|
array& gx = outputs[0];
|
||||||
|
array& gw = outputs[1];
|
||||||
|
array& gb = outputs[2];
|
||||||
|
|
||||||
|
// Check whether we had a weight.
|
||||||
|
bool has_w = w.ndim() != 0;
|
||||||
|
|
||||||
|
// Allocate space for the outputs.
|
||||||
|
bool g_in_gx = false;
|
||||||
|
if (donate_x) {
|
||||||
|
gx.copy_shared_buffer(x);
|
||||||
|
} else if (donate_g) {
|
||||||
|
gx.copy_shared_buffer(g);
|
||||||
|
g_in_gx = true;
|
||||||
|
} else {
|
||||||
|
gx.set_data(allocator::malloc(gx.nbytes()));
|
||||||
|
}
|
||||||
|
if (g_copied && !g_in_gx) {
|
||||||
|
encoder.add_temporary(g);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t axis_size = x.shape().back();
|
||||||
|
int32_t n_rows = x.data_size() / axis_size;
|
||||||
|
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||||
|
|
||||||
|
// Allocate a temporary to store the gradients for w and allocate the output
|
||||||
|
// gradient accumulators.
|
||||||
|
array gw_temp =
|
||||||
|
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
|
||||||
|
if (has_w) {
|
||||||
|
if (!g_in_gx && donate_g) {
|
||||||
|
gw_temp.copy_shared_buffer(g);
|
||||||
|
} else {
|
||||||
|
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||||
|
encoder.add_temporary(gw_temp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||||
|
gb.set_data(allocator::malloc(gb.nbytes()));
|
||||||
|
|
||||||
|
// Finish with the gradient for b in case we had a b.
|
||||||
|
if (gb.ndim() == 1 && gb.size() == axis_size) {
|
||||||
|
ReductionPlan plan(
|
||||||
|
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||||
|
col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan);
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.set_input_array(x);
|
||||||
|
encoder.set_input_array(w);
|
||||||
|
encoder.set_input_array(g);
|
||||||
|
encoder.set_output_array(gx);
|
||||||
|
encoder.set_output_array(gw_temp);
|
||||||
|
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "layernorm_vjp", CTYPE, {
|
||||||
|
using DataType = cuda_type_t<CTYPE>;
|
||||||
|
constexpr int N_READS = 4;
|
||||||
|
MLX_SWITCH_BOOL(has_w, HAS_W, {
|
||||||
|
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||||
|
auto kernel = cu::layer_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>;
|
||||||
|
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||||
|
x.data<DataType>(),
|
||||||
|
w.data<DataType>(),
|
||||||
|
g.data<DataType>(),
|
||||||
|
gx.data<DataType>(),
|
||||||
|
gw_temp.data<DataType>(),
|
||||||
|
eps_,
|
||||||
|
axis_size,
|
||||||
|
w_stride);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
if (has_w) {
|
||||||
|
ReductionPlan plan(
|
||||||
|
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||||
|
col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace fast
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
159
mlx/backend/cuda/logsumexp.cu
Normal file
159
mlx/backend/cuda/logsumexp.cu
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cooperative_groups/reduce.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <cub/block/block_load.cuh>
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline __device__ T softmax_exp(T x) {
|
||||||
|
// Softmax doesn't need high precision exponential cause x is gonna be in
|
||||||
|
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
|
||||||
|
return __expf(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4>
|
||||||
|
__global__ void logsumexp(const T* in, T* out, int axis_size) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
|
||||||
|
in += grid.block_rank() * axis_size;
|
||||||
|
|
||||||
|
cg::greater<AccT> max_op;
|
||||||
|
cg::plus<AccT> plus_op;
|
||||||
|
|
||||||
|
// Thread reduce.
|
||||||
|
AccT prevmax;
|
||||||
|
AccT maxval = Limits<AccT>::finite_min();
|
||||||
|
AccT normalizer = 0;
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||||
|
AccT vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
r * BLOCK_DIM + block.thread_rank(),
|
||||||
|
make_cast_iterator<AccT>(in),
|
||||||
|
vals,
|
||||||
|
axis_size,
|
||||||
|
Limits<AccT>::min());
|
||||||
|
prevmax = maxval;
|
||||||
|
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
|
||||||
|
// Online normalizer calculation for softmax:
|
||||||
|
// https://github.com/NVIDIA/online-softmax
|
||||||
|
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
normalizer = normalizer + softmax_exp(vals[i] - maxval);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// First warp reduce.
|
||||||
|
prevmax = maxval;
|
||||||
|
maxval = cg::reduce(warp, maxval, max_op);
|
||||||
|
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||||
|
normalizer = cg::reduce(warp, normalizer, plus_op);
|
||||||
|
|
||||||
|
__shared__ AccT local_max[WARP_SIZE];
|
||||||
|
__shared__ AccT local_normalizer[WARP_SIZE];
|
||||||
|
|
||||||
|
// Write to shared memory and do second warp reduce.
|
||||||
|
prevmax = maxval;
|
||||||
|
if (warp.thread_rank() == 0) {
|
||||||
|
local_max[warp.meta_group_rank()] = maxval;
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
maxval = warp.thread_rank() < warp.meta_group_size()
|
||||||
|
? local_max[warp.thread_rank()]
|
||||||
|
: Limits<AccT>::finite_min();
|
||||||
|
maxval = cg::reduce(warp, maxval, max_op);
|
||||||
|
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||||
|
if (warp.thread_rank() == 0) {
|
||||||
|
local_normalizer[warp.meta_group_rank()] = normalizer;
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
normalizer = warp.thread_rank() < warp.meta_group_size()
|
||||||
|
? local_normalizer[warp.thread_rank()]
|
||||||
|
: AccT{};
|
||||||
|
normalizer = cg::reduce(warp, normalizer, plus_op);
|
||||||
|
|
||||||
|
// Write output.
|
||||||
|
if (block.thread_rank() == 0) {
|
||||||
|
out[grid.block_rank()] = isinf(maxval) ? maxval : log(normalizer) + maxval;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("LogSumExp::eval_gpu");
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
|
// Make sure that the last dimension is contiguous.
|
||||||
|
auto ensure_contiguous = [&s, &encoder](const array& x) {
|
||||||
|
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||||
|
return x;
|
||||||
|
} else {
|
||||||
|
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||||
|
copy_gpu(x, x_copy, CopyType::General, s);
|
||||||
|
encoder.add_temporary(x_copy);
|
||||||
|
return x_copy;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto in = ensure_contiguous(inputs[0]);
|
||||||
|
if (in.flags().row_contiguous) {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
} else {
|
||||||
|
auto n = in.shape(-1);
|
||||||
|
auto flags = in.flags();
|
||||||
|
auto strides = in.strides();
|
||||||
|
for (auto& s : strides) {
|
||||||
|
s /= n;
|
||||||
|
}
|
||||||
|
bool col_contig = strides[0] == 1;
|
||||||
|
for (int i = 1; col_contig && i < strides.size(); ++i) {
|
||||||
|
col_contig &=
|
||||||
|
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
|
||||||
|
}
|
||||||
|
flags.col_contiguous = col_contig;
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc(in.nbytes() / n),
|
||||||
|
in.data_size() / n,
|
||||||
|
std::move(strides),
|
||||||
|
flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
int axis_size = in.shape().back();
|
||||||
|
int n_rows = in.data_size() / axis_size;
|
||||||
|
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "logsumexp", CTYPE, {
|
||||||
|
using DataType = cuda_type_t<CTYPE>;
|
||||||
|
constexpr int N_READS = 4;
|
||||||
|
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||||
|
auto kernel = cu::logsumexp<DataType, float, BLOCK_DIM, N_READS>;
|
||||||
|
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||||
|
in.data<DataType>(), out.data<DataType>(), axis_size);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
11
mlx/backend/cuda/no_cuda.cpp
Normal file
11
mlx/backend/cuda/no_cuda.cpp
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
@ -72,7 +72,6 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
|||||||
}
|
}
|
||||||
|
|
||||||
NO_GPU(ArgPartition)
|
NO_GPU(ArgPartition)
|
||||||
NO_GPU(ArgReduce)
|
|
||||||
NO_GPU(BlockMaskedMM)
|
NO_GPU(BlockMaskedMM)
|
||||||
NO_GPU_MULTI(Compiled)
|
NO_GPU_MULTI(Compiled)
|
||||||
NO_GPU(Convolution)
|
NO_GPU(Convolution)
|
||||||
@ -86,18 +85,15 @@ NO_GPU(GatherMM)
|
|||||||
NO_GPU(GatherQMM)
|
NO_GPU(GatherQMM)
|
||||||
NO_GPU(Hadamard)
|
NO_GPU(Hadamard)
|
||||||
NO_GPU(Load)
|
NO_GPU(Load)
|
||||||
NO_GPU(LogSumExp)
|
|
||||||
NO_GPU_MULTI(LUF)
|
NO_GPU_MULTI(LUF)
|
||||||
NO_GPU(Partition)
|
NO_GPU(Partition)
|
||||||
NO_GPU_MULTI(QRF)
|
NO_GPU_MULTI(QRF)
|
||||||
NO_GPU(QuantizedMatmul)
|
NO_GPU(QuantizedMatmul)
|
||||||
NO_GPU(Reduce)
|
|
||||||
NO_GPU(Scan)
|
NO_GPU(Scan)
|
||||||
NO_GPU(Scatter)
|
NO_GPU(Scatter)
|
||||||
NO_GPU(ScatterAxis)
|
NO_GPU(ScatterAxis)
|
||||||
NO_GPU(Select)
|
NO_GPU(Select)
|
||||||
NO_GPU(SliceUpdate)
|
NO_GPU(SliceUpdate)
|
||||||
NO_GPU(Softmax)
|
|
||||||
NO_GPU_MULTI(SVD)
|
NO_GPU_MULTI(SVD)
|
||||||
NO_GPU(Inverse)
|
NO_GPU(Inverse)
|
||||||
NO_GPU(Cholesky)
|
NO_GPU(Cholesky)
|
||||||
@ -105,8 +101,6 @@ NO_GPU_MULTI(Eig)
|
|||||||
NO_GPU_MULTI(Eigh)
|
NO_GPU_MULTI(Eigh)
|
||||||
|
|
||||||
namespace fast {
|
namespace fast {
|
||||||
NO_GPU_USE_FALLBACK(LayerNorm)
|
|
||||||
NO_GPU_MULTI(LayerNormVJP)
|
|
||||||
NO_GPU_USE_FALLBACK(RMSNorm)
|
NO_GPU_USE_FALLBACK(RMSNorm)
|
||||||
NO_GPU_MULTI(RMSNormVJP)
|
NO_GPU_MULTI(RMSNormVJP)
|
||||||
NO_GPU_USE_FALLBACK(RoPE)
|
NO_GPU_USE_FALLBACK(RoPE)
|
||||||
|
82
mlx/backend/cuda/reduce.cu
Normal file
82
mlx/backend/cuda/reduce.cu
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <thrust/device_ptr.h>
|
||||||
|
#include <thrust/fill.h>
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Reduce::eval_gpu");
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
array in = inputs[0];
|
||||||
|
|
||||||
|
// Make sure no identity reductions trickle down here.
|
||||||
|
assert(!axes_.empty());
|
||||||
|
assert(out.size() != in.size());
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
// Fill out with init value.
|
||||||
|
if (in.size() == 0) {
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, {
|
||||||
|
using InType = cuda_type_t<CTYPE>;
|
||||||
|
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||||
|
thrust::fill_n(
|
||||||
|
cu::thrust_policy(stream),
|
||||||
|
thrust::device_pointer_cast(out.data<OutType>()),
|
||||||
|
out.data_size(),
|
||||||
|
cu::ReduceInit<OP, InType>::value());
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reduce.
|
||||||
|
ReductionPlan plan = get_reduction_plan(in, axes_);
|
||||||
|
|
||||||
|
// If it is a general reduce then copy the input to a contiguous array and
|
||||||
|
// recompute the plan.
|
||||||
|
if (plan.type == GeneralReduce) {
|
||||||
|
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
||||||
|
copy_gpu(in, in_copy, CopyType::General, s);
|
||||||
|
encoder.add_temporary(in_copy);
|
||||||
|
in = in_copy;
|
||||||
|
plan = get_reduction_plan(in, axes_);
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((plan.type == ContiguousAllReduce) ||
|
||||||
|
(plan.type == ContiguousReduce && plan.shape.size() == 1)) {
|
||||||
|
segmented_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {
|
||||||
|
row_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (plan.type == ContiguousStridedReduce ||
|
||||||
|
plan.type == GeneralStridedReduce) {
|
||||||
|
col_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
throw std::runtime_error("No plan reached in reduce.");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
278
mlx/backend/cuda/reduce/col_reduce.cu
Normal file
278
mlx/backend/cuda/reduce/col_reduce.cu
Normal file
@ -0,0 +1,278 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||||
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cooperative_groups/reduce.h>
|
||||||
|
#include <cub/block/block_load.cuh>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
struct ColReduceArgs {
|
||||||
|
// The size of the contiguous column reduction.
|
||||||
|
size_t reduction_size;
|
||||||
|
int64_t reduction_stride;
|
||||||
|
|
||||||
|
// Input shape and strides excluding the reduction axes.
|
||||||
|
Shape shape;
|
||||||
|
Strides strides;
|
||||||
|
int ndim;
|
||||||
|
|
||||||
|
// Input shape and strides of the reduction axes (including last dimension).
|
||||||
|
Shape reduce_shape;
|
||||||
|
Strides reduce_strides;
|
||||||
|
int reduce_ndim;
|
||||||
|
|
||||||
|
// The number of column we are reducing. Namely prod(reduce_shape).
|
||||||
|
size_t non_col_reductions;
|
||||||
|
|
||||||
|
ColReduceArgs(
|
||||||
|
const array& in,
|
||||||
|
const ReductionPlan& plan,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
assert(!plan.shape.empty());
|
||||||
|
reduction_size = plan.shape.back();
|
||||||
|
reduction_stride = plan.strides.back();
|
||||||
|
|
||||||
|
int64_t stride_back = 1;
|
||||||
|
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
|
||||||
|
while (!shape_vec.empty() && stride_back < reduction_stride) {
|
||||||
|
stride_back *= shape_vec.back();
|
||||||
|
shape_vec.pop_back();
|
||||||
|
strides_vec.pop_back();
|
||||||
|
}
|
||||||
|
std::tie(shape_vec, strides_vec) =
|
||||||
|
collapse_contiguous_dims(shape_vec, strides_vec);
|
||||||
|
shape = const_param(shape_vec);
|
||||||
|
strides = const_param(strides_vec);
|
||||||
|
ndim = shape_vec.size();
|
||||||
|
|
||||||
|
reduce_shape = const_param(plan.shape);
|
||||||
|
reduce_strides = const_param(plan.strides);
|
||||||
|
reduce_ndim = plan.shape.size();
|
||||||
|
|
||||||
|
non_col_reductions = 1;
|
||||||
|
for (int i = 0; i < reduce_ndim - 1; i++) {
|
||||||
|
non_col_reductions *= reduce_shape[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||||
|
__global__ void col_reduce_small(
|
||||||
|
const T* in,
|
||||||
|
U* out,
|
||||||
|
const __grid_constant__ ColReduceArgs args) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
|
int column =
|
||||||
|
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||||
|
if (column * N_READS >= args.reduction_stride) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
||||||
|
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||||
|
|
||||||
|
Op op;
|
||||||
|
U totals[N_READS];
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
totals[i] = ReduceInit<Op, T>::value();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read input to local.
|
||||||
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
|
loop.next(
|
||||||
|
block.thread_index().y,
|
||||||
|
args.reduce_shape.data(),
|
||||||
|
args.reduce_strides.data());
|
||||||
|
for (size_t r = block.thread_index().y;
|
||||||
|
r < args.non_col_reductions * args.reduction_size;
|
||||||
|
r += block.dim_threads().y) {
|
||||||
|
U vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
column,
|
||||||
|
make_cast_iterator<U>(in + loop.location()),
|
||||||
|
vals,
|
||||||
|
args.reduction_stride,
|
||||||
|
ReduceInit<Op, T>::value());
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
totals[i] = op(vals[i], totals[i]);
|
||||||
|
}
|
||||||
|
loop.next(
|
||||||
|
block.dim_threads().y,
|
||||||
|
args.reduce_shape.data(),
|
||||||
|
args.reduce_strides.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do block reduce when each column has more than 1 element to reduce.
|
||||||
|
if (block.dim_threads().y > 1) {
|
||||||
|
__shared__ U shared_vals[32 * 8 * N_READS];
|
||||||
|
size_t col =
|
||||||
|
block.thread_index().y * block.dim_threads().x + block.thread_index().x;
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
shared_vals[col * N_READS + i] = totals[i];
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
if (block.thread_index().y == 0) {
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
totals[i] = shared_vals[block.thread_index().x * N_READS + i];
|
||||||
|
}
|
||||||
|
for (int j = 1; j < block.dim_threads().y; j++) {
|
||||||
|
col = j * block.dim_threads().x + block.thread_index().x;
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
totals[i] = op(shared_vals[col * N_READS + i], totals[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write result.
|
||||||
|
if (block.thread_index().y == 0) {
|
||||||
|
cub::StoreDirectBlocked(
|
||||||
|
column,
|
||||||
|
out + out_idx * args.reduction_stride,
|
||||||
|
totals,
|
||||||
|
args.reduction_stride);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
typename Op,
|
||||||
|
int NDIM,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int N_READS = 4>
|
||||||
|
__global__ void col_reduce_looped(
|
||||||
|
const T* in,
|
||||||
|
U* out,
|
||||||
|
const __grid_constant__ ColReduceArgs args) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
|
||||||
|
constexpr int n_warps = BN / N_READS;
|
||||||
|
|
||||||
|
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
||||||
|
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||||
|
|
||||||
|
Op op;
|
||||||
|
U totals[N_READS];
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
totals[i] = ReduceInit<Op, T>::value();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read input to local.
|
||||||
|
int r = block.thread_rank() / n_warps;
|
||||||
|
int column = block.thread_rank() % n_warps;
|
||||||
|
int in_offset = grid.block_index().x * BN;
|
||||||
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
|
loop.next(r, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
|
for (; r < args.non_col_reductions * args.reduction_size; r += BM) {
|
||||||
|
U vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
column,
|
||||||
|
make_cast_iterator<U>(in + loop.location() + in_offset),
|
||||||
|
vals,
|
||||||
|
args.reduction_stride - in_offset,
|
||||||
|
ReduceInit<Op, T>::value());
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
totals[i] = op(vals[i], totals[i]);
|
||||||
|
}
|
||||||
|
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do warp reduce for each output.
|
||||||
|
constexpr int n_outputs = BN / n_warps;
|
||||||
|
static_assert(BM == 32 && n_outputs == N_READS);
|
||||||
|
__shared__ U shared_vals[BM * BN];
|
||||||
|
size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS;
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
shared_vals[col + i] = totals[i];
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
|
||||||
|
for (int i = 0; i < n_outputs; i++) {
|
||||||
|
totals[i] = cg::reduce(warp, shared_vals[col + i], op);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write result.
|
||||||
|
if (warp.thread_rank() == 0) {
|
||||||
|
size_t out_offset = grid.block_index().x * BN;
|
||||||
|
cub::StoreDirectBlocked(
|
||||||
|
warp.meta_group_rank(),
|
||||||
|
out + out_idx * args.reduction_stride + out_offset,
|
||||||
|
totals,
|
||||||
|
args.reduction_stride - out_offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
inline auto output_grid_for_col_reduce(
|
||||||
|
const array& out,
|
||||||
|
const cu::ColReduceArgs& args) {
|
||||||
|
auto out_shape = out.shape();
|
||||||
|
auto out_strides = out.strides();
|
||||||
|
while (!out_shape.empty() && out_strides.back() < args.reduction_stride) {
|
||||||
|
out_shape.pop_back();
|
||||||
|
out_strides.pop_back();
|
||||||
|
}
|
||||||
|
return get_2d_grid_dims(out_shape, out_strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
void col_reduce(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan) {
|
||||||
|
cu::ColReduceArgs args(in, plan, axes);
|
||||||
|
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
|
using InType = cuda_type_t<CTYPE>;
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||||
|
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||||
|
constexpr int N_READS = 4;
|
||||||
|
dim3 block_dims;
|
||||||
|
dim3 num_blocks = output_grid_for_col_reduce(out, args);
|
||||||
|
num_blocks.z = num_blocks.y;
|
||||||
|
num_blocks.y = num_blocks.x;
|
||||||
|
auto kernel =
|
||||||
|
cu::col_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
||||||
|
size_t total = args.non_col_reductions * args.reduction_size;
|
||||||
|
if (total < 32) {
|
||||||
|
size_t stride_blocks =
|
||||||
|
cuda::ceil_div(args.reduction_stride, N_READS);
|
||||||
|
block_dims.x = std::min(stride_blocks, 32ul);
|
||||||
|
block_dims.y = std::min(total, 8ul);
|
||||||
|
num_blocks.x = cuda::ceil_div(stride_blocks, block_dims.x);
|
||||||
|
} else {
|
||||||
|
constexpr int BM = 32;
|
||||||
|
constexpr int BN = 32;
|
||||||
|
block_dims.x = BM * BN / N_READS;
|
||||||
|
num_blocks.x = cuda::ceil_div(args.reduction_stride, BN);
|
||||||
|
kernel = cu::
|
||||||
|
col_reduce_looped<InType, OutType, OP, NDIM, BM, BN, N_READS>;
|
||||||
|
}
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
in.data<InType>(), out.data<OutType>(), args);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
74
mlx/backend/cuda/reduce/reduce.cuh
Normal file
74
mlx/backend/cuda/reduce/reduce.cuh
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/reduce.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
|
||||||
|
#include "mlx/backend/cuda/reduce/reduce_ops.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
// Dispatch dynamic ndim to constexpr.
|
||||||
|
// The behavior follows get_kernel_reduce_ndim in metal/reduce.cpp file.
|
||||||
|
#define MLX_SWITCH_REDUCE_NDIM(ndim, NDIM, ...) \
|
||||||
|
if (ndim == 1) { \
|
||||||
|
constexpr uint32_t NDIM = 1; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else if (ndim == 2) { \
|
||||||
|
constexpr uint32_t NDIM = 2; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else { \
|
||||||
|
constexpr uint32_t NDIM = 5; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dispatch reduce ops to constexpr.
|
||||||
|
#define MLX_SWITCH_REDUCE_OPS(REDUCE, OP, ...) \
|
||||||
|
if (REDUCE == Reduce::ReduceType::And) { \
|
||||||
|
using OP = cu::And; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else if (REDUCE == Reduce::ReduceType::Or) { \
|
||||||
|
using OP = cu::Or; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else if (REDUCE == Reduce::ReduceType::Sum) { \
|
||||||
|
using OP = cu::Sum; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else if (REDUCE == Reduce::ReduceType::Prod) { \
|
||||||
|
using OP = cu::Prod; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else if (REDUCE == Reduce::ReduceType::Max) { \
|
||||||
|
using OP = cu::Max; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else if (REDUCE == Reduce::ReduceType::Min) { \
|
||||||
|
using OP = cu::Min; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else { \
|
||||||
|
throw std::invalid_argument("Unknown reduce type."); \
|
||||||
|
}
|
||||||
|
|
||||||
|
void segmented_reduce(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan);
|
||||||
|
|
||||||
|
void row_reduce(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan);
|
||||||
|
|
||||||
|
void col_reduce(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
144
mlx/backend/cuda/reduce/reduce_ops.cuh
Normal file
144
mlx/backend/cuda/reduce/reduce_ops.cuh
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/kernels/utils.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
// Reduce ops.
|
||||||
|
struct And {
|
||||||
|
__device__ bool operator()(bool a, bool b) {
|
||||||
|
return a && b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Or {
|
||||||
|
__device__ bool operator()(bool a, bool b) {
|
||||||
|
return a || b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Sum {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T a, T b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Prod {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T a, T b) {
|
||||||
|
return a * b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Min {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T a, T b) {
|
||||||
|
return a < b ? a : b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Max {
|
||||||
|
template <typename T>
|
||||||
|
__device__ T operator()(T a, T b) {
|
||||||
|
return a > b ? a : b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Traits to get the result type of reduce op.
|
||||||
|
template <typename Op, typename T>
|
||||||
|
struct ReduceResult;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceResult<And, T> {
|
||||||
|
using type = bool;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceResult<Or, T> {
|
||||||
|
using type = bool;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceResult<Sum, T> {
|
||||||
|
using type = cuda::std::conditional_t<
|
||||||
|
(cuda::std::is_integral_v<T> && sizeof(T) <= 4),
|
||||||
|
int32_t,
|
||||||
|
T>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceResult<Prod, T> {
|
||||||
|
using type = cuda::std::conditional_t<
|
||||||
|
(cuda::std::is_integral_v<T> && sizeof(T) <= 4),
|
||||||
|
int32_t,
|
||||||
|
T>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceResult<Min, T> {
|
||||||
|
using type = T;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceResult<Max, T> {
|
||||||
|
using type = T;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Traits to get the init value of reduce op.
|
||||||
|
template <typename Op, typename T>
|
||||||
|
struct ReduceInit;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceInit<And, T> {
|
||||||
|
static constexpr __host__ __device__ bool value() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceInit<Or, T> {
|
||||||
|
static constexpr __host__ __device__ bool value() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceInit<Sum, T> {
|
||||||
|
static constexpr __host__ __device__ auto value() {
|
||||||
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
return T{0, 0};
|
||||||
|
} else {
|
||||||
|
return typename ReduceResult<Sum, T>::type{0};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceInit<Prod, T> {
|
||||||
|
static constexpr __host__ __device__ auto value() {
|
||||||
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
return T{1, 1};
|
||||||
|
} else {
|
||||||
|
return typename ReduceResult<Prod, T>::type{1};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceInit<Min, T> {
|
||||||
|
static constexpr __host__ __device__ T value() {
|
||||||
|
return Limits<T>::max();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ReduceInit<Max, T> {
|
||||||
|
static constexpr __host__ __device__ T value() {
|
||||||
|
return Limits<T>::min();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
250
mlx/backend/cuda/reduce/row_reduce.cu
Normal file
250
mlx/backend/cuda/reduce/row_reduce.cu
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||||
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cooperative_groups/reduce.h>
|
||||||
|
#include <cub/block/block_load.cuh>
|
||||||
|
#include <cub/block/block_reduce.cuh>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
struct RowReduceArgs {
|
||||||
|
// The size of the row being reduced, i.e. the size of last dimension.
|
||||||
|
int row_size;
|
||||||
|
|
||||||
|
// Input shape and strides excluding the reduction axes.
|
||||||
|
Shape shape;
|
||||||
|
Strides strides;
|
||||||
|
int ndim;
|
||||||
|
|
||||||
|
// Input shape and strides of the reduction axes excluding last dimension.
|
||||||
|
Shape reduce_shape;
|
||||||
|
Strides reduce_strides;
|
||||||
|
int reduce_ndim;
|
||||||
|
|
||||||
|
// The number of rows we are reducing. Namely prod(reduce_shape).
|
||||||
|
size_t non_row_reductions;
|
||||||
|
|
||||||
|
RowReduceArgs(
|
||||||
|
const array& in,
|
||||||
|
const ReductionPlan& plan,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
assert(!plan.shape.empty());
|
||||||
|
row_size = plan.shape.back();
|
||||||
|
|
||||||
|
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
|
||||||
|
std::tie(shape_vec, strides_vec) =
|
||||||
|
collapse_contiguous_dims(shape_vec, strides_vec);
|
||||||
|
shape = const_param(shape_vec);
|
||||||
|
strides = const_param(strides_vec);
|
||||||
|
ndim = shape_vec.size();
|
||||||
|
|
||||||
|
reduce_shape = const_param(plan.shape);
|
||||||
|
reduce_strides = const_param(plan.strides);
|
||||||
|
reduce_ndim = plan.shape.size() - 1;
|
||||||
|
|
||||||
|
non_row_reductions = 1;
|
||||||
|
for (int i = 0; i < reduce_ndim; i++) {
|
||||||
|
non_row_reductions *= reduce_shape[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||||
|
__global__ void row_reduce_small(
|
||||||
|
const T* in,
|
||||||
|
U* out,
|
||||||
|
size_t out_size,
|
||||||
|
const __grid_constant__ RowReduceArgs args) {
|
||||||
|
size_t out_idx = cg::this_grid().thread_rank();
|
||||||
|
if (out_idx >= out_size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Op op;
|
||||||
|
|
||||||
|
U total_val = ReduceInit<Op, T>::value();
|
||||||
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
|
|
||||||
|
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||||
|
|
||||||
|
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||||
|
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
|
||||||
|
U vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
r,
|
||||||
|
make_cast_iterator<U>(in + loop.location()),
|
||||||
|
vals,
|
||||||
|
args.row_size,
|
||||||
|
ReduceInit<Op, T>::value());
|
||||||
|
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||||
|
}
|
||||||
|
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
out[out_idx] = total_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||||
|
__global__ void row_reduce_small_warp(
|
||||||
|
const T* in,
|
||||||
|
U* out,
|
||||||
|
size_t out_size,
|
||||||
|
const __grid_constant__ RowReduceArgs args) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
|
||||||
|
size_t out_idx = grid.thread_rank() / WARP_SIZE;
|
||||||
|
if (out_idx >= out_size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Op op;
|
||||||
|
|
||||||
|
U total_val = ReduceInit<Op, T>::value();
|
||||||
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
|
|
||||||
|
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||||
|
|
||||||
|
for (size_t n = warp.thread_rank(); n < args.non_row_reductions;
|
||||||
|
n += WARP_SIZE) {
|
||||||
|
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
|
||||||
|
U vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
r,
|
||||||
|
make_cast_iterator<U>(in + loop.location()),
|
||||||
|
vals,
|
||||||
|
args.row_size,
|
||||||
|
ReduceInit<Op, T>::value());
|
||||||
|
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||||
|
}
|
||||||
|
loop.next(WARP_SIZE, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
total_val = cg::reduce(warp, total_val, op);
|
||||||
|
|
||||||
|
if (warp.thread_rank() == 0) {
|
||||||
|
out[out_idx] = total_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
typename Op,
|
||||||
|
int NDIM,
|
||||||
|
int BLOCK_DIM_X,
|
||||||
|
int N_READS = 4>
|
||||||
|
__global__ void row_reduce_looped(
|
||||||
|
const T* in,
|
||||||
|
U* out,
|
||||||
|
size_t out_size,
|
||||||
|
const __grid_constant__ RowReduceArgs args) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
|
size_t out_idx = grid.thread_rank() / BLOCK_DIM_X;
|
||||||
|
if (out_idx >= out_size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Op op;
|
||||||
|
|
||||||
|
U total_val = ReduceInit<Op, T>::value();
|
||||||
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
|
|
||||||
|
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||||
|
|
||||||
|
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||||
|
for (size_t r = 0; r < cuda::ceil_div(args.row_size, BLOCK_DIM_X * N_READS);
|
||||||
|
r++) {
|
||||||
|
U vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
r * BLOCK_DIM_X + block.thread_index().x,
|
||||||
|
make_cast_iterator<U>(in + loop.location()),
|
||||||
|
vals,
|
||||||
|
args.row_size,
|
||||||
|
ReduceInit<Op, T>::value());
|
||||||
|
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
||||||
|
}
|
||||||
|
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef cub::BlockReduce<U, BLOCK_DIM_X> BlockReduceT;
|
||||||
|
__shared__ typename BlockReduceT::TempStorage temp;
|
||||||
|
|
||||||
|
total_val = BlockReduceT(temp).Reduce(total_val, op);
|
||||||
|
|
||||||
|
if (block.thread_rank() == 0) {
|
||||||
|
out[out_idx] = total_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void row_reduce(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan) {
|
||||||
|
cu::RowReduceArgs args(in, plan, axes);
|
||||||
|
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
|
using InType = cuda_type_t<CTYPE>;
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||||
|
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||||
|
constexpr size_t N_READS = 4;
|
||||||
|
dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
|
dim3 block_dims, num_blocks;
|
||||||
|
auto kernel =
|
||||||
|
cu::row_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
||||||
|
if (args.row_size <= 64) {
|
||||||
|
if ((args.non_row_reductions < 32 && args.row_size <= 8) ||
|
||||||
|
(args.non_row_reductions <= 8)) {
|
||||||
|
block_dims.x = std::min(out_dims.x, 1024u);
|
||||||
|
num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x);
|
||||||
|
num_blocks.y = out_dims.y;
|
||||||
|
} else {
|
||||||
|
block_dims.x = WARP_SIZE;
|
||||||
|
num_blocks.y = out_dims.x;
|
||||||
|
num_blocks.z = out_dims.y;
|
||||||
|
kernel =
|
||||||
|
cu::row_reduce_small_warp<InType, OutType, OP, NDIM, N_READS>;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
size_t num_threads = cuda::ceil_div(args.row_size, N_READS);
|
||||||
|
num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE;
|
||||||
|
MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, {
|
||||||
|
num_blocks.y = out_dims.x;
|
||||||
|
num_blocks.z = out_dims.y;
|
||||||
|
block_dims.x = BLOCK_DIM_X;
|
||||||
|
kernel = cu::row_reduce_looped<
|
||||||
|
InType,
|
||||||
|
OutType,
|
||||||
|
OP,
|
||||||
|
NDIM,
|
||||||
|
BLOCK_DIM_X,
|
||||||
|
N_READS>;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
in.data<InType>(), out.data<OutType>(), out.size(), args);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
84
mlx/backend/cuda/reduce/segmented_reduce.cu
Normal file
84
mlx/backend/cuda/reduce/segmented_reduce.cu
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||||
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
|
||||||
|
#include <thrust/device_ptr.h>
|
||||||
|
#include <cub/device/device_reduce.cuh>
|
||||||
|
#include <cub/device/device_segmented_reduce.cuh>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||||
|
// Allocate temporary storage.
|
||||||
|
size_t size;
|
||||||
|
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...));
|
||||||
|
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||||
|
encoder.add_temporary(temp);
|
||||||
|
// Run op.
|
||||||
|
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data<void>(), size, args...));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
||||||
|
// Allocate temporary storage.
|
||||||
|
size_t size;
|
||||||
|
CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...));
|
||||||
|
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||||
|
encoder.add_temporary(temp);
|
||||||
|
// Run op.
|
||||||
|
CHECK_CUDA_ERROR(
|
||||||
|
cub::DeviceSegmentedReduce::Reduce(temp.data<void>(), size, args...));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MultiplyOp {
|
||||||
|
int factor;
|
||||||
|
__device__ int operator()(int i) {
|
||||||
|
return i * factor;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void segmented_reduce(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan) {
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
using InType = cuda_type_t<CTYPE>;
|
||||||
|
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||||
|
auto in_iter = cu::make_cast_iterator<OutType>(
|
||||||
|
thrust::device_pointer_cast(in.data<InType>()));
|
||||||
|
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
|
||||||
|
auto init = cu::ReduceInit<OP, InType>::value();
|
||||||
|
|
||||||
|
if (plan.type == ContiguousAllReduce) {
|
||||||
|
cub_all_reduce(
|
||||||
|
encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream);
|
||||||
|
} else if (plan.type == ContiguousReduce) {
|
||||||
|
auto offsets = thrust::make_transform_iterator(
|
||||||
|
thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()});
|
||||||
|
cub_segmented_reduce(
|
||||||
|
encoder,
|
||||||
|
in_iter,
|
||||||
|
out_ptr,
|
||||||
|
out.size(),
|
||||||
|
offsets,
|
||||||
|
offsets + 1,
|
||||||
|
OP(),
|
||||||
|
init,
|
||||||
|
stream);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Unsupported plan in segmented_reduce.");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
160
mlx/backend/cuda/softmax.cu
Normal file
160
mlx/backend/cuda/softmax.cu
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cooperative_groups/reduce.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <cub/block/block_load.cuh>
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline __device__ T softmax_exp(T x) {
|
||||||
|
// Softmax doesn't need high precision exponential cause x is gonna be in
|
||||||
|
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
|
||||||
|
return __expf(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename AccT, int BLOCK_DIM, int N_READS = 4>
|
||||||
|
__global__ void softmax(const T* in, T* out, int axis_size) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
|
||||||
|
in += grid.block_rank() * axis_size;
|
||||||
|
out += grid.block_rank() * axis_size;
|
||||||
|
|
||||||
|
cg::greater<AccT> max_op;
|
||||||
|
cg::plus<AccT> plus_op;
|
||||||
|
|
||||||
|
// Thread reduce.
|
||||||
|
AccT prevmax;
|
||||||
|
AccT maxval = Limits<AccT>::finite_min();
|
||||||
|
AccT normalizer = 0;
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||||
|
AccT vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
r * BLOCK_DIM + block.thread_rank(),
|
||||||
|
make_cast_iterator<AccT>(in),
|
||||||
|
vals,
|
||||||
|
axis_size,
|
||||||
|
Limits<AccT>::finite_min());
|
||||||
|
prevmax = maxval;
|
||||||
|
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
|
||||||
|
// Online normalizer calculation for softmax:
|
||||||
|
// https://github.com/NVIDIA/online-softmax
|
||||||
|
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
normalizer = normalizer + softmax_exp(vals[i] - maxval);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// First warp reduce.
|
||||||
|
prevmax = maxval;
|
||||||
|
maxval = cg::reduce(warp, maxval, max_op);
|
||||||
|
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||||
|
normalizer = cg::reduce(warp, normalizer, plus_op);
|
||||||
|
|
||||||
|
__shared__ AccT local_max[WARP_SIZE];
|
||||||
|
__shared__ AccT local_normalizer[WARP_SIZE];
|
||||||
|
|
||||||
|
// Write to shared memory and do second warp reduce.
|
||||||
|
prevmax = maxval;
|
||||||
|
if (warp.thread_rank() == 0) {
|
||||||
|
local_max[warp.meta_group_rank()] = maxval;
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
maxval = warp.thread_rank() < warp.meta_group_size()
|
||||||
|
? local_max[warp.thread_rank()]
|
||||||
|
: Limits<AccT>::finite_min();
|
||||||
|
maxval = cg::reduce(warp, maxval, max_op);
|
||||||
|
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||||
|
if (warp.thread_rank() == 0) {
|
||||||
|
local_normalizer[warp.meta_group_rank()] = normalizer;
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
normalizer = warp.thread_rank() < warp.meta_group_size()
|
||||||
|
? local_normalizer[warp.thread_rank()]
|
||||||
|
: AccT{};
|
||||||
|
normalizer = cg::reduce(warp, normalizer, plus_op);
|
||||||
|
normalizer = 1 / normalizer;
|
||||||
|
|
||||||
|
// Write output.
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||||
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
|
T vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(index, in, vals, axis_size);
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
vals[i] = softmax_exp(static_cast<AccT>(vals[i]) - maxval) * normalizer;
|
||||||
|
}
|
||||||
|
cub::StoreDirectBlocked(index, out, vals, axis_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Softmax::eval_gpu");
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& s = stream();
|
||||||
|
|
||||||
|
// Make sure that the last dimension is contiguous.
|
||||||
|
auto set_output = [&s, &out](const array& x) {
|
||||||
|
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||||
|
if (x.is_donatable()) {
|
||||||
|
out.copy_shared_buffer(x);
|
||||||
|
} else {
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc(x.data_size() * x.itemsize()),
|
||||||
|
x.data_size(),
|
||||||
|
x.strides(),
|
||||||
|
x.flags());
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
} else {
|
||||||
|
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||||
|
copy_gpu(x, x_copy, CopyType::General, s);
|
||||||
|
out.copy_shared_buffer(x_copy);
|
||||||
|
return x_copy;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
array in = set_output(inputs[0]);
|
||||||
|
bool precise = in.dtype() != float32 && precise_;
|
||||||
|
|
||||||
|
int axis_size = in.shape().back();
|
||||||
|
int n_rows = in.data_size() / axis_size;
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "softmax", CTYPE, {
|
||||||
|
using DataType = cuda_type_t<CTYPE>;
|
||||||
|
constexpr int N_READS = 4;
|
||||||
|
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||||
|
auto kernel = cu::softmax<DataType, DataType, BLOCK_DIM, N_READS>;
|
||||||
|
if (precise) {
|
||||||
|
kernel = cu::softmax<DataType, float, BLOCK_DIM, N_READS>;
|
||||||
|
}
|
||||||
|
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||||
|
in.data<DataType>(), out.data<DataType>(), axis_size);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -391,6 +391,7 @@ void implicit_gemm_conv_2D_general_gpu(
|
|||||||
// Get channel iteration info
|
// Get channel iteration info
|
||||||
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
|
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
|
||||||
int gemm_k_iters = channel_k_iters;
|
int gemm_k_iters = channel_k_iters;
|
||||||
|
bool align_C = conv_params.C % bk == 0;
|
||||||
|
|
||||||
// Fix host side helper params
|
// Fix host side helper params
|
||||||
int sign = (conv_params.flip ? -1 : 1);
|
int sign = (conv_params.flip ? -1 : 1);
|
||||||
@ -419,14 +420,33 @@ void implicit_gemm_conv_2D_general_gpu(
|
|||||||
/* const int swizzle_log = */ swizzle_log};
|
/* const int swizzle_log = */ swizzle_log};
|
||||||
|
|
||||||
// Determine kernel
|
// Determine kernel
|
||||||
std::ostringstream kname;
|
std::string kname;
|
||||||
kname << "implicit_gemm_conv_2d_general_" << type_to_name(out) << "_bm" << bm
|
kname.reserve(64);
|
||||||
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
|
concatenate(
|
||||||
|
kname,
|
||||||
|
"implicit_gemm_conv_2d_general_",
|
||||||
|
type_to_name(out),
|
||||||
|
"_bm",
|
||||||
|
bm,
|
||||||
|
"_bn",
|
||||||
|
bn,
|
||||||
|
"_bk",
|
||||||
|
bk,
|
||||||
|
"_wm",
|
||||||
|
wm,
|
||||||
|
"_wn",
|
||||||
|
wn);
|
||||||
|
std::string hash_name;
|
||||||
|
hash_name.reserve(64);
|
||||||
|
concatenate(hash_name, kname, "_alC_", align_C);
|
||||||
|
metal::MTLFCList func_consts = {
|
||||||
|
{&align_C, MTL::DataType::DataTypeBool, 200},
|
||||||
|
};
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel =
|
auto kernel = get_steel_conv_general_kernel(
|
||||||
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
|
d, kname, hash_name, func_consts, out, bm, bn, bk, wm, wn);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
// Deduce grid launch dimensions
|
// Deduce grid launch dimensions
|
||||||
@ -728,8 +748,10 @@ void dispatch_conv_2D_gpu(
|
|||||||
|
|
||||||
// Direct to winograd conv
|
// Direct to winograd conv
|
||||||
bool inp_large =
|
bool inp_large =
|
||||||
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
|
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 4096;
|
||||||
bool channels_large = (conv_params.C + conv_params.O) >= 256;
|
bool channels_large = (conv_params.C + conv_params.O) >= 256;
|
||||||
|
bool out_large =
|
||||||
|
(conv_params.N * conv_params.oS[0] * conv_params.oS[1]) >= 256;
|
||||||
if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one &&
|
if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one &&
|
||||||
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
||||||
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
|
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
|
||||||
@ -743,7 +765,7 @@ void dispatch_conv_2D_gpu(
|
|||||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||||
}
|
}
|
||||||
|
|
||||||
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
|
else if ((conv_params.C % 16 == 0 && conv_params.O % 16 == 0) || out_large) {
|
||||||
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
|
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -727,6 +727,8 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
|
|||||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const metal::MTLFCList& func_consts,
|
||||||
const array& out,
|
const array& out,
|
||||||
int bm,
|
int bm,
|
||||||
int bn,
|
int bn,
|
||||||
@ -749,7 +751,7 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
|||||||
wn);
|
wn);
|
||||||
return kernel_source.str();
|
return kernel_source.str();
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_fft_kernel(
|
MTL::ComputePipelineState* get_fft_kernel(
|
||||||
|
@ -205,6 +205,8 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
|
|||||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const metal::MTLFCList& func_consts,
|
||||||
const array& out,
|
const array& out,
|
||||||
int bm,
|
int bm,
|
||||||
int bn,
|
int bn,
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
|
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
|
||||||
|
|
||||||
|
constant bool align_C [[function_constant(200)]];
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
int BM,
|
int BM,
|
||||||
@ -118,6 +120,7 @@ implicit_gemm_conv_2d_general(
|
|||||||
// Prepare threadgroup mma operation
|
// Prepare threadgroup mma operation
|
||||||
mma_t mma_op(simd_gid, simd_lid);
|
mma_t mma_op(simd_gid, simd_lid);
|
||||||
|
|
||||||
|
if (align_C) {
|
||||||
int gemm_k_iterations =
|
int gemm_k_iterations =
|
||||||
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
|
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
|
||||||
|
|
||||||
@ -136,6 +139,40 @@ implicit_gemm_conv_2d_general(
|
|||||||
loader_a.next();
|
loader_a.next();
|
||||||
loader_b.next();
|
loader_b.next();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
else {
|
||||||
|
for (int k = 1; k < gemm_params->gemm_k_iterations; k++) {
|
||||||
|
for (int j = 0; j < base_wh_size * base_ww_size; j++) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Load elements into threadgroup
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const short remaining_k = params->C % BK;
|
||||||
|
for (int j = 0; j < base_wh_size * base_ww_size; j++) {
|
||||||
|
// Load elements into threadgroup
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
loader_a.load_safe(remaining_k);
|
||||||
|
loader_b.load_safe(remaining_k);
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
@ -137,6 +137,52 @@ struct Conv2DInputBlockLoaderGeneral {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
METAL_FUNC void load_safe(const short remaining_k) const {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
|
||||||
|
// Find bounds
|
||||||
|
int n = read_n[i];
|
||||||
|
|
||||||
|
int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h;
|
||||||
|
int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w;
|
||||||
|
|
||||||
|
int ih_dil = read_ih[i] + h_flip * params->kdil[0];
|
||||||
|
int iw_dil = read_iw[i] + w_flip * params->kdil[1];
|
||||||
|
|
||||||
|
int ih = ih_dil / params->idil[0];
|
||||||
|
int iw = iw_dil / params->idil[1];
|
||||||
|
|
||||||
|
size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2];
|
||||||
|
|
||||||
|
// Read from input if in bounds
|
||||||
|
if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) &&
|
||||||
|
(iw_dil >= 0 && iw < params->iS[1])) {
|
||||||
|
if (bj + vec_size <= remaining_k) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < vec_size; ++j) {
|
||||||
|
dst[is * dst_ld + j] = (src[i])[offset + j];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (short j = 0; j < vec_size; ++j) {
|
||||||
|
if (bj + j < remaining_k) {
|
||||||
|
dst[is * dst_ld + j] = (src[i])[offset + j];
|
||||||
|
} else {
|
||||||
|
dst[is * dst_ld + j] = T(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Zero pad otherwise
|
||||||
|
else {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < vec_size; ++j) {
|
||||||
|
dst[is * dst_ld + j] = T(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* Iteration helper */
|
/* Iteration helper */
|
||||||
METAL_FUNC void next() {
|
METAL_FUNC void next() {
|
||||||
weight_w += jump_params->f_wgt_jump_w;
|
weight_w += jump_params->f_wgt_jump_w;
|
||||||
@ -262,6 +308,55 @@ struct Conv2DWeightBlockLoaderGeneral {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
METAL_FUNC void load_safe(const short remaining_k) const {
|
||||||
|
const device T* curr_src = src + weight_h * params->wt_strides[1] +
|
||||||
|
weight_w * params->wt_strides[2];
|
||||||
|
|
||||||
|
if ((start_row + BN <= params->O)) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < BN; i += TROWS) {
|
||||||
|
if (bj + vec_size <= remaining_k) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
if (bj + j < remaining_k) {
|
||||||
|
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||||
|
} else {
|
||||||
|
dst[i * dst_ld + j] = T(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (short i = 0; i < BN; i += TROWS) {
|
||||||
|
if ((start_row + i) < params->O) {
|
||||||
|
if (bj + vec_size <= remaining_k) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
if (bj + j < remaining_k) {
|
||||||
|
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||||
|
} else {
|
||||||
|
dst[i * dst_ld + j] = T(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
dst[i * dst_ld + j] = T(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* Iteration helper */
|
/* Iteration helper */
|
||||||
METAL_FUNC void next() {
|
METAL_FUNC void next() {
|
||||||
weight_w += jump_params->f_wgt_jump_w;
|
weight_w += jump_params->f_wgt_jump_w;
|
||||||
|
@ -3,8 +3,11 @@
|
|||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
|
#include "mlx/fast.h"
|
||||||
|
|
||||||
namespace mlx::core::metal {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace metal {
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return false;
|
return false;
|
||||||
@ -19,4 +22,21 @@ device_info() {
|
|||||||
"[metal::device_info] Cannot get device info without metal backend");
|
"[metal::device_info] Cannot get device info without metal backend");
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace metal
|
||||||
|
|
||||||
|
namespace fast {
|
||||||
|
|
||||||
|
MetalKernelFunction metal_kernel(
|
||||||
|
const std::string&,
|
||||||
|
const std::vector<std::string>&,
|
||||||
|
const std::vector<std::string>&,
|
||||||
|
const std::string&,
|
||||||
|
const std::string&,
|
||||||
|
bool ensure_row_contiguous,
|
||||||
|
bool atomic_outputs) {
|
||||||
|
throw std::runtime_error("[metal_kernel] No GPU back-end.");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace fast
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
|
@ -244,13 +244,15 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
|
|||||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const metal::MTLFCList& func_consts,
|
||||||
const array&,
|
const array&,
|
||||||
int,
|
int,
|
||||||
int,
|
int,
|
||||||
int,
|
int,
|
||||||
int,
|
int,
|
||||||
int) {
|
int) {
|
||||||
return d.get_kernel(kernel_name);
|
return d.get_kernel(kernel_name, hash_name, func_consts);
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_fft_kernel(
|
MTL::ComputePipelineState* get_fft_kernel(
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/distributed/primitives.h"
|
#include "mlx/distributed/primitives.h"
|
||||||
#include "mlx/fast.h"
|
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
|
|
||||||
#define NO_GPU_MULTI(func) \
|
#define NO_GPU_MULTI(func) \
|
||||||
@ -156,18 +155,6 @@ NO_GPU_USE_FALLBACK(RoPE)
|
|||||||
NO_GPU(ScaledDotProductAttention)
|
NO_GPU(ScaledDotProductAttention)
|
||||||
NO_GPU_MULTI(AffineQuantize)
|
NO_GPU_MULTI(AffineQuantize)
|
||||||
NO_GPU_MULTI(CustomKernel)
|
NO_GPU_MULTI(CustomKernel)
|
||||||
|
|
||||||
MetalKernelFunction metal_kernel(
|
|
||||||
const std::string&,
|
|
||||||
const std::vector<std::string>&,
|
|
||||||
const std::vector<std::string>&,
|
|
||||||
const std::string&,
|
|
||||||
const std::string&,
|
|
||||||
bool ensure_row_contiguous,
|
|
||||||
bool atomic_outputs) {
|
|
||||||
throw std::runtime_error("[metal_kernel] No GPU back-end.");
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace fast
|
} // namespace fast
|
||||||
|
|
||||||
namespace distributed {
|
namespace distributed {
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
#include "mlx/compile.h"
|
#include "mlx/compile.h"
|
||||||
#include "mlx/device.h"
|
#include "mlx/device.h"
|
||||||
|
@ -17,10 +17,7 @@
|
|||||||
#include "python/src/indexing.h"
|
#include "python/src/indexing.h"
|
||||||
#include "python/src/utils.h"
|
#include "python/src/utils.h"
|
||||||
|
|
||||||
#include "mlx/device.h"
|
#include "mlx/mlx.h"
|
||||||
#include "mlx/ops.h"
|
|
||||||
#include "mlx/transforms.h"
|
|
||||||
#include "mlx/utils.h"
|
|
||||||
|
|
||||||
namespace mx = mlx::core;
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
@ -461,9 +458,12 @@ void init_array(nb::module_& m) {
|
|||||||
.def(
|
.def(
|
||||||
"__dlpack_device__",
|
"__dlpack_device__",
|
||||||
[](const mx::array& a) {
|
[](const mx::array& a) {
|
||||||
|
// See
|
||||||
|
// https://github.com/dmlc/dlpack/blob/5c210da409e7f1e51ddf445134a4376fdbd70d7d/include/dlpack/dlpack.h#L74
|
||||||
if (mx::metal::is_available()) {
|
if (mx::metal::is_available()) {
|
||||||
// Metal device is available
|
|
||||||
return nb::make_tuple(8, 0);
|
return nb::make_tuple(8, 0);
|
||||||
|
} else if (mx::cu::is_available()) {
|
||||||
|
return nb::make_tuple(13, 0);
|
||||||
} else {
|
} else {
|
||||||
// CPU device
|
// CPU device
|
||||||
return nb::make_tuple(1, 0);
|
return nb::make_tuple(1, 0);
|
||||||
|
@ -58,4 +58,9 @@ void init_device(nb::module_& m) {
|
|||||||
&mx::set_default_device,
|
&mx::set_default_device,
|
||||||
"device"_a,
|
"device"_a,
|
||||||
R"pbdoc(Set the default device.)pbdoc");
|
R"pbdoc(Set the default device.)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"is_available",
|
||||||
|
&mx::is_available,
|
||||||
|
"device"_a,
|
||||||
|
R"pbdoc(Check if a back-end is available for the given device.)pbdoc");
|
||||||
}
|
}
|
||||||
|
@ -198,7 +198,7 @@ class TestInequality(mlx_tests.MLXTestCase):
|
|||||||
def test_dlx_device_type(self):
|
def test_dlx_device_type(self):
|
||||||
a = mx.array([1, 2, 3])
|
a = mx.array([1, 2, 3])
|
||||||
device_type, device_id = a.__dlpack_device__()
|
device_type, device_id = a.__dlpack_device__()
|
||||||
self.assertIn(device_type, [1, 8])
|
self.assertIn(device_type, [1, 8, 13])
|
||||||
self.assertEqual(device_id, 0)
|
self.assertEqual(device_id, 0)
|
||||||
|
|
||||||
if device_type == 8:
|
if device_type == 8:
|
||||||
|
@ -1173,6 +1173,19 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertTrue(mx.allclose(out, out_2d.squeeze(2)))
|
self.assertTrue(mx.allclose(out, out_2d.squeeze(2)))
|
||||||
|
|
||||||
|
def test_conv2d_unaligned_channels(self):
|
||||||
|
x = mx.random.uniform(shape=(2, 16, 16, 21))
|
||||||
|
w = mx.random.uniform(shape=(32, 3, 3, 21))
|
||||||
|
y = mx.conv2d(x, w, stream=mx.cpu)
|
||||||
|
y_hat = mx.conv2d(x, w)
|
||||||
|
self.assertTrue(mx.allclose(y, y_hat))
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(2, 16, 16, 21))
|
||||||
|
w = mx.random.uniform(shape=(21, 3, 3, 21))
|
||||||
|
y = mx.conv2d(x, w, stream=mx.cpu)
|
||||||
|
y_hat = mx.conv2d(x, w)
|
||||||
|
self.assertTrue(mx.allclose(y, y_hat))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -10,7 +10,7 @@ import mlx_tests
|
|||||||
class TestDefaultDevice(unittest.TestCase):
|
class TestDefaultDevice(unittest.TestCase):
|
||||||
def test_mlx_default_device(self):
|
def test_mlx_default_device(self):
|
||||||
device = mx.default_device()
|
device = mx.default_device()
|
||||||
if mx.metal.is_available():
|
if mx.is_available(mx.gpu):
|
||||||
self.assertEqual(device, mx.Device(mx.gpu))
|
self.assertEqual(device, mx.Device(mx.gpu))
|
||||||
self.assertEqual(str(device), "Device(gpu, 0)")
|
self.assertEqual(str(device), "Device(gpu, 0)")
|
||||||
self.assertEqual(device, mx.gpu)
|
self.assertEqual(device, mx.gpu)
|
||||||
@ -73,7 +73,7 @@ class TestStream(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(s2.device, mx.default_device())
|
self.assertEqual(s2.device, mx.default_device())
|
||||||
self.assertNotEqual(s1, s2)
|
self.assertNotEqual(s1, s2)
|
||||||
|
|
||||||
if mx.metal.is_available():
|
if mx.is_available(mx.gpu):
|
||||||
s_gpu = mx.default_stream(mx.gpu)
|
s_gpu = mx.default_stream(mx.gpu)
|
||||||
self.assertEqual(s_gpu.device, mx.gpu)
|
self.assertEqual(s_gpu.device, mx.gpu)
|
||||||
else:
|
else:
|
||||||
@ -86,7 +86,7 @@ class TestStream(mlx_tests.MLXTestCase):
|
|||||||
s_cpu = mx.new_stream(mx.cpu)
|
s_cpu = mx.new_stream(mx.cpu)
|
||||||
self.assertEqual(s_cpu.device, mx.cpu)
|
self.assertEqual(s_cpu.device, mx.cpu)
|
||||||
|
|
||||||
if mx.metal.is_available():
|
if mx.is_available(mx.gpu):
|
||||||
s_gpu = mx.new_stream(mx.gpu)
|
s_gpu = mx.new_stream(mx.gpu)
|
||||||
self.assertEqual(s_gpu.device, mx.gpu)
|
self.assertEqual(s_gpu.device, mx.gpu)
|
||||||
else:
|
else:
|
||||||
@ -99,7 +99,7 @@ class TestStream(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
a = mx.add(x, y, stream=mx.default_stream(mx.default_device()))
|
a = mx.add(x, y, stream=mx.default_stream(mx.default_device()))
|
||||||
|
|
||||||
if mx.metal.is_available():
|
if mx.is_available(mx.gpu):
|
||||||
b = mx.add(x, y, stream=mx.default_stream(mx.gpu))
|
b = mx.add(x, y, stream=mx.default_stream(mx.gpu))
|
||||||
self.assertEqual(a.item(), b.item())
|
self.assertEqual(a.item(), b.item())
|
||||||
s_gpu = mx.new_stream(mx.gpu)
|
s_gpu = mx.new_stream(mx.gpu)
|
||||||
|
@ -353,7 +353,7 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
|
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
|
||||||
|
|
||||||
|
|
||||||
class TestSchedulers(unittest.TestCase):
|
class TestSchedulers(mlx_tests.MLXTestCase):
|
||||||
def test_decay_lr(self):
|
def test_decay_lr(self):
|
||||||
for optim_class in optimizers_dict.values():
|
for optim_class in optimizers_dict.values():
|
||||||
lr_schedule = opt.step_decay(1e-1, 0.9, 1)
|
lr_schedule = opt.step_decay(1e-1, 0.9, 1)
|
||||||
|
Loading…
Reference in New Issue
Block a user