Compare commits

..

16 Commits

Author SHA1 Message Date
Awni Hannun
da5912e4f2 fix custom metal extension (#2446) 2025-07-31 06:25:36 -07:00
Cheng
daafee676f Fix wrong graph key when using concurrent context (#2447) 2025-07-31 06:01:05 -07:00
Awni Hannun
d32519c8ee fix gemv regression (#2445) 2025-07-30 14:23:01 -07:00
Awni Hannun
b405591249 fix circular reference (#2443) 2025-07-30 09:37:44 -07:00
Angelos Katharopoulos
3bf81ed1bd [CUDA] Quantized refactoring (#2442) 2025-07-30 08:27:20 -07:00
Cheng
2204182bba Make CI faster (#2440) 2025-07-30 02:26:36 -07:00
Cheng
3628e5d497 Use load_vector in arg_reduce (#2439) 2025-07-30 17:40:26 +09:00
Cheng
a0ae49d397 Move arange to its own file (#2438) 2025-07-30 13:05:51 +09:00
Cheng
254476718b Remove the kernel arg from get_launch_args (#2437) 2025-07-30 11:43:02 +09:00
Awni Hannun
3adba92ebe Cuda faster softmax (#2435)
* faster softmax and logsumexp

* faster softmax and logsumexp

* format
2025-07-29 17:18:12 -07:00
Awni Hannun
ef631d63af faster rms norm (#2433) 2025-07-29 13:12:00 -07:00
Cheng
970dbe8e25 Use ccache in CI (#2414)
* Detect ccache

* Use ccache in CI

* Separate cache for different images

* Test both 12.2 and 12.9 for PRs
2025-07-29 08:43:22 +09:00
Awni Hannun
641be9463b Add more CUDA architectures for PyPi package (#2427)
* add cuda sm 90

* add more archs
2025-07-28 12:35:15 -07:00
Awni Hannun
ab0e608862 [CUDA] More sizes for gemv (#2429)
* route more to gemv

* route more sizes to custom gemv
2025-07-28 12:35:01 -07:00
Awni Hannun
1588659062 no occupancy query for launch params (#2426) 2025-07-28 09:09:41 -07:00
Awni Hannun
b9e88fb976 [CUDA] Fix segfault on exit (#2424)
* fix cuda segfault on exit

* comment
2025-07-27 08:08:13 -07:00
43 changed files with 880 additions and 635 deletions

View File

@@ -81,23 +81,24 @@ jobs:
export DEBIAN_FRONTEND=noninteractive export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a export NEEDRESTART_MODE=a
sudo apt-get update sudo apt-get update
sudo apt-get upgrade -y
pip install --upgrade cmake
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
curl -LsSf https://astral.sh/uv/install.sh | sh
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
pip install -e ".[dev]" uv venv
uv pip install cmake
uv pip install -e ".[dev]" -v
- run: - run:
name: Generate package stubs name: Generate package stubs
command: | command: |
echo "stubs" uv pip install typing_extensions
pip install typing_extensions uv run --no-project setup.py generate_stubs
python setup.py generate_stubs
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |
source .venv/bin/activate
python -m unittest discover python/tests -v python -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2) mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
@@ -105,6 +106,7 @@ jobs:
- run: - run:
name: Build CPP only name: Build CPP only
command: | command: |
source .venv/bin/activate
mkdir -p build && cd build mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc` make -j `nproc`
@@ -130,33 +132,30 @@ jobs:
- run: - run:
name: Install dependencies name: Install dependencies
command: | command: |
brew install python@3.9 HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
brew install openmpi brew install openmpi uv
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install numpy
pip install torch
pip install tensorflow
pip install unittest-xml-reporting
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
source env/bin/activate uv venv --python 3.9
uv pip install \
nanobind==2.4.0 \
cmake \
numpy \
torch \
tensorflow \
unittest-xml-reporting
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \ DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
pip install -e . -v uv pip install -e . -v
- run: - run:
name: Generate package stubs name: Generate package stubs
command: | command: |
source env/bin/activate uv pip install typing_extensions
pip install typing_extensions uv run --no-project setup.py generate_stubs
python setup.py generate_stubs
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |
source env/bin/activate source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
@@ -165,16 +164,17 @@ jobs:
- run: - run:
name: Build example extension name: Build example extension
command: | command: |
source env/bin/activate source .venv/bin/activate
cd examples/extensions cd examples/extensions
pip install -r requirements.txt uv pip install -r requirements.txt
python setup.py build_ext -j8 uv run --no-project setup.py build_ext --inplace
uv run --no-project python test.py
- store_test_results: - store_test_results:
path: test-results path: test-results
- run: - run:
name: Build CPP only name: Build CPP only
command: | command: |
source env/bin/activate source .venv/bin/activate
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu` mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
- run: - run:
name: Run CPP tests name: Run CPP tests
@@ -183,7 +183,7 @@ jobs:
- run: - run:
name: Build small binary name: Build small binary
command: | command: |
source env/bin/activate source .venv/bin/activate
cd build/ cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \ cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \ -DBUILD_SHARED_LIBS=ON \
@@ -195,12 +195,13 @@ jobs:
- run: - run:
name: Run Python tests with JIT name: Run Python tests with JIT
command: | command: |
source env/bin/activate
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \ CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v uv pip install -e .
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \ LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \ METAL_DEBUG_ERROR_MODE=0 \
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit uv run --no-project python -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit
cuda_build_and_test: cuda_build_and_test:
parameters: parameters:
@@ -212,22 +213,42 @@ jobs:
resource_class: gpu.nvidia.small.gen2 resource_class: gpu.nvidia.small.gen2
steps: steps:
- checkout - checkout
- restore_cache:
keys:
- cuda-<< parameters.image_date >>-{{ arch }}-
- run: - run:
name: Install Python package name: Install dependencies
command: | command: |
sudo apt-get update sudo apt-get update
sudo apt-get install libcudnn9-dev-cuda-12 sudo apt-get install libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
python3 -m venv env curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
source env/bin/activate sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
rm -rf ccache-4.11.3-linux-x86_64
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Install Python package
command: |
uv venv
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install -e ".[dev]" uv pip install -e ".[dev]" -v
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |
source env/bin/activate source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
- run:
name: CCache report
command: |
ccache --show-stats
ccache --zero-stats
ccache --max-size 400MB
ccache --cleanup
- save_cache:
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
paths:
- /home/circleci/.cache/ccache
build_release: build_release:
parameters: parameters:
@@ -323,14 +344,10 @@ jobs:
export DEBIAN_FRONTEND=noninteractive export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a export NEEDRESTART_MODE=a
sudo apt-get update sudo apt-get update
sudo apt-get upgrade -y
TZ=Etc/UTC sudo apt-get -y install tzdata TZ=Etc/UTC sudo apt-get -y install tzdata
sudo apt-get install -y apt-utils
sudo apt-get install -y software-properties-common
sudo add-apt-repository -y ppa:deadsnakes/ppa sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install -y build-essential git
$PYTHON -m venv env $PYTHON -m venv env
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
@@ -555,6 +572,9 @@ workflows:
requires: [ hold ] requires: [ hold ]
- cuda_build_and_test: - cuda_build_and_test:
requires: [ hold ] requires: [ hold ]
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
nightly_build: nightly_build:
when: when:
and: and:

View File

@@ -41,6 +41,7 @@ option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON) option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF) option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF) option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
# --------------------- Processor tests ------------------------- # --------------------- Processor tests -------------------------
@@ -68,6 +69,15 @@ else()
set(MLX_BUILD_METAL OFF) set(MLX_BUILD_METAL OFF)
endif() endif()
if(MLX_USE_CCACHE)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
endif()
endif()
# ----------------------------- Lib ----------------------------- # ----------------------------- Lib -----------------------------
include(FetchContent) include(FetchContent)

View File

@@ -394,14 +394,14 @@ below.
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
// Resolve name of kernel // Resolve name of kernel
std::ostringstream kname; std::stream kname;
kname << "axpby_" << "general_" << type_to_name(out); kname = "axpby_general_" + type_to_name(out);
// Load the metal library // Load the metal library
auto lib = d.get_library("mlx_ext"); auto lib = d.get_library("mlx_ext", current_binary_dir());
// Make a kernel from this metal library // Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), lib); auto kernel = d.get_kernel(kname, lib);
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -1,5 +1,6 @@
// Copyright © 2023-2025 Apple Inc. // Copyright © 2023-2025 Apple Inc.
#include <dlfcn.h>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
@@ -16,6 +17,19 @@
namespace my_ext { namespace my_ext {
// A helper function to find the location of the current binary on disk.
// The Metal library ("mlx_ext.mtllib"), should be in the same directory.
std::string current_binary_dir() {
static std::string binary_dir = []() {
Dl_info info;
if (!dladdr(reinterpret_cast<void*>(&current_binary_dir), &info)) {
throw std::runtime_error("Unable to get current binary dir.");
}
return std::filesystem::path(info.dli_fname).parent_path().string();
}();
return binary_dir;
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Operation Implementation // Operation Implementation
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@@ -167,16 +181,15 @@ void Axpby::eval_gpu(
} }
// Resolve name of kernel (corresponds to axpby.metal) // Resolve name of kernel (corresponds to axpby.metal)
std::ostringstream kname; std::string kname = "axpby_";
kname << "axpby_"; kname += (contiguous_kernel ? "contiguous_" : "general_");
kname << (contiguous_kernel ? "contiguous_" : "general_"); kname += type_to_name(out);
kname << type_to_name(out);
// Load the metal library // Load the metal library
auto lib = d.get_library("mlx_ext"); auto lib = d.get_library("mlx_ext", current_binary_dir());
// Make a kernel from this metal library // Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), lib); auto kernel = d.get_kernel(kname, lib);
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -1,4 +1,4 @@
setuptools>=42 setuptools>=42
cmake>=3.25 cmake>=3.25
mlx>=0.21.0 mlx>=0.21.0
nanobind==2.2.0 nanobind==2.4.0

View File

@@ -3,8 +3,10 @@ from mlx_sample_extensions import axpby
a = mx.ones((3, 4)) a = mx.ones((3, 4))
b = mx.ones((3, 4)) b = mx.ones((3, 4))
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu) c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)
print(f"c shape: {c.shape}") print(f"c shape: {c_cpu.shape}")
print(f"c dtype: {c.dtype}") print(f"c dtype: {c_cpu.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}") print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}")
print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")

View File

@@ -6,6 +6,7 @@
target_sources( target_sources(
mlx mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arange.cu
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu ${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu ${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
@@ -29,7 +30,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/random.cu ${CMAKE_CURRENT_SOURCE_DIR}/random.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
@@ -45,7 +46,8 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu ${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0) if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
@@ -105,11 +107,11 @@ if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>") mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
endif() endif()
# Compute capability 7 is required for synchronization between CPU/GPU with # Compute capability >= 7.0 is required for synchronization between CPU/GPU with
# managed memory. TODO: Add more architectures for potential performance gain. # managed memory.
set(MLX_CUDA_ARCHITECTURES if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
"70;80" set(MLX_CUDA_ARCHITECTURES "native")
CACHE STRING "CUDA architectures") endif()
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}") message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
"${MLX_CUDA_ARCHITECTURES}") "${MLX_CUDA_ARCHITECTURES}")

View File

@@ -0,0 +1,55 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
namespace mlx::core {
namespace cu {
template <typename T>
struct Arange {
const T start;
const T step;
__device__ T operator()(uint32_t i) const {
return start + i * step;
}
};
} // namespace cu
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Arange::eval_gpu");
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(stream());
encoder.set_output_array(out);
auto capture = encoder.capture_context();
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag);
using OutType = cuda_type_t<CTYPE>;
CTYPE step =
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
thrust::transform(
cu::thrust_policy(encoder.stream()),
thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(out.data_size()),
thrust::device_pointer_cast(out.data<OutType>()),
cu::Arange<OutType>{
static_cast<OutType>(start_), static_cast<OutType>(step)});
});
}
} // namespace mlx::core

View File

@@ -44,8 +44,11 @@ struct ArgMin {
} }
template <int N> template <int N>
__device__ IndexValPair<T> __device__ IndexValPair<T> reduce_many(
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) { IndexValPair<T> best,
const AlignedVector<T, N>& vals,
uint32_t offset) {
#pragma unroll
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
if (vals[i] < best.val) { if (vals[i] < best.val) {
best.val = vals[i]; best.val = vals[i];
@@ -74,8 +77,11 @@ struct ArgMax {
} }
template <int N> template <int N>
__device__ IndexValPair<T> __device__ IndexValPair<T> reduce_many(
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) { IndexValPair<T> best,
const AlignedVector<T, N>& vals,
uint32_t offset) {
#pragma unroll
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
if (vals[i] > best.val) { if (vals[i] > best.val) {
best.val = vals[i]; best.val = vals[i];
@@ -106,16 +112,15 @@ __global__ void arg_reduce_general(
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim); int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim); int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
in += in_idx;
Op op; Op op;
T init = op.init(); T init = op.init();
IndexValPair<T> best{0, init}; IndexValPair<T> best{0, init};
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T vals[N_READS];
auto tid = r * BLOCK_DIM + block.thread_index().x; auto tid = r * BLOCK_DIM + block.thread_index().x;
cub::LoadDirectBlocked( auto vals = load_vector<N_READS>(in, tid, axis_size, axis_stride, init);
tid, StridedIterator(in + in_idx, axis_stride), vals, axis_size, init);
best = op.reduce_many(best, vals, tid * N_READS); best = op.reduce_many(best, vals, tid * N_READS);
} }

View File

@@ -28,7 +28,7 @@ __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec; AlignedVector<Out, N_READS> out_vec;
#pragma unroll #pragma unroll
for (int i = 0; i < N_READS; ++i) { for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a[0], b[0]); out_vec[i] = Op{}(a[0], b[0]);
} }
store_vector<N_READS>(out, index, out_vec); store_vector<N_READS>(out, index, out_vec);
@@ -49,7 +49,7 @@ __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec; AlignedVector<Out, N_READS> out_vec;
#pragma unroll #pragma unroll
for (int i = 0; i < N_READS; ++i) { for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a[0], b_vec.val[i]); out_vec[i] = Op{}(a[0], b_vec[i]);
} }
store_vector<N_READS>(out, index, out_vec); store_vector<N_READS>(out, index, out_vec);
@@ -70,7 +70,7 @@ __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec; AlignedVector<Out, N_READS> out_vec;
#pragma unroll #pragma unroll
for (int i = 0; i < N_READS; ++i) { for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a_vec.val[i], b[0]); out_vec[i] = Op{}(a_vec[i], b[0]);
} }
store_vector<N_READS>(out, index, out_vec); store_vector<N_READS>(out, index, out_vec);
@@ -92,7 +92,7 @@ __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec; AlignedVector<Out, N_READS> out_vec;
#pragma unroll #pragma unroll
for (int i = 0; i < N_READS; ++i) { for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i]); out_vec[i] = Op{}(a_vec[i], b_vec[i]);
} }
store_vector<N_READS>(out, index, out_vec); store_vector<N_READS>(out, index, out_vec);
@@ -211,12 +211,15 @@ void binary_op_gpu_inplace(
int ndim = shape.size(); int ndim = shape.size();
if (ndim <= 3) { if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu::
binary_g_nd<Op, InType, OutType, IdxT, dims_constant()>;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large()); get_launch_args(out, large());
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::binary_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant()>,
num_blocks, num_blocks,
block_dims, block_dims,
a.data<InType>(), a.data<InType>(),
@@ -228,11 +231,9 @@ void binary_op_gpu_inplace(
const_param<dims_constant()>(b_strides)); const_param<dims_constant()>(b_strides));
}); });
} else { } else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>; auto [num_blocks, block_dims] = get_launch_args(out, large());
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::binary_g<Op, InType, OutType, IdxT>,
num_blocks, num_blocks,
block_dims, block_dims,
a.data<InType>(), a.data<InType>(),
@@ -248,8 +249,7 @@ void binary_op_gpu_inplace(
} else { } else {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>; using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
// TODO: Choose optimized value based on type size. constexpr int N_READS = 16 / sizeof(InType);
constexpr int N_READS = 4;
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>; auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
if (bopt == BinaryOpType::ScalarVector) { if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_sv<Op, InType, OutType, IdxT, N_READS>; kernel = cu::binary_sv<Op, InType, OutType, IdxT, N_READS>;
@@ -259,12 +259,7 @@ void binary_op_gpu_inplace(
kernel = cu::binary_vv<Op, InType, OutType, IdxT, N_READS>; kernel = cu::binary_vv<Op, InType, OutType, IdxT, N_READS>;
} }
auto [num_blocks, block_dims] = get_launch_args( auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), large(), N_READS);
out.data_size(),
out.shape(),
out.strides(),
large(),
N_READS);
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
num_blocks, num_blocks,

View File

@@ -33,8 +33,8 @@ binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
#pragma unroll #pragma unroll
for (int i = 0; i < N_READS; ++i) { for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a[0], b[0]); auto out = Op{}(a[0], b[0]);
out_a_vec.val[i] = out[0]; out_a_vec[i] = out[0];
out_b_vec.val[i] = out[1]; out_b_vec[i] = out[1];
} }
store_vector<N_READS>(out_a, index, out_a_vec); store_vector<N_READS>(out_a, index, out_a_vec);
@@ -60,9 +60,9 @@ binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
AlignedVector<Out, N_READS> out_b_vec; AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll #pragma unroll
for (int i = 0; i < N_READS; ++i) { for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a[0], b_vec.val[i]); auto out = Op{}(a[0], b_vec[i]);
out_a_vec.val[i] = out[0]; out_a_vec[i] = out[0];
out_b_vec.val[i] = out[1]; out_b_vec[i] = out[1];
} }
store_vector<N_READS>(out_a, index, out_a_vec); store_vector<N_READS>(out_a, index, out_a_vec);
@@ -88,9 +88,9 @@ binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
AlignedVector<Out, N_READS> out_b_vec; AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll #pragma unroll
for (int i = 0; i < N_READS; ++i) { for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec.val[i], b[0]); auto out = Op{}(a_vec[i], b[0]);
out_a_vec.val[i] = out[0]; out_a_vec[i] = out[0];
out_b_vec.val[i] = out[1]; out_b_vec[i] = out[1];
} }
store_vector<N_READS>(out_a, index, out_a_vec); store_vector<N_READS>(out_a, index, out_a_vec);
@@ -117,9 +117,9 @@ binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
AlignedVector<Out, N_READS> out_b_vec; AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll #pragma unroll
for (int i = 0; i < N_READS; ++i) { for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec.val[i], b_vec.val[i]); auto out = Op{}(a_vec[i], b_vec[i]);
out_a_vec.val[i] = out[0]; out_a_vec[i] = out[0];
out_b_vec.val[i] = out[1]; out_b_vec[i] = out[1];
} }
store_vector<N_READS>(out_a, index, out_a_vec); store_vector<N_READS>(out_a, index, out_a_vec);
@@ -227,16 +227,15 @@ void binary_two_op_gpu_inplace(
int ndim = shape.size(); int ndim = shape.size();
if (ndim <= 3) { if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu::binary_two_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant()>;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large()); get_launch_args(out_a, large());
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::binary_two_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant()>,
num_blocks, num_blocks,
block_dims, block_dims,
a.data<InType>(), a.data<InType>(),
@@ -249,11 +248,10 @@ void binary_two_op_gpu_inplace(
const_param<dims_constant()>(b_strides)); const_param<dims_constant()>(b_strides));
}); });
} else { } else {
auto kernel = cu::binary_two_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large()); get_launch_args(out_a, large());
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::binary_two_g<Op, InType, OutType, IdxT>,
num_blocks, num_blocks,
block_dims, block_dims,
a.data<InType>(), a.data<InType>(),
@@ -270,8 +268,7 @@ void binary_two_op_gpu_inplace(
} else { } else {
dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) { dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>; using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
// TODO: Choose optimized value based on type size. constexpr int N_READS = 16 / sizeof(InType);
constexpr int N_READS = 4;
auto kernel = cu::binary_two_ss<Op, InType, OutType, IdxT, N_READS>; auto kernel = cu::binary_two_ss<Op, InType, OutType, IdxT, N_READS>;
if (bopt == BinaryOpType::ScalarVector) { if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_two_sv<Op, InType, OutType, IdxT, N_READS>; kernel = cu::binary_two_sv<Op, InType, OutType, IdxT, N_READS>;
@@ -281,7 +278,6 @@ void binary_two_op_gpu_inplace(
kernel = cu::binary_two_vv<Op, InType, OutType, IdxT, N_READS>; kernel = cu::binary_two_vv<Op, InType, OutType, IdxT, N_READS>;
} }
auto [num_blocks, block_dims] = get_launch_args( auto [num_blocks, block_dims] = get_launch_args(
kernel,
out_a.data_size(), out_a.data_size(),
out_a.shape(), out_a.shape(),
out_a.strides(), out_a.strides(),

View File

@@ -294,7 +294,7 @@ void Compiled::eval_gpu(
auto kernel = mod.get_kernel(kernel_name); auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(kernel, outputs[0], large, work_per_thread); get_launch_args(outputs[0], large, work_per_thread);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
} }

View File

@@ -22,7 +22,7 @@ __global__ void copy_s(const In* in, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec; AlignedVector<Out, N_READS> out_vec;
#pragma unroll #pragma unroll
for (int i = 0; i < N_READS; ++i) { for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = cast_to<Out>(in[0]); out_vec[i] = cast_to<Out>(in[0]);
} }
store_vector<N_READS>(out, index, out_vec); store_vector<N_READS>(out, index, out_vec);
@@ -43,7 +43,7 @@ __global__ void copy_v(const In* in, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec; AlignedVector<Out, N_READS> out_vec;
#pragma unroll #pragma unroll
for (int i = 0; i < N_READS; ++i) { for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = cast_to<Out>(in_vec.val[i]); out_vec[i] = cast_to<Out>(in_vec[i]);
} }
store_vector<N_READS>(out, index, out_vec); store_vector<N_READS>(out, index, out_vec);
@@ -65,19 +65,13 @@ void copy_contiguous(
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>; using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>; using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, uint32_t>; using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
// TODO: Choose optimized value based on type size. constexpr int N_READS = 16 / sizeof(InType);
constexpr int N_READS = 4;
auto kernel = cu::copy_s<InType, OutType, IdxT, N_READS>; auto kernel = cu::copy_s<InType, OutType, IdxT, N_READS>;
if (ctype == CopyType::Vector) { if (ctype == CopyType::Vector) {
kernel = cu::copy_v<InType, OutType, IdxT, N_READS>; kernel = cu::copy_v<InType, OutType, IdxT, N_READS>;
} }
auto [num_blocks, block_dims] = get_launch_args( auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), large(), N_READS);
out.data_size(),
out.shape(),
out.strides(),
large(),
N_READS);
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
num_blocks, num_blocks,

View File

@@ -71,12 +71,10 @@ void copy_general(
data_size *= s; data_size *= s;
if (ndim <= 3) { if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto ndim_constant) { dispatch_1_2_3(ndim, [&](auto ndim_constant) {
auto kernel = auto [num_blocks, block_dims] =
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>; get_launch_args(data_size, shape, out.strides(), large());
auto [num_blocks, block_dims] = get_launch_args(
kernel, data_size, shape, out.strides(), large());
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>,
num_blocks, num_blocks,
block_dims, block_dims,
in_ptr, in_ptr,
@@ -87,11 +85,10 @@ void copy_general(
const_param<ndim_constant()>(strides_out)); const_param<ndim_constant()>(strides_out));
}); });
} else { // ndim >= 4 } else { // ndim >= 4
auto kernel = cu::copy_gg<InType, OutType, IdxT>; auto [num_blocks, block_dims] =
auto [num_blocks, block_dims] = get_launch_args( get_launch_args(data_size, shape, out.strides(), large());
kernel, data_size, shape, out.strides(), large());
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::copy_gg<InType, OutType, IdxT>,
num_blocks, num_blocks,
block_dims, block_dims,
in_ptr, in_ptr,

View File

@@ -74,12 +74,13 @@ void copy_general_dynamic(
int ndim = shape.size(); int ndim = shape.size();
if (ndim <= 3) { if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu:: auto [num_blocks, block_dims] = get_launch_args(out, large());
copy_gg_dynamic_nd<InType, OutType, IdxT, dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::copy_gg_dynamic_nd<
InType,
OutType,
IdxT,
dims_constant()>,
num_blocks, num_blocks,
block_dims, block_dims,
in_ptr, in_ptr,
@@ -92,11 +93,9 @@ void copy_general_dynamic(
dynamic_offset_out.data<int64_t>()); dynamic_offset_out.data<int64_t>());
}); });
} else { // ndim >= 4 } else { // ndim >= 4
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>; auto [num_blocks, block_dims] = get_launch_args(out, large());
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::copy_gg_dynamic<InType, OutType, IdxT>,
num_blocks, num_blocks,
block_dims, block_dims,
in_ptr, in_ptr,

View File

@@ -63,12 +63,9 @@ void copy_general_input(
int ndim = shape.size(); int ndim = shape.size();
if (ndim <= 3) { if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = auto [num_blocks, block_dims] = get_launch_args(out, large());
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>,
num_blocks, num_blocks,
block_dims, block_dims,
in_ptr, in_ptr,
@@ -78,11 +75,9 @@ void copy_general_input(
const_param<dims_constant()>(strides_in)); const_param<dims_constant()>(strides_in));
}); });
} else { // ndim >= 4 } else { // ndim >= 4
auto kernel = cu::copy_g<InType, OutType, IdxT>; auto [num_blocks, block_dims] = get_launch_args(out, large());
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::copy_g<InType, OutType, IdxT>,
num_blocks, num_blocks,
block_dims, block_dims,
in_ptr, in_ptr,

View File

@@ -1,6 +1,7 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/worker.h" #include "mlx/backend/cuda/worker.h"
#include "mlx/utils.h" #include "mlx/utils.h"
@@ -54,6 +55,10 @@ Device::Device(int device) : device_(device) {
CHECK_CUBLAS_ERROR(cublasLtCreate(&lt_)); CHECK_CUBLAS_ERROR(cublasLtCreate(&lt_));
// The cudnn handle is used by Convolution. // The cudnn handle is used by Convolution.
CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_)); CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_));
// Initialize the jit module cache here ensures it is not
// unloaded before any evaluation is done
get_jit_module_cache();
} }
Device::~Device() { Device::~Device() {
@@ -92,23 +97,6 @@ CommandEncoder::CaptureContext::~CaptureContext() {
if (discard) { if (discard) {
return; return;
} }
// Extract and add as single kernel node when possible.
size_t num_nodes;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes));
if (num_nodes == 1) {
cudaGraphNode_t captured_node;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes));
cudaGraphNodeType type;
CHECK_CUDA_ERROR(cudaGraphNodeGetType(captured_node, &type));
if (type == cudaGraphNodeTypeKernel) {
CUDA_KERNEL_NODE_PARAMS params;
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, &params));
enc.add_kernel_node(params);
return;
}
}
// Otherwise add the captured graph as subgraph.
enc.add_graph_node(graph); enc.add_graph_node(graph);
} }
@@ -330,6 +318,7 @@ void CommandEncoder::commit() {
// Reset state // Reset state
node_count_ = 0; node_count_ = 0;
graph_node_count_ = 0; graph_node_count_ = 0;
empty_node_count_ = 0;
from_nodes_.clear(); from_nodes_.clear();
to_nodes_.clear(); to_nodes_.clear();
graph_key_.clear(); graph_key_.clear();

View File

@@ -1,15 +0,0 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core::cu {
template <typename T>
struct Arange {
const T start;
const T step;
__device__ T operator()(uint32_t i) const {
return start + i * step;
}
};
} // namespace mlx::core::cu

View File

@@ -49,11 +49,7 @@ inline __device__ void atomic_add(__half* out, __half val) {
} }
inline __device__ void atomic_add(complex64_t* out, complex64_t val) { inline __device__ void atomic_add(complex64_t* out, complex64_t val) {
#if __CUDA_ARCH__ < 900
atomic_add_general(out, val); atomic_add_general(out, val);
#else
atomicAdd(out, val);
#endif
} }
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) { inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {

View File

@@ -32,36 +32,119 @@ using Strides = cuda::std::array<int64_t, MAX_NDIM>;
template <typename T, int N> template <typename T, int N>
struct alignas(sizeof(T) * N) AlignedVector { struct alignas(sizeof(T) * N) AlignedVector {
T val[N]; T val[N];
__device__ T& operator[](int i) {
return val[i];
}
__device__ T operator[](int i) const {
return val[i];
}
}; };
template <int N, typename T> template <int N, typename T>
inline __device__ AlignedVector<T, N> load_vector( inline __host__ __device__ bool is_aligned(T* x) {
return (reinterpret_cast<uintptr_t>(x) % (N * sizeof(T))) == 0;
}
template <int N, typename T>
inline __device__ AlignedVector<T, N> unsafe_load_vector(
const T* ptr, const T* ptr,
uint32_t offset) { uint32_t offset) {
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr); auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
return from[offset]; return from[offset];
} }
template <int N, typename T>
inline __device__ AlignedVector<T, N> load_vector(
const T* ptr,
uint32_t offset) {
if (is_aligned<N>(ptr)) {
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
return from[offset];
} else {
AlignedVector<T, N> v;
#pragma unroll
for (int i = 0; i < N; ++i) {
v[i] = ptr[offset * N + i];
}
return v;
}
}
template <int N, typename T, typename SizeT>
inline __device__ AlignedVector<T, N>
load_vector(const T* ptr, uint32_t offset, SizeT size, T fallback) {
if (is_aligned<N>(ptr) && (offset + 1) * N <= size) {
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
return from[offset];
} else {
AlignedVector<T, N> v;
#pragma unroll
for (int i = 0; i < N; ++i) {
v[i] = (N * offset + i) < size ? ptr[offset * N + i] : fallback;
}
return v;
}
}
template <int N, typename T, typename SizeT>
inline __device__ AlignedVector<T, N> load_vector(
const T* ptr,
uint32_t offset,
SizeT size,
int64_t stride,
T fallback) {
if (is_aligned<N>(ptr) && stride == 1 && (offset + 1) * N <= size) {
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
return from[offset];
} else {
AlignedVector<T, N> v;
#pragma unroll
for (int i = 0; i < N; ++i) {
v[i] =
(N * offset + i) < size ? ptr[stride * (offset * N + i)] : fallback;
}
return v;
}
}
template <int N, typename T> template <int N, typename T>
inline __device__ void inline __device__ void
store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) { unsafe_store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr); auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec; to[offset] = vec;
} }
// Helper for accessing strided data. template <int N, typename T>
template <typename T> inline __device__ void
struct StridedIterator { store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
T it; if (is_aligned<N>(ptr)) {
int64_t stride; auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
__host__ __device__ StridedIterator(T it, int64_t stride) } else {
: it(it), stride(stride) {} #pragma unroll
for (int i = 0; i < N; ++i) {
__host__ __device__ auto operator[](int i) const { ptr[offset * N + i] = vec[i];
return it[i * stride]; }
} }
}; }
template <int N, typename T, typename SizeT>
inline __device__ void store_vector(
T* ptr,
uint32_t offset,
const AlignedVector<T, N>& vec,
SizeT size) {
if (is_aligned<N>(ptr) && (offset + 1) * N <= size) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
} else {
for (int i = 0; (offset * N + i) < size && i < N; ++i) {
ptr[offset * N + i] = vec[i];
}
}
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Type limits utils // Type limits utils

View File

@@ -11,7 +11,6 @@ namespace mlx::core::cu {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
static constexpr int n_per_thread = 4;
static constexpr int rows_per_block = 8; static constexpr int rows_per_block = 8;
template <typename T, int rows_per_block, int n_per_thread> template <typename T, int rows_per_block, int n_per_thread>
@@ -28,12 +27,13 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
float sum = 0.0f; float sum = 0.0f;
for (int col = n_per_thread * warp.thread_rank(); col < cols; for (int col = n_per_thread * warp.thread_rank(); col < cols;
col += (WARP_SIZE * n_per_thread)) { col += (WARP_SIZE * n_per_thread)) {
auto local_mat = load_vector<n_per_thread>(mat + row * cols + col, 0); auto local_mat =
auto local_vec = load_vector<n_per_thread>(vec + col, 0); unsafe_load_vector<n_per_thread>(mat + row * cols + col, 0);
auto local_vec = unsafe_load_vector<n_per_thread>(vec + col, 0);
#pragma unroll #pragma unroll
for (int j = 0; j < n_per_thread; ++j) { for (int j = 0; j < n_per_thread; ++j) {
sum += static_cast<float>(local_mat.val[j]) * sum +=
static_cast<float>(local_vec.val[j]); static_cast<float>(local_mat[j]) * static_cast<float>(local_vec[j]);
} }
} }
@@ -74,8 +74,22 @@ __global__ void gemv_batched(
} }
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) { bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) {
return K % (WARP_SIZE * n_per_thread) == 0 && return K % 32 == 0 && ((M == 1 && b_transposed) || (N == 1 && !a_transposed));
((M == 1 && b_transposed) || (N == 1 && !a_transposed)); }
template <typename F>
void dispatch_n_per_thread(int n_per_thread, F&& f) {
switch (n_per_thread) {
case 1:
f(std::integral_constant<int, 1>{});
break;
case 2:
f(std::integral_constant<int, 2>{});
break;
case 4:
f(std::integral_constant<int, 4>{});
break;
}
} }
void gemv( void gemv(
@@ -114,33 +128,43 @@ void gemv(
rows = M; rows = M;
} }
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block; uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
if (batch_count == 1) { int n_per_t;
auto kernel = gemv_single<DataType, rows_per_block, n_per_thread>; if (K % 128 == 0 && is_aligned<4>(mat) && is_aligned<4>(vec)) {
encoder.add_kernel_node( n_per_t = 4;
kernel, } else if (K % 64 == 0 && is_aligned<2>(mat) && is_aligned<2>(vec)) {
num_blocks_x, n_per_t = 2;
block_dims,
mat,
vec,
out.data<DataType>(),
rows,
cols);
} else { } else {
auto kernel = gemv_batched<DataType, rows_per_block, n_per_thread>; n_per_t = 1;
encoder.add_kernel_node(
kernel,
dim3{num_blocks_x, batch_count},
block_dims,
mat,
vec,
out.data<DataType>(),
rows,
cols,
const_param(batch_shape),
mat_strides,
vec_strides,
batch_shape.size());
} }
dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) {
if (batch_count == 1) {
auto kernel = gemv_single<DataType, rows_per_block, n_per_thread()>;
encoder.add_kernel_node(
kernel,
num_blocks_x,
block_dims,
mat,
vec,
out.data<DataType>(),
rows,
cols);
} else {
auto kernel = gemv_batched<DataType, rows_per_block, n_per_thread()>;
encoder.add_kernel_node(
kernel,
dim3{num_blocks_x, batch_count},
block_dims,
mat,
vec,
out.data<DataType>(),
rows,
cols,
const_param(batch_shape),
mat_strides,
vec_strides,
batch_shape.size());
}
});
}); });
} }

View File

@@ -128,7 +128,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_output_array(out); encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name); auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); auto [num_blocks, block_dims] = get_launch_args(out, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
} }
@@ -229,7 +229,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
encoder.set_output_array(out); encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name); auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, upd, large); auto [num_blocks, block_dims] = get_launch_args(upd, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
} }
@@ -317,7 +317,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
encoder.set_output_array(out); encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name); auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large); auto [num_blocks, block_dims] = get_launch_args(idx, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
} }
@@ -421,7 +421,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
encoder.set_output_array(out); encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name); auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large); auto [num_blocks, block_dims] = get_launch_args(idx, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
} }

View File

@@ -9,7 +9,6 @@
#include <cstdlib> #include <cstdlib>
#include <filesystem> #include <filesystem>
#include <fstream> #include <fstream>
#include <unordered_map>
#include <fmt/format.h> #include <fmt/format.h>
#include <nvrtc.h> #include <nvrtc.h>
@@ -330,11 +329,16 @@ CUfunction JitModule::get_kernel(const std::string& kernel_name) {
return it->second; return it->second;
} }
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
static std::unordered_map<std::string, JitModule> map;
return map;
}
JitModule& get_jit_module( JitModule& get_jit_module(
const mlx::core::Device& device, const mlx::core::Device& device,
const std::string& name, const std::string& name,
const KernelBuilder& builder) { const KernelBuilder& builder) {
static std::unordered_map<std::string, JitModule> map; auto& map = get_jit_module_cache();
auto it = map.find(name); auto it = map.find(name);
if (it == map.end()) { if (it == map.end()) {
it = map.try_emplace(name, cu::device(device), name, builder).first; it = map.try_emplace(name, cu::device(device), name, builder).first;

View File

@@ -99,6 +99,8 @@ class JitModule {
std::unordered_map<std::string, CUfunction> kernels_; std::unordered_map<std::string, CUfunction> kernels_;
}; };
std::unordered_map<std::string, JitModule>& get_jit_module_cache();
JitModule& get_jit_module( JitModule& get_jit_module(
const mlx::core::Device& device, const mlx::core::Device& device,
const std::string& name, const std::string& name,

View File

@@ -30,4 +30,25 @@ std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2) {
return std::make_pair(dim3(gx, gy, gz), dim3(bx, by, bz)); return std::make_pair(dim3(gx, gy, gz), dim3(bx, by, bz));
} }
std::tuple<dim3, uint> get_launch_args(
size_t size,
const Shape& shape,
const Strides& strides,
bool large,
int work_per_thread) {
size_t nthreads = cuda::ceil_div(size, work_per_thread);
uint block_dim = 1024;
if (block_dim > nthreads) {
block_dim = nthreads;
}
dim3 num_blocks;
if (large) {
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
} else {
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
}
return std::make_tuple(num_blocks, block_dim);
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -120,53 +120,19 @@ dim3 get_2d_grid_dims(
size_t divisor); size_t divisor);
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2); std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
// Return a block size that achieves maximum potential occupancy for kernel.
template <typename T>
inline uint max_occupancy_block_dim(T kernel) {
int _, block_dim;
if constexpr (std::is_same_v<T, CUfunction>) {
CHECK_CUDA_ERROR(
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
} else {
CHECK_CUDA_ERROR(
cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
}
return block_dim;
}
// Get the num_blocks and block_dims that maximize occupancy for |kernel|, // Get the num_blocks and block_dims that maximize occupancy for |kernel|,
// assuming each thread handles |work_per_thread| elements of |arr|. // assuming each thread handles |work_per_thread| elements of |arr|.
template <typename T> std::tuple<dim3, uint> get_launch_args(
inline std::tuple<dim3, uint> get_launch_args(
T kernel,
size_t size, size_t size,
const Shape& shape, const Shape& shape,
const Strides& strides, const Strides& strides,
bool large, bool large,
int work_per_thread = 1) { int work_per_thread = 1);
size_t nthreads = cuda::ceil_div(size, work_per_thread);
uint block_dim = max_occupancy_block_dim(kernel);
if (block_dim > nthreads) {
block_dim = nthreads;
}
dim3 num_blocks;
if (large) {
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
} else {
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
}
return std::make_tuple(num_blocks, block_dim);
}
template <typename T> inline std::tuple<dim3, uint>
inline std::tuple<dim3, uint> get_launch_args( get_launch_args(const array& arr, bool large, int work_per_thread = 1) {
T kernel,
const array& arr,
bool large,
int work_per_thread = 1) {
return get_launch_args( return get_launch_args(
kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread); arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -10,8 +10,6 @@
#include <cooperative_groups.h> #include <cooperative_groups.h>
#include <cooperative_groups/reduce.h> #include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <cub/block/block_load.cuh>
#include <cub/block/block_reduce.cuh>
namespace mlx::core { namespace mlx::core {
@@ -74,9 +72,11 @@ __global__ void layer_norm(
float sum = 0; float sum = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank(); auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS] = {}; auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
cub::LoadDirectBlocked(index, x, xn, axis_size); #pragma unroll
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{})); for (int i = 0; i < N_READS; ++i) {
sum += static_cast<float>(xn[i]);
}
} }
sum = BlockReduceT{block, temp}.Sum(sum); sum = BlockReduceT{block, temp}.Sum(sum);
@@ -87,11 +87,18 @@ __global__ void layer_norm(
float normalizer = 0; float normalizer = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank(); auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS]; if ((index + 1) * N_READS <= axis_size) {
cub::LoadDirectBlocked(index, x, xn, axis_size, mean); auto xn = load_vector<N_READS>(x, index);
for (int i = 0; i < N_READS; ++i) { #pragma unroll
float t = static_cast<float>(xn[i]) - mean; for (int i = 0; i < N_READS; ++i) {
normalizer += t * t; float t = static_cast<float>(xn[i]) - mean;
normalizer += t * t;
}
} else {
for (int i = index * N_READS; i < axis_size; ++i) {
float t = static_cast<float>(x[i]) - mean;
normalizer += t * t;
}
} }
} }
normalizer = BlockReduceT{block, temp}.Sum(normalizer); normalizer = BlockReduceT{block, temp}.Sum(normalizer);
@@ -100,17 +107,15 @@ __global__ void layer_norm(
// Outputs. // Outputs.
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank(); auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS]; auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
T wn[N_READS]; auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
T bn[N_READS]; auto bn = load_vector<N_READS>(b, index, axis_size, b_stride, T(0));
cub::LoadDirectBlocked(index, x, xn, axis_size); #pragma unroll
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
cub::LoadDirectBlocked(index, StridedIterator(b, b_stride), bn, axis_size);
for (int i = 0; i < N_READS; ++i) { for (int i = 0; i < N_READS; ++i) {
float norm = (static_cast<float>(xn[i]) - mean) * normalizer; float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
xn[i] = wn[i] * static_cast<T>(norm) + bn[i]; xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
} }
cub::StoreDirectBlocked(index, out, xn, axis_size); store_vector<N_READS>(out, index, xn, axis_size);
} }
} }
@@ -143,9 +148,11 @@ __global__ void layer_norm_vjp(
float sum = 0; float sum = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank(); auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS] = {}; auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
cub::LoadDirectBlocked(index, x, xn, axis_size); #pragma unroll
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{})); for (int i = 0; i < N_READS; ++i) {
sum += static_cast<float>(xn[i]);
}
} }
sum = BlockReduceF{block, temp.f}.Sum(sum); sum = BlockReduceF{block, temp.f}.Sum(sum);
@@ -155,19 +162,28 @@ __global__ void layer_norm_vjp(
// Normalizer. // Normalizer.
float3 factors = {}; float3 factors = {};
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T xn[N_READS];
T wn[N_READS] = {};
T gn[N_READS] = {};
auto index = r * BLOCK_DIM + block.thread_rank(); auto index = r * BLOCK_DIM + block.thread_rank();
cub::LoadDirectBlocked(index, x, xn, axis_size, mean); auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
cub::LoadDirectBlocked(index, g, gn, axis_size); auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) { if ((index + 1) * N_READS <= axis_size) {
float t = static_cast<float>(xn[i]) - mean; auto xn = load_vector<N_READS>(x, index);
float wi = wn[i]; #pragma unroll
float gi = gn[i]; for (int i = 0; i < N_READS; ++i) {
float wg = wi * gi; float t = static_cast<float>(xn[i]) - mean;
factors = plus_f3(factors, {wg, wg * t, t * t}); float wi = wn[i];
float gi = gn[i];
float wg = wi * gi;
factors = plus_f3(factors, {wg, wg * t, t * t});
}
} else {
for (int i = index * N_READS; i < axis_size; ++i) {
float t = static_cast<float>(x[i]) - mean;
float wi = wn[i];
float gi = gn[i];
float wg = wi * gi;
factors = plus_f3(factors, {wg, wg * t, t * t});
}
} }
} }
factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {}); factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {});
@@ -179,12 +195,10 @@ __global__ void layer_norm_vjp(
// Outputs. // Outputs.
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank(); auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS]; auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
T wn[N_READS]; auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
T gn[N_READS]; auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
float xi = (static_cast<float>(xn[i]) - mean) * normalizer; float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
float wi = wn[i]; float wi = wn[i];
@@ -194,9 +208,9 @@ __global__ void layer_norm_vjp(
wn[i] = gi * xi; wn[i] = gi * xi;
} }
} }
cub::StoreDirectBlocked(index, gx, xn, axis_size); store_vector<N_READS>(gx, index, xn, axis_size);
if constexpr (HAS_W) { if constexpr (HAS_W) {
cub::StoreDirectBlocked(index, gw, wn, axis_size); store_vector<N_READS>(gw, index, wn, axis_size);
} }
} }
} }
@@ -257,9 +271,9 @@ void LayerNorm::eval_gpu(
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) { dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) {
constexpr uint32_t N_READS = 4; using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::layer_norm<DataType, block_dim(), N_READS>; auto kernel = cu::layer_norm<DataType, block_dim(), N_READS>;
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
@@ -364,10 +378,10 @@ void LayerNormVJP::eval_gpu(
encoder.set_output_array(gw_temp); encoder.set_output_array(gw_temp);
dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) { dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) {
dispatch_bool(has_w, [&](auto has_w_constant) { dispatch_bool(has_w, [&](auto has_w_constant) {
constexpr int N_READS = 4; using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
dispatch_block_dim( dispatch_block_dim(
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::layer_norm_vjp< auto kernel = cu::layer_norm_vjp<
DataType, DataType,
has_w_constant.value, has_w_constant.value,

View File

@@ -43,20 +43,19 @@ __global__ void logsumexp(const T* in, T* out, int axis_size) {
AccT maxval = Limits<AccT>::finite_min(); AccT maxval = Limits<AccT>::finite_min();
AccT normalizer = 0; AccT normalizer = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
AccT vals[N_READS]; auto index = r * BLOCK_DIM + block.thread_rank();
cub::LoadDirectBlocked( auto vals = load_vector<N_READS>(in, index, axis_size, Limits<T>::min());
r * BLOCK_DIM + block.thread_rank(),
make_cast_iterator<AccT>(in),
vals,
axis_size,
Limits<AccT>::min());
prevmax = maxval; prevmax = maxval;
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op)); #pragma unroll
for (int i = 0; i < N_READS; ++i) {
maxval = max_op(maxval, static_cast<AccT>(vals[i]));
}
// Online normalizer calculation for softmax: // Online normalizer calculation for softmax:
// https://github.com/NVIDIA/online-softmax // https://github.com/NVIDIA/online-softmax
normalizer = normalizer * softmax_exp(prevmax - maxval); normalizer = normalizer * softmax_exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
normalizer = normalizer + softmax_exp(vals[i] - maxval); normalizer =
normalizer + softmax_exp(static_cast<AccT>(vals[i]) - maxval);
} }
} }
@@ -143,9 +142,9 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) { dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) {
constexpr int N_READS = 4; using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::logsumexp<DataType, float, block_dim(), N_READS>; auto kernel = cu::logsumexp<DataType, float, block_dim(), N_READS>;
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,

View File

@@ -1,47 +1,11 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/arange.cuh"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/distributed/primitives.h" #include "mlx/distributed/primitives.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
#include <cassert>
namespace mlx::core { namespace mlx::core {
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Arange::eval_gpu");
assert(inputs.size() == 0);
out.set_data(allocator::malloc(out.nbytes()));
if (out.size() == 0) {
return;
}
auto& encoder = cu::get_command_encoder(stream());
encoder.set_output_array(out);
auto capture = encoder.capture_context();
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag);
using OutType = cuda_type_t<CTYPE>;
CTYPE step =
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
thrust::transform(
cu::thrust_policy(encoder.stream()),
thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(out.data_size()),
thrust::device_pointer_cast(out.data<OutType>()),
cu::Arange<OutType>{
static_cast<OutType>(start_), static_cast<OutType>(step)});
});
}
bool fast::ScaledDotProductAttention::use_fallback( bool fast::ScaledDotProductAttention::use_fallback(
const array& q, const array& q,
const array& k, const array& k,

View File

@@ -2,30 +2,17 @@
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/cuda/quantized/quantized_utils.cuh"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include <cooperative_groups.h> #include <cooperative_groups.h>
#include <cooperative_groups/reduce.h> #include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core { namespace mlx::core {
namespace cu { namespace cu {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
template <int bits, int wsize = 8>
inline constexpr __device__ short get_pack_factor() {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
template <int bits, int wsize = 8>
inline constexpr __device__ short get_bytes_per_pack() {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits>
__global__ void __global__ void
affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) { affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
@@ -240,144 +227,100 @@ __global__ void affine_dequantize(
} }
} // namespace cu } // namespace cu
namespace {
inline array ensure_row_contiguous( void affine_quantize(
const array& x, const array& w,
array& wq,
array& scales,
array& biases,
int group_size_,
int bits_,
cu::CommandEncoder& enc, cu::CommandEncoder& enc,
const Stream& s) { const Stream& s) {
if (!x.flags().row_contiguous) { // Calculate the number of elements per thread
array x_copy = contiguous_copy_gpu(x, s); int per_thread = group_size_ / WARP_SIZE;
enc.add_temporary(x_copy); size_t size = w.size() / per_thread;
return x_copy;
} else {
return x;
}
}
} // namespace
template <typename F>
void dispatch_groups(int group_size, F&& f) {
switch (group_size) {
case 32:
f(std::integral_constant<int, 32>{});
break;
case 64:
f(std::integral_constant<int, 64>{});
break;
case 128:
f(std::integral_constant<int, 128>{});
break;
}
}
template <typename F>
void dispatch_bits(int bits, F&& f) {
switch (bits) {
case 2:
f(std::integral_constant<int, 2>{});
break;
case 3:
f(std::integral_constant<int, 3>{});
break;
case 4:
f(std::integral_constant<int, 4>{});
break;
case 5:
f(std::integral_constant<int, 5>{});
break;
case 6:
f(std::integral_constant<int, 6>{});
break;
case 8:
f(std::integral_constant<int, 8>{});
break;
}
}
void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& w_pre = inputs[0];
auto& out = outputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);
auto w = ensure_row_contiguous(w_pre, enc, s);
enc.set_input_array(w);
if (dequantize_) {
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto biases = ensure_row_contiguous(inputs[2], enc, s);
enc.set_input_array(scales);
enc.set_input_array(biases);
enc.set_output_array(out);
} else {
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
enc.set_output_array(out);
enc.set_output_array(scales);
enc.set_output_array(biases);
}
auto dtype = dequantize_ ? outputs[0].dtype() : inputs[0].dtype();
// Treat uint32 as uint8 in kernel
int uint8_per_uint32 = 4;
int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8
: bits_ == 6 ? 4
: 8 / bits_;
int per_thread = dequantize_ ? packs_per_int : group_size_ / WARP_SIZE;
size_t size =
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
// Calculate the thread grid that we need to launch
bool large = size > UINT_MAX; bool large = size > UINT_MAX;
auto grid_shape = w.shape(); auto grid_shape = w.shape();
grid_shape.back() /= per_thread;
if (dequantize_) { enc.set_input_array(w);
grid_shape.back() *= uint8_per_uint32; enc.set_output_array(wq);
} else { enc.set_output_array(scales);
grid_shape.back() /= per_thread; enc.set_output_array(biases);
} dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
dispatch_float_types(dtype, "affine_quantize", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) { dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) { dispatch_bits(bits_, [&](auto bits) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if (dequantize_) { auto kernel = cu::affine_quantize<T, group_size.value, bits.value>;
auto kernel = auto [num_blocks, block_dims] =
cu::affine_dequantize<DataType, group_size.value, bits.value>; get_launch_args(size, grid_shape, w.strides(), large);
auto [num_blocks, block_dims] = enc.add_kernel_node(
get_launch_args(kernel, size, grid_shape, w.strides(), large); kernel,
enc.add_kernel_node( num_blocks,
kernel, block_dims,
num_blocks, w.data<T>(),
block_dims, wq.data<uint8_t>(),
w.data<uint8_t>(), scales.data<T>(),
inputs[1].data<DataType>(), biases.data<T>(),
inputs[2].data<DataType>(), w.size());
out.data<DataType>(), });
out.size()); });
} else { });
auto kernel = }
cu::affine_quantize<DataType, group_size.value, bits.value>;
auto [num_blocks, block_dims] = void affine_dequantize(
get_launch_args(kernel, size, grid_shape, w.strides(), large); const array& wq,
enc.add_kernel_node( const array& scales,
kernel, const array& biases,
num_blocks, array& w,
block_dims, int group_size_,
w.data<DataType>(), int bits_,
out.data<uint8_t>(), cu::CommandEncoder& enc,
outputs[1].data<DataType>(), const Stream& s) {
outputs[2].data<DataType>(), // Calculate how many numbers we pack together. For 2, 4, 8 bits we pack in
w.size()); // one uint8, for 3, 6 in 3 uint8 and for 5 in 5 uint8.
} constexpr int uint8_per_uint32 = 4;
int packs_per_int;
switch (bits_) {
case 3:
case 5:
packs_per_int = 8;
break;
case 6:
packs_per_int = 4;
break;
default:
packs_per_int = 8 / bits_;
}
size_t size = w.size() / packs_per_int;
bool large = size > UINT_MAX;
auto grid_shape = w.shape();
grid_shape.back() *= uint8_per_uint32;
enc.set_input_array(wq);
enc.set_input_array(scales);
enc.set_input_array(biases);
enc.set_output_array(w);
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::affine_dequantize<T, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
wq.data<uint8_t>(),
scales.data<T>(),
biases.data<T>(),
w.data<T>(),
w.size());
}); });
}); });
}); });

View File

@@ -0,0 +1,72 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/quantized/quantized.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/fast_primitives.h"
namespace mlx::core {
namespace {
inline array ensure_row_contiguous(
const array& x,
cu::CommandEncoder& enc,
const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
} else {
return x;
}
}
inline array ensure_row_contiguous_matrix(
const array& x,
cu::CommandEncoder& enc,
const Stream& s) {
auto stride_0 = x.strides()[x.ndim() - 2];
auto stride_1 = x.strides()[x.ndim() - 1];
if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
}
}
} // namespace
void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);
if (dequantize_) {
auto wq = ensure_row_contiguous(inputs[0], enc, s);
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto biases = ensure_row_contiguous(inputs[2], enc, s);
auto& w = outputs[0];
w.set_data(allocator::malloc(w.nbytes()));
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
} else {
auto w = ensure_row_contiguous(inputs[0], enc, s);
auto& wq = outputs[0];
auto& scales = outputs[1];
auto& biases = outputs[2];
wq.set_data(allocator::malloc(wq.nbytes()));
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,27 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
namespace mlx::core {
void affine_quantize(
const array& w,
array& wq,
array& scales,
array& biases,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s);
void affine_dequantize(
const array& wq,
const array& scales,
const array& biases,
array& w,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s);
} // namespace mlx::core

View File

@@ -0,0 +1,59 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core {
namespace cu {
template <int bits, int wsize = 8>
inline constexpr __device__ short get_pack_factor() {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
template <int bits, int wsize = 8>
inline constexpr __device__ short get_bytes_per_pack() {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
} // namespace cu
template <typename F>
void dispatch_groups(int group_size, F&& f) {
switch (group_size) {
case 32:
f(std::integral_constant<int, 32>{});
break;
case 64:
f(std::integral_constant<int, 64>{});
break;
case 128:
f(std::integral_constant<int, 128>{});
break;
}
}
template <typename F>
void dispatch_bits(int bits, F&& f) {
switch (bits) {
case 2:
f(std::integral_constant<int, 2>{});
break;
case 3:
f(std::integral_constant<int, 3>{});
break;
case 4:
f(std::integral_constant<int, 4>{});
break;
case 5:
f(std::integral_constant<int, 5>{});
break;
case 6:
f(std::integral_constant<int, 6>{});
break;
case 8:
f(std::integral_constant<int, 8>{});
break;
}
}
} // namespace mlx::core

View File

@@ -5,8 +5,6 @@
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/fill.h>
#include <cassert> #include <cassert>

View File

@@ -10,8 +10,6 @@
#include <cooperative_groups.h> #include <cooperative_groups.h>
#include <cooperative_groups/reduce.h> #include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <cub/block/block_load.cuh>
#include <cub/block/block_reduce.cuh>
namespace mlx::core { namespace mlx::core {
@@ -57,7 +55,7 @@ __global__ void rms_norm(
const T* w, const T* w,
T* out, T* out,
float eps, float eps,
int32_t axis_size, uint32_t axis_size,
int64_t w_stride) { int64_t w_stride) {
auto grid = cg::this_grid(); auto grid = cg::this_grid();
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
@@ -72,8 +70,8 @@ __global__ void rms_norm(
float normalizer = 0; float normalizer = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank(); auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS]; auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0)); #pragma unroll
for (int i = 0; i < N_READS; ++i) { for (int i = 0; i < N_READS; ++i) {
float t = static_cast<float>(xn[i]); float t = static_cast<float>(xn[i]);
normalizer += t * t; normalizer += t * t;
@@ -85,15 +83,14 @@ __global__ void rms_norm(
// Outputs. // Outputs.
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank(); auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS]; auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
T wn[N_READS]; auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
cub::LoadDirectBlocked(index, x, xn, axis_size); #pragma unroll
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; ++i) { for (int i = 0; i < N_READS; ++i) {
float norm = static_cast<float>(xn[i]) * normalizer; float y = static_cast<float>(xn[i]) * normalizer;
xn[i] = wn[i] * static_cast<T>(norm); xn[i] = wn[i] * static_cast<T>(y);
} }
cub::StoreDirectBlocked(index, out, xn, axis_size); store_vector<N_READS>(out, index, xn, axis_size);
} }
} }
@@ -125,13 +122,10 @@ __global__ void rms_norm_vjp(
// Normalizer. // Normalizer.
float2 factors = {}; float2 factors = {};
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T xn[N_READS];
T wn[N_READS] = {};
T gn[N_READS] = {};
auto index = r * BLOCK_DIM + block.thread_rank(); auto index = r * BLOCK_DIM + block.thread_rank();
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0)); auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
cub::LoadDirectBlocked(index, g, gn, axis_size); auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size); auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
float t = static_cast<float>(xn[i]); float t = static_cast<float>(xn[i]);
float wi = wn[i]; float wi = wn[i];
@@ -148,12 +142,9 @@ __global__ void rms_norm_vjp(
// Outputs. // Outputs.
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank(); auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS]; auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
T wn[N_READS]; auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
T gn[N_READS]; auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
float xi = xn[i]; float xi = xn[i];
float wi = wn[i]; float wi = wn[i];
@@ -163,9 +154,9 @@ __global__ void rms_norm_vjp(
wn[i] = static_cast<T>(gi * xi * normalizer); wn[i] = static_cast<T>(gi * xi * normalizer);
} }
} }
cub::StoreDirectBlocked(index, gx, xn, axis_size); store_vector<N_READS>(gx, index, xn, axis_size);
if constexpr (HAS_W) { if constexpr (HAS_W) {
cub::StoreDirectBlocked(index, gw, wn, axis_size); store_vector<N_READS>(gw, index, wn, axis_size);
} }
} }
} }
@@ -223,9 +214,9 @@ void RMSNorm::eval_gpu(
encoder.set_input_array(w); encoder.set_input_array(w);
encoder.set_output_array(out); encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) { dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
constexpr uint32_t N_READS = 4; using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>; auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>;
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
@@ -312,11 +303,10 @@ void RMSNormVJP::eval_gpu(
encoder.set_output_array(gw_temp); encoder.set_output_array(gw_temp);
dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) { dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) {
dispatch_bool(has_w, [&](auto has_w_constant) { dispatch_bool(has_w, [&](auto has_w_constant) {
constexpr int N_READS = 4; using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
dispatch_block_dim( dispatch_block_dim(
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 4;
auto kernel = cu::rms_norm_vjp< auto kernel = cu::rms_norm_vjp<
DataType, DataType,
has_w_constant.value, has_w_constant.value,

View File

@@ -11,7 +11,6 @@
#include <cooperative_groups.h> #include <cooperative_groups.h>
#include <cooperative_groups/reduce.h> #include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <cub/block/block_load.cuh>
#include <cassert> #include <cassert>
@@ -45,20 +44,21 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
AccT maxval = Limits<AccT>::finite_min(); AccT maxval = Limits<AccT>::finite_min();
AccT normalizer = cast_to<AccT>(0); AccT normalizer = cast_to<AccT>(0);
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
AccT vals[N_READS]; auto index = r * BLOCK_DIM + block.thread_rank();
cub::LoadDirectBlocked( auto vals = load_vector<N_READS>(in, index, axis_size, Limits<T>::min());
r * BLOCK_DIM + block.thread_rank(),
make_cast_iterator<AccT>(in),
vals,
axis_size,
Limits<AccT>::min());
prevmax = maxval; prevmax = maxval;
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op)); #pragma unroll
for (int i = 0; i < N_READS; ++i) {
maxval = max_op(maxval, static_cast<AccT>(vals[i]));
}
// Online normalizer calculation for softmax: // Online normalizer calculation for softmax:
// https://github.com/NVIDIA/online-softmax // https://github.com/NVIDIA/online-softmax
normalizer = normalizer * softmax_exp(prevmax - maxval); normalizer = normalizer * softmax_exp(prevmax - maxval);
#pragma unroll
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
normalizer = normalizer + softmax_exp(vals[i] - maxval); normalizer =
normalizer + softmax_exp(static_cast<AccT>(vals[i]) - maxval);
} }
} }
@@ -95,12 +95,11 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
// Write output. // Write output.
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
auto index = r * BLOCK_DIM + block.thread_rank(); auto index = r * BLOCK_DIM + block.thread_rank();
T vals[N_READS]; auto vals = load_vector<N_READS>(in, index, axis_size, T(0));
cub::LoadDirectBlocked(index, in, vals, axis_size);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
vals[i] = softmax_exp(static_cast<AccT>(vals[i]) - maxval) * normalizer; vals[i] = softmax_exp(static_cast<AccT>(vals[i]) - maxval) * normalizer;
} }
cub::StoreDirectBlocked(index, out, vals, axis_size); store_vector<N_READS>(out, index, vals, axis_size);
} }
} }
@@ -141,9 +140,9 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) { dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) {
constexpr int N_READS = 4; using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::softmax<DataType, DataType, block_dim(), N_READS>; auto kernel = cu::softmax<DataType, DataType, block_dim(), N_READS>;
if (precise) { if (precise) {
kernel = cu::softmax<DataType, float, block_dim(), N_READS>; kernel = cu::softmax<DataType, float, block_dim(), N_READS>;

View File

@@ -32,7 +32,7 @@ ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) {
AlignedVector<T, N_READS> out_vec; AlignedVector<T, N_READS> out_vec;
#pragma unroll #pragma unroll
for (int i = 0; i < N_READS; ++i) { for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i], c_vec.val[i]); out_vec[i] = Op{}(a_vec[i], b_vec[i], c_vec[i]);
} }
store_vector<N_READS>(out, index, out_vec); store_vector<N_READS>(out, index, out_vec);
@@ -125,12 +125,9 @@ void ternary_op_gpu_inplace(
int ndim = shape.size(); int ndim = shape.size();
if (ndim <= 3) { if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = auto [num_blocks, block_dims] = get_launch_args(out, large());
cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>,
num_blocks, num_blocks,
block_dims, block_dims,
a.data<bool>(), a.data<bool>(),
@@ -144,11 +141,9 @@ void ternary_op_gpu_inplace(
const_param<dims_constant()>(c_strides)); const_param<dims_constant()>(c_strides));
}); });
} else { } else {
auto kernel = cu::ternary_g<Op, DType, IdxT>; auto [num_blocks, block_dims] = get_launch_args(out, large());
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::ternary_g<Op, DType, IdxT>,
num_blocks, num_blocks,
block_dims, block_dims,
a.data<bool>(), a.data<bool>(),
@@ -166,18 +161,11 @@ void ternary_op_gpu_inplace(
} else { } else {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>; using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
// TODO: Choose optimized value based on type size. constexpr int N_READS = 16 / sizeof(DType);
constexpr int N_READS = 4;
auto kernel = cu::ternary_v<Op, DType, IdxT, N_READS>;
auto [num_blocks, block_dims] = get_launch_args( auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), large(), N_READS);
out.data_size(),
out.shape(),
out.strides(),
large(),
N_READS);
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::ternary_v<Op, DType, IdxT, N_READS>,
num_blocks, num_blocks,
block_dims, block_dims,
a.data<bool>(), a.data<bool>(),

View File

@@ -30,7 +30,7 @@ __global__ void unary_v(const In* in, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec; AlignedVector<Out, N_READS> out_vec;
#pragma unroll #pragma unroll
for (int i = 0; i < N_READS; ++i) { for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(in_vec.val[i]); out_vec[i] = Op{}(in_vec[i]);
} }
store_vector<N_READS>(out, index, out_vec); store_vector<N_READS>(out, index, out_vec);
@@ -129,16 +129,10 @@ void unary_op_gpu_inplace(
using IdxT = std::conditional_t<large(), int64_t, uint32_t>; using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
// TODO: Choose optimized value based on type size. // TODO: Choose optimized value based on type size.
constexpr int N_READS = 4; constexpr int N_READS = 4;
auto kernel = cu::unary_v<Op, InType, OutType, IdxT, N_READS>;
auto [num_blocks, block_dims] = get_launch_args( auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), large, N_READS);
out.data_size(),
out.shape(),
out.strides(),
large,
N_READS);
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::unary_v<Op, InType, OutType, IdxT, N_READS>,
num_blocks, num_blocks,
block_dims, block_dims,
in.data<InType>(), in.data<InType>(),
@@ -147,10 +141,9 @@ void unary_op_gpu_inplace(
} else { } else {
using IdxT = std::conditional_t<large(), int64_t, int32_t>; using IdxT = std::conditional_t<large(), int64_t, int32_t>;
auto [shape, strides] = collapse_contiguous_dims(in); auto [shape, strides] = collapse_contiguous_dims(in);
auto kernel = cu::unary_g<Op, InType, OutType, IdxT>; auto [num_blocks, block_dims] = get_launch_args(out, large);
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, cu::unary_g<Op, InType, OutType, IdxT>,
num_blocks, num_blocks,
block_dims, block_dims,
in.data<InType>(), in.data<InType>(),

View File

@@ -399,41 +399,7 @@ class Module(dict):
Returns: Returns:
The module instance after updating the submodules. The module instance after updating the submodules.
""" """
_update_modules(self, modules, strict)
def apply(dst, modules):
if isinstance(modules, dict):
for k in modules:
if k in dst:
current_value = dst[k]
new_value = modules[k]
if self.is_module(current_value) and self.is_module(new_value):
dst[k] = new_value
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
elif strict and new_value != {}:
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
elif strict:
raise ValueError(
f'Module does not have sub-module named "{k}".'
)
elif isinstance(modules, list):
for i in range(len(modules)):
current_value = dst[i]
new_value = modules[i]
if self.is_module(current_value) and self.is_module(new_value):
dst[i] = new_value
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
elif strict and new_value != {}:
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
elif strict:
raise ValueError(f"Received invalid type: {type(modules).__name__}.")
apply(self, modules)
return self return self
def apply_to_modules(self, apply_fn: Callable[[str, Module], Any]) -> Module: def apply_to_modules(self, apply_fn: Callable[[str, Module], Any]) -> Module:
@@ -639,6 +605,36 @@ class Module(dict):
self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x) self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)
def _update_modules(dst, modules, strict):
if isinstance(modules, dict):
for k in modules:
if k in dst:
current_value = dst[k]
new_value = modules[k]
if Module.is_module(current_value) and Module.is_module(new_value):
dst[k] = new_value
elif isinstance(current_value, (dict, list)):
_update_modules(current_value, new_value, strict)
elif strict and new_value != {}:
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
elif strict:
raise ValueError(f'Module does not have sub-module named "{k}".')
elif isinstance(modules, list):
for i in range(len(modules)):
current_value = dst[i]
new_value = modules[i]
if Module.is_module(current_value) and Module.is_module(new_value):
dst[i] = new_value
elif isinstance(current_value, (dict, list)):
_update_modules(current_value, new_value, strict)
elif strict and new_value != {}:
raise ValueError(f"Received invalid type: {type(new_value).__name__}.")
elif strict:
raise ValueError(f"Received invalid type: {type(modules).__name__}.")
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn): def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
if is_leaf_fn(model, value_key, value): if is_leaf_fn(model, value_key, value):
return map_fn(value) return map_fn(value)

View File

@@ -47,7 +47,7 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(out_mlx, out_npy.astype(np_dtype), atol=1e-5)) self.assertTrue(np.allclose(out_mlx, out_npy.astype(np_dtype), atol=1e-5))
def test_matmul_unaligned(self): def test_matmul_unaligned(self):
if not mx.metal.is_available(): if not mx.is_available(mx.gpu):
return return
for dtype in self.dtypes: for dtype in self.dtypes:
@@ -61,8 +61,15 @@ class TestBlas(mlx_tests.MLXTestCase):
shape_b = (dim + p, dim + p) shape_b = (dim + p, dim + p)
self.__gemm_test(shape_a, shape_b, np_dtype) self.__gemm_test(shape_a, shape_b, np_dtype)
def test_matvec_unaligned(self):
a = mx.random.normal(shape=(4, 128))
b = mx.random.normal(shape=(129,))[1:]
out = a @ b
np_out = np.array(a) @ np.array(b)
self.assertTrue(np.allclose(out, np_out))
def test_matmul_shapes(self): def test_matmul_shapes(self):
if not mx.metal.is_available(): if not mx.is_available(mx.gpu):
return return
shapes = [ shapes = [
@@ -1274,7 +1281,7 @@ class TestBlas(mlx_tests.MLXTestCase):
def test_gemv_gemm_same_precision(self): def test_gemv_gemm_same_precision(self):
mx.random.seed(0) mx.random.seed(0)
N = 256 N = 256
if mx.metal.is_available(): if mx.is_available(mx.gpu):
t = mx.bfloat16 t = mx.bfloat16
a = mx.random.normal([1, N]).astype(t) a = mx.random.normal([1, N]).astype(t)
b = mx.concatenate([a, a], axis=0).astype(t) b = mx.concatenate([a, a], axis=0).astype(t)

View File

@@ -279,6 +279,23 @@ class TestBase(mlx_tests.MLXTestCase):
del m.weight del m.weight
self.assertFalse(hasattr(m, "weight")) self.assertFalse(hasattr(m, "weight"))
def test_circular_leaks(self):
y = mx.random.uniform(1)
mx.eval(y)
def make_and_update():
model = nn.Linear(1024, 512)
mx.eval(model.parameters())
leaves = {}
model.update_modules(leaves)
mx.synchronize()
pre = mx.get_active_memory()
make_and_update()
mx.synchronize()
post = mx.get_active_memory()
self.assertEqual(pre, post)
class TestLayers(mlx_tests.MLXTestCase): class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self): def test_identity(self):

View File

@@ -3049,6 +3049,25 @@ class TestOps(mlx_tests.MLXTestCase):
out = mx.power(mx.array(0j), float("nan")) out = mx.power(mx.array(0j), float("nan"))
self.assertTrue(mx.isnan(out)) self.assertTrue(mx.isnan(out))
def test_irregular_alignments(self):
# Unaligned unary op
a = mx.ones((64, 1))
b = -a[1:]
self.assertTrue(mx.all(b == -1.0))
# Unaligned binary op
a = mx.ones((64, 1))
b = a[1:]
c = b + b
self.assertTrue(mx.all(c == 2.0))
# Unaligned ternary op
a = mx.ones((64, 1))
b = mx.zeros((63, 1))
c = mx.ones((63, 1)).astype(mx.bool_)
d = mx.where(c, a[1:], b)
self.assertTrue(mx.all(d == 1.0))
class TestBroadcast(mlx_tests.MLXTestCase): class TestBroadcast(mlx_tests.MLXTestCase):
def test_broadcast_shapes(self): def test_broadcast_shapes(self):

View File

@@ -44,6 +44,8 @@ def get_version():
build_stage = int(os.environ.get("MLX_BUILD_STAGE", 0)) build_stage = int(os.environ.get("MLX_BUILD_STAGE", 0))
build_macos = platform.system() == "Darwin"
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
# A CMakeExtension needs a sourcedir instead of a file list. # A CMakeExtension needs a sourcedir instead of a file list.
@@ -85,6 +87,11 @@ class CMakeBuild(build_ext):
"-DMLX_BUILD_EXAMPLES=OFF", "-DMLX_BUILD_EXAMPLES=OFF",
f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={extdir}{os.sep}", f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={extdir}{os.sep}",
] ]
if build_stage == 2 and build_cuda:
# Last arch is always real and virtual for forward-compatibility
cuda_archs = ";".join(("70-real", "80-real", "90-real", "100-real", "120"))
cmake_args += [f"-DMLX_CUDA_ARCHITECTURES={cuda_archs}"]
# Some generators require explcitly passing config when building. # Some generators require explcitly passing config when building.
build_args = ["--config", cfg] build_args = ["--config", cfg]
# Adding CMake arguments set as environment variable # Adding CMake arguments set as environment variable
@@ -95,7 +102,7 @@ class CMakeBuild(build_ext):
# Pass version to C++ # Pass version to C++
cmake_args += [f"-DMLX_VERSION={self.distribution.get_version()}"] # type: ignore[attr-defined] cmake_args += [f"-DMLX_VERSION={self.distribution.get_version()}"] # type: ignore[attr-defined]
if platform.system() == "Darwin": if build_macos:
# Cross-compile support for macOS - respect ARCHFLAGS if set # Cross-compile support for macOS - respect ARCHFLAGS if set
archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", "")) archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", ""))
if archs: if archs:
@@ -113,6 +120,9 @@ class CMakeBuild(build_ext):
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
build_args += [f"-j{os.cpu_count()}"] build_args += [f"-j{os.cpu_count()}"]
# Avoid cache miss when building from temporary dirs.
os.environ["CCACHE_BASEDIR"] = os.path.abspath(self.build_temp)
subprocess.run( subprocess.run(
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
) )
@@ -202,9 +212,6 @@ if __name__ == "__main__":
], ],
) )
build_macos = platform.system() == "Darwin"
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
version = get_version() version = get_version()
_setup = partial( _setup = partial(