Compare commits

...

12 Commits

Author SHA1 Message Date
Jagrit Digani
5cf6f10bef Add debug line info 2025-12-02 14:49:11 -08:00
Jagrit Digani
7c1abc50c0 Update make compiled preamble to not preprocess macros 2025-12-02 14:25:00 -08:00
Cheng
2b95d0c270 [CUDA] Use cuDNN attention when T_q != T_kv (#2843)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-11-27 09:58:43 +09:00
Chaoran Yu
b054838780 Added clarification to apply_fn parameter of apply_to_modules (#2831)
Co-authored-by: Awni Hannun <awni@apple.com>
2025-11-26 15:40:56 -08:00
Awni Hannun
dd79d3c465 [CUDA] Faster rms norm for small dimension (#2838) 2025-11-26 15:10:41 -08:00
Cheng
704fd1ae28 [CUDA] Support array mask in SDPA (#2822)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-11-26 11:08:58 +09:00
Cheng
c9f4dc851f Merge build-cuda and build-linux actions (#2783)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-11-25 20:06:42 +09:00
Cheng
f8bd675655 [CUDA] Output of SDPA should have same layout with inputs (#2826)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-25 15:22:58 +09:00
Cheng
23a9168d34 [CUDA] Add debug env to save cuda graphs to dot files (#2825) 2025-11-25 15:22:36 +09:00
Awni Hannun
bca205e287 [CUDA] Exit on crash and more helpful errors (#2830) 2025-11-24 19:46:03 -08:00
CCYeh
1d4eacb737 Fix mx.core.linspace type annotation (#2820)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2025-11-24 14:15:08 -08:00
dependabot[bot]
8abd37ad05 Bump actions/checkout from 5 to 6 (#2828)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-24 06:04:46 -08:00
28 changed files with 708 additions and 223 deletions

View File

@@ -1,18 +1,13 @@
name: 'Build CUDA wheel'
description: 'Build CUDA wheel'
inputs:
toolkit:
description: 'The CUDA toolkit'
required: true
runs:
using: "composite"
steps:
- name: Build package
shell: bash
env:
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON
run: |
pip install auditwheel build patchelf setuptools
python setup.py clean --all

View File

@@ -1,26 +0,0 @@
name: 'Build and Test with CUDA'
description: 'Build and test MLX with CUDA'
inputs:
toolkit:
description: 'The CUDA toolkit'
required: true
runs:
using: "composite"
steps:
- name: Install Python package
shell: bash
env:
DEBUG: 1
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc
run: pip install --no-build-isolation -e ".[dev]" -v
- name: Build CPP only
shell: bash
run: |
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j $(nproc)

View File

@@ -1,25 +1,41 @@
name: 'Build and Test on Linux'
description: 'Build and test MLX on Linux'
inputs:
toolkit:
description: 'The toolkit to build with'
required: false
default: 'cpu'
runs:
using: "composite"
steps:
- name: Install Python package
id: python_build
shell: sh
env:
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
DEBUG: 1
run: pip install --no-build-isolation -e ".[dev]" -v
CMAKE_ARGS: >-
-DCMAKE_COMPILE_WARNING_AS_ERROR=ON
-DMLX_BUILD_CUDA=${{ startsWith(inputs.toolkit, 'cuda') && 'ON' || 'OFF' }}
run: |
if ${{ startsWith(inputs.toolkit, 'cuda') && runner.arch == 'arm64' }} ; then
# There is no GPU in arm64 runner, use a common arch.
CMAKE_ARGS="$CMAKE_ARGS -DMLX_CUDA_ARCHITECTURES=90a"
# Can not build tests when the built executables can not run.
CMAKE_ARGS="$CMAKE_ARGS -DMLX_BUILD_TESTS=OFF"
fi
pip install --no-build-isolation -e ".[dev]" -v
# Pass the CMAKE_ARGS to following steps.
echo CMAKE_ARGS="$CMAKE_ARGS" >> $GITHUB_OUTPUT
- name: Generate package stubs
shell: sh
run: |
pip install typing_extensions
python setup.py generate_stubs
- name: Build CPP only
shell: bash
run: |
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j $(nproc)
cmake . -B build -DCMAKE_BUILD_TYPE=Debug ${{ steps.python_build.outputs.CMAKE_ARGS }}
cmake --build build -j $(nproc)

View File

@@ -51,8 +51,6 @@ runs:
# Note: the CI machine does not meet CUDA 13's driver requirement.
# Compatibility matrix:
# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
# The `nvcc` is installed into `/usr/local/cuda-VERSION/bin/nvcc` - but
# it's *not* on the default toolkit path.
PACKAGES: |
{
"cuda-12.6": "libcudnn9-dev-cuda-12 cuda-toolkit-12-6",
@@ -60,13 +58,16 @@ runs:
"cuda-13.0": "libcudnn9-dev-cuda-13 cuda-toolkit-13-0"
}
run: |
export ARCH=${{ runner.arch == 'arm64' && 'arm64' || 'x86_64' }}
# The CUDA binaries are hosted in the "sbsa" repo, the "arm64" repo is
# Jetson specific. SBSA means Arm Server Base System Architecture.
ARCH=${{ runner.arch == 'arm64' && 'sbsa' || 'x86_64' }}
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$ARCH/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install -y \
libnccl2 libnccl-dev \
${{ fromJson(env.PACKAGES)[inputs.toolkit] }}
echo "/usr/local/${{ inputs.toolkit }}/bin" >> $GITHUB_PATH
- name: CUDA packages and driver report
if: ${{ startsWith(inputs.toolkit, 'cuda') }}

View File

@@ -1,8 +1,8 @@
name: 'Run Linux tests'
inputs:
cpu-only:
description: 'Skip GPU tests'
has-gpu:
description: 'Run GPU tests'
required: false
default: false
@@ -17,7 +17,7 @@ runs:
echo "::endgroup::"
- name: Run distributed tests
if: ${{ inputs.cpu-only == 'true' }}
if: ${{ inputs.has-gpu == 'false' }}
shell: bash
run: |
echo "::group::Distributed tests"
@@ -30,7 +30,7 @@ runs:
echo "::endgroup::"
- name: Run Python tests - CPU
if: ${{ inputs.cpu-only == 'true' }}
if: ${{ inputs.has-gpu == 'false' }}
shell: bash
env:
DEVICE: cpu
@@ -40,7 +40,7 @@ runs:
echo "::endgroup::"
- name: Run Python tests - GPU
if: ${{ inputs.cpu-only == 'false' }}
if: ${{ inputs.has-gpu == 'true' }}
shell: bash
env:
DEVICE: gpu
@@ -59,7 +59,7 @@ runs:
echo "::endgroup::"
- name: Run CPP tests - GPU
if: ${{ inputs.cpu-only == 'false' }}
if: ${{ inputs.has-gpu == 'true' }}
shell: bash
env:
DEVICE: gpu

View File

@@ -17,29 +17,51 @@ concurrency:
jobs:
check_lint:
name: Check Lint
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: pre-commit/action@v3.0.1
linux_build_and_test:
name: Linux (cpu, ${{ matrix.arch }})
needs: check_lint
strategy:
matrix:
runner:
- ubuntu-22.04
- ubuntu-22.04-arm
fail-fast: false
runs-on: ${{ matrix.runner }}
matrix:
arch: ['x86_64', 'aarch64']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
- uses: ./.github/actions/test-linux
cuda_build_and_test:
name: Linux (${{ matrix.toolkit }}, ${{ matrix.arch }})
if: github.repository == 'ml-explore/mlx'
needs: check_lint
strategy:
fail-fast: false
matrix:
arch: ['x86_64', 'aarch64']
toolkit: ['cuda-12.6', 'cuda-12.9']
runs-on: ${{ matrix.arch == 'x86_64' && 'gpu-t4-4-core' || 'ubuntu-22.04-arm' }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
cpu-only: true
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/build-linux
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/test-linux
if: matrix.arch == 'x86_64'
with:
has-gpu: true
mac_build_and_test:
name: macOS (${{ matrix.macos-target }})
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
@@ -49,38 +71,21 @@ jobs:
MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos-target }}
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos
- uses: ./.github/actions/build-macos
cuda_build_and_test:
if: github.repository == 'ml-explore/mlx'
strategy:
fail-fast: false
matrix:
toolkit: ['cuda-12.6', 'cuda-12.9']
runs-on: gpu-t4-4-core
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/build-cuda
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/test-linux
build_documentation:
name: Build Documentation
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22.04
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/build-docs
linux_fedora_build_cpp:
name: Linux Fedora CPP Build (${{ matrix.arch }})
name: Linux Fedora (${{ matrix.arch }})
needs: check_lint
strategy:
fail-fast: false
@@ -96,7 +101,7 @@ jobs:
image: fedora:42
steps:
- name: Checkout code
uses: actions/checkout@v5
uses: actions/checkout@v6
- name: CPP Build Test - No Release
run: |

View File

@@ -10,7 +10,7 @@ jobs:
build:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/build-docs
deploy:

View File

@@ -16,7 +16,7 @@ jobs:
python_version: ["3.10", "3.14"]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux-release
with:
@@ -46,14 +46,12 @@ jobs:
- ubuntu-22.04-arm
runs-on: ${{ matrix.runner }}
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux
- uses: ./.github/actions/test-linux
with:
cpu-only: true
build_mac_release:
if: github.repository == 'ml-explore/mlx'
@@ -62,7 +60,7 @@ jobs:
python-version: ["3.10", "3.13"]
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
@@ -82,7 +80,7 @@ jobs:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22-large
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: 'cuda-12.9'

View File

@@ -25,7 +25,7 @@ jobs:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/build-docs
deploy_documentation:
@@ -53,7 +53,7 @@ jobs:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
@@ -86,7 +86,7 @@ jobs:
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
@@ -133,14 +133,12 @@ jobs:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: 'cuda-12.9'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
toolkit: 'cuda-12.9'
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:

View File

@@ -123,14 +123,21 @@ if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
endif()
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
# managed memory.
# Use native CUDA arch by default.
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
execute_process(
COMMAND bash detect_cuda_arch.sh
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
COMMAND __nvcc_device_query
OUTPUT_VARIABLE MLX_CUDA_ARCHITECTURES
OUTPUT_STRIP_TRAILING_WHITESPACE)
set(UPGRADABLE_ARCHITECTURES "90;100;121")
if(MLX_CUDA_ARCHITECTURES STREQUAL "")
message(
FATAL_ERROR
"Can not get native CUDA arch, must set MLX_CUDA_ARCHITECTURES")
elseif(MLX_CUDA_ARCHITECTURES IN_LIST UPGRADABLE_ARCHITECTURES)
# Use arch-specific compute capability whenever possible.
set(MLX_CUDA_ARCHITECTURES "${MLX_CUDA_ARCHITECTURES}a")
endif()
endif()
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES

View File

@@ -154,17 +154,21 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
}
lock.unlock();
if (!buf) {
buf = new CudaBuffer{nullptr, size, device};
cudaError_t err;
void* data = nullptr;
if (device == -1) {
err = cudaMallocManaged(&buf->data, size);
err = cudaMallocManaged(&data, size);
} else {
err = cudaMallocAsync(&buf->data, size, stream);
err = cudaMallocAsync(&data, size, stream);
}
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
if (!data) {
return Buffer{nullptr};
}
buf = new CudaBuffer{data, size, device};
}
lock.lock();
}

View File

@@ -29,6 +29,10 @@ class CudaHandle {
}
~CudaHandle() {
// Skip if there was an error to avoid throwing in the destructors
if (cudaPeekAtLastError() != cudaSuccess) {
return;
}
reset();
}

View File

@@ -1,13 +0,0 @@
#!/bin/bash
arch=`__nvcc_device_query`
case "$arch" in
"90")
echo "90a" ;;
"100")
echo "100a" ;;
"121")
echo "121a" ;;
*)
echo "native" ;;
esac

View File

@@ -24,12 +24,21 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
}
bool use_cuda_graphs() {
static bool use_graphs = []() {
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
}();
static bool use_graphs = env::get_var("MLX_USE_CUDA_GRAPHS", true);
return use_graphs;
}
const char* save_cuda_graphs_dot_file() {
static const char* filename = []() -> const char* {
const char* env = std::getenv("MLX_SAVE_CUDA_GRAPHS_DOT_FILE");
if (env && std::strlen(env) == 0) {
return nullptr;
}
return env;
}();
return filename;
}
} // namespace
Device::Device(int device) : device_(device) {
@@ -421,6 +430,14 @@ void CommandEncoder::commit() {
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
}
// Save cuda graph to dot file
if (const char* filename = save_cuda_graphs_dot_file(); filename) {
static int count = 0;
auto path = fmt::format("{}_{}.dot", filename, ++count);
CHECK_CUDA_ERROR(cudaGraphDebugDotPrint(graph_, path.c_str(), 0));
}
// Reset state
from_nodes_.clear();
to_nodes_.clear();

View File

@@ -305,6 +305,7 @@ void Event::wait() {
} else {
event->atomic->wait(value());
}
CHECK_CUDA_ERROR(cudaPeekAtLastError());
}
void Event::wait(Stream s) {

View File

@@ -22,26 +22,28 @@ inline __device__ float2 plus_f2(const float2& a, const float2& b) {
}
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
template <typename T, int BLOCK_DIM>
template <typename T, int BLOCK_DIM, int GROUP_DIM = WARP_SIZE>
struct BlockBroadcastReduce {
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
static_assert(BLOCK_DIM % WARP_SIZE == 0);
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
using TempStorage = T[std::max(BLOCK_DIM / WARP_SIZE, 1)];
cg::thread_block& block;
TempStorage& temp;
template <typename Op>
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
auto warp = cg::tiled_partition<WARP_SIZE>(block);
auto warp = cg::tiled_partition<GROUP_DIM>(block);
T x = cg::reduce(warp, input, op);
if (warp.thread_rank() == 0) {
temp[warp.meta_group_rank()] = x;
if constexpr (BLOCK_DIM > GROUP_DIM) {
if (warp.thread_rank() == 0) {
temp[warp.meta_group_rank()] = x;
}
block.sync();
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
: init_value;
return cg::reduce(warp, x, op);
} else {
return x;
}
block.sync();
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
: init_value;
return cg::reduce(warp, x, op);
}
__device__ T Sum(const T& input) {
@@ -49,6 +51,52 @@ struct BlockBroadcastReduce {
}
};
template <typename T, int BLOCK_DIM, int REDUCE_DIM, int N_READS = 4>
__global__ void rms_norm_small(
const T* x,
const T* w,
T* out,
float eps,
uint32_t axis_size,
uint32_t n_rows,
int64_t w_stride) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM, REDUCE_DIM>;
__shared__ typename BlockReduceT::TempStorage temp;
auto row =
(grid.block_rank() * block.dim_threads().y) + block.thread_index().y;
if (row >= n_rows) {
return;
}
x += row * axis_size;
out += row * axis_size;
// Normalizer.
float normalizer = 0;
auto index = block.thread_index().x;
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
float t = static_cast<float>(xn[i]);
normalizer += t * t;
}
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
normalizer = rsqrt(normalizer / axis_size + eps);
// Outputs.
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
float y = static_cast<float>(xn[i]) * normalizer;
xn[i] = wn[i] * static_cast<T>(y);
}
store_vector<N_READS>(out, index, xn, axis_size);
}
template <typename T, int BLOCK_DIM, int N_READS = 4>
__global__ void rms_norm(
const T* x,
@@ -94,6 +142,74 @@ __global__ void rms_norm(
}
}
template <
typename T,
bool HAS_W,
int BLOCK_DIM,
int REDUCE_DIM,
int N_READS = 4>
__global__ void rms_norm_vjp_small(
const T* x,
const T* w,
const T* g,
T* gx,
T* gw,
float eps,
int32_t axis_size,
int32_t n_rows,
int64_t w_stride) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM, REDUCE_DIM>;
__shared__ typename BlockReduceF2::TempStorage temp;
auto row =
(grid.block_rank() * block.dim_threads().y) + block.thread_index().y;
if (row >= n_rows) {
return;
}
x += row * axis_size;
g += row * axis_size;
gx += row * axis_size;
gw += row * axis_size;
// Normalizer.
float2 factors = {};
auto index = block.thread_index().x;
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
for (int i = 0; i < N_READS; i++) {
float t = static_cast<float>(xn[i]);
float wi = wn[i];
float gi = gn[i];
float wg = wi * gi;
factors = plus_f2(factors, {wg * t, t * t});
}
factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});
float meangwx = factors.x / axis_size;
float normalizer = rsqrt(factors.y / axis_size + eps);
float normalizer3 = normalizer * normalizer * normalizer;
// Outputs.
for (int i = 0; i < N_READS; i++) {
float xi = xn[i];
float wi = wn[i];
float gi = gn[i];
xn[i] = static_cast<T>(normalizer * wi * gi - xi * meangwx * normalizer3);
if constexpr (HAS_W) {
wn[i] = static_cast<T>(gi * xi * normalizer);
}
}
store_vector<N_READS>(gx, index, xn, axis_size);
if constexpr (HAS_W) {
store_vector<N_READS>(gw, index, wn, axis_size);
}
}
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
__global__ void rms_norm_vjp(
const T* x,
@@ -107,12 +223,8 @@ __global__ void rms_norm_vjp(
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>;
__shared__ union {
typename BlockReduceF::TempStorage f;
typename BlockReduceF2::TempStorage f2;
} temp;
__shared__ typename BlockReduceF2::TempStorage temp;
x += grid.block_rank() * axis_size;
g += grid.block_rank() * axis_size;
@@ -134,7 +246,7 @@ __global__ void rms_norm_vjp(
factors = plus_f2(factors, {wg * t, t * t});
}
}
factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {});
factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});
float meangwx = factors.x / axis_size;
float normalizer = rsqrt(factors.y / axis_size + eps);
float normalizer3 = normalizer * normalizer * normalizer;
@@ -169,6 +281,43 @@ bool RMSNorm::use_fallback(Stream s) {
return s.device == Device::cpu;
}
template <int n_per_thread, typename F>
void dispatch_group_dim(int axis_size, F&& f) {
if (axis_size <= n_per_thread * 8) {
f(std::integral_constant<int, 8>{},
std::integral_constant<int, 1>(),
std::integral_constant<int, 16>());
} else if (axis_size <= n_per_thread * 16) {
f(std::integral_constant<int, 16>{},
std::integral_constant<int, 1>(),
std::integral_constant<int, 8>());
} else if (axis_size <= n_per_thread * 32) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 1>(),
std::integral_constant<int, 4>());
} else if (axis_size <= n_per_thread * 32 * 2) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 2>(),
std::integral_constant<int, 2>());
} else if (axis_size <= n_per_thread * 32 * 4) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 4>(),
std::integral_constant<int, 1>());
} else if (axis_size <= n_per_thread * 32 * 8) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 8>(),
std::integral_constant<int, 1>());
} else if (axis_size <= n_per_thread * 32 * 16) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 16>(),
std::integral_constant<int, 1>());
} else {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 32>(),
std::integral_constant<int, 1>());
}
}
// TODO: There are duplicate code with backend/metal/normalization.cpp
void RMSNorm::eval_gpu(
const std::vector<array>& inputs,
@@ -216,12 +365,33 @@ void RMSNorm::eval_gpu(
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>;
if (axis_size <= N_READS * 1024) {
dispatch_group_dim<N_READS>(
axis_size, [&](auto group_dim, auto n_groups, auto groups_per_block) {
constexpr int block_dim = n_groups() * group_dim();
auto kernel =
cu::rms_norm_small<DataType, block_dim, group_dim(), N_READS>;
auto n_blocks =
(n_rows + groups_per_block() - 1) / groups_per_block();
encoder.add_kernel_node(
kernel,
n_blocks,
{block_dim, groups_per_block()},
0,
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(out),
eps_,
axis_size,
n_rows,
w_stride);
});
} else {
auto kernel = cu::rms_norm<DataType, 1024, N_READS>;
encoder.add_kernel_node(
kernel,
n_rows,
block_dim(),
1024,
0,
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
@@ -229,7 +399,7 @@ void RMSNorm::eval_gpu(
eps_,
axis_size,
w_stride);
});
}
});
}
@@ -306,27 +476,51 @@ void RMSNormVJP::eval_gpu(
dispatch_bool(has_w, [&](auto has_w_constant) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
dispatch_block_dim(
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
auto kernel = cu::rms_norm_vjp<
DataType,
has_w_constant.value,
block_dim(),
N_READS>;
encoder.add_kernel_node(
kernel,
n_rows,
block_dim(),
0,
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(g),
gpu_ptr<DataType>(gx),
gpu_ptr<DataType>(gw_temp),
eps_,
axis_size,
w_stride);
});
if (axis_size <= N_READS * 1024) {
dispatch_group_dim<N_READS>(
axis_size,
[&](auto group_dim, auto n_groups, auto groups_per_block) {
constexpr int block_dim = group_dim() * n_groups();
auto kernel = cu::rms_norm_vjp_small<
DataType,
has_w_constant.value,
block_dim,
group_dim(),
N_READS>;
auto n_blocks =
(n_rows + groups_per_block() - 1) / groups_per_block();
encoder.add_kernel_node(
kernel,
n_blocks,
{block_dim, groups_per_block()},
0,
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(g),
gpu_ptr<DataType>(gx),
gpu_ptr<DataType>(gw_temp),
eps_,
axis_size,
n_rows,
w_stride);
});
} else {
auto kernel =
cu::rms_norm_vjp<DataType, has_w_constant.value, 1024, N_READS>;
encoder.add_kernel_node(
kernel,
n_rows,
1024,
0,
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(g),
gpu_ptr<DataType>(gx),
gpu_ptr<DataType>(gw_temp),
eps_,
axis_size,
w_stride);
}
});
});

View File

@@ -63,6 +63,38 @@ array prepare_sdpa_input(const array& x, Stream s) {
return x;
}
void malloc_with_same_layout(
cu::CommandEncoder& encoder,
array& o,
const array& q) {
if (q.flags().row_contiguous) {
o.set_data(cu::malloc_async(o.nbytes(), encoder));
return;
}
// fill_order = argsort(q.strides())
Shape fill_order(q.ndim());
std::iota(fill_order.begin(), fill_order.end(), 0);
std::stable_sort(
fill_order.begin(), fill_order.end(), [&q](int idx1, int idx2) {
auto s1 = q.strides(idx1) > 0 ? q.strides(idx1) : 1;
auto s2 = q.strides(idx2) > 0 ? q.strides(idx2) : 1;
return s1 < s2;
});
// Generate o_strides with fill_order
Strides o_strides(q.ndim());
int64_t stride = 1;
for (int i : fill_order) {
o_strides[i] = stride;
stride *= o.shape(i);
}
// o is a transposed contiguous array
o.set_data(
cu::malloc_async(o.nbytes(), encoder),
o.size(),
o_strides,
{true, false, false});
}
constexpr int QKV_NDIM = 4;
struct SDPACacheKey {
@@ -75,6 +107,8 @@ struct SDPACacheKey {
std::array<int64_t, QKV_NDIM> k_strides;
std::array<int64_t, QKV_NDIM> v_strides;
bool do_causal;
std::array<int, QKV_NDIM> mask_shape;
std::array<int64_t, QKV_NDIM> mask_strides;
bool output_logsumexp;
};
@@ -84,6 +118,7 @@ inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
const array& k,
const array& v,
bool do_causal,
const std::optional<array>& mask_arr,
bool output_logsumexp = true) {
BytesKey<SDPACacheKey> cache_key;
cache_key.pod = {
@@ -96,20 +131,26 @@ inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
vector_key<QKV_NDIM>(k.strides()),
vector_key<QKV_NDIM>(v.strides()),
do_causal,
{},
{},
output_logsumexp,
};
if (mask_arr) {
cache_key.pod.mask_shape = vector_key<QKV_NDIM>(mask_arr->shape());
cache_key.pod.mask_strides = vector_key<QKV_NDIM>(mask_arr->strides());
}
return cache_key;
}
auto& sdpa_cache() {
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 16);
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 64);
return cache;
}
auto& sdpa_backward_cache() {
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 16);
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 64);
return cache;
}
@@ -118,6 +159,7 @@ enum UIDS {
K,
V,
SCALE,
BIAS,
O,
STATS,
// Backward graph:
@@ -133,6 +175,7 @@ fe::graph::Graph build_sdpa_graph(
const array& k,
const array& v,
bool do_causal,
const std::optional<array>& mask_arr,
bool output_logsumexp,
const array& o,
const array& stats) {
@@ -164,8 +207,19 @@ fe::graph::Graph build_sdpa_graph(
auto options = fe::graph::SDPA_attributes()
.set_name("sdpa_cudnn")
.set_attn_scale(scale)
.set_causal_mask(do_causal)
.set_generate_stats(output_logsumexp);
if (do_causal) {
if (q.shape(2) > k.shape(2)) {
options.set_causal_mask(do_causal);
} else {
options.set_causal_mask_bottom_right(do_causal);
}
}
if (mask_arr) {
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
set_tensor_attrs(bias_, BIAS, *mask_arr);
options.set_bias(bias_);
}
auto [o_, stats_] = graph.sdpa(q_, k_, v_, options);
o_->set_output(true);
@@ -192,6 +246,7 @@ fe::graph::Graph build_sdpa_backward_graph(
const array& k,
const array& v,
bool do_causal,
const std::optional<array>& mask_arr,
const array& o,
const array& d_o,
const array& stats,
@@ -233,7 +288,19 @@ fe::graph::Graph build_sdpa_backward_graph(
auto options = fe::graph::SDPA_backward_attributes()
.set_name("sdpa_backward_cudnn")
.set_attn_scale(scale)
.set_causal_mask(do_causal);
.set_attn_scale(scale);
if (do_causal) {
if (q.shape(2) > k.shape(2)) {
options.set_causal_mask(do_causal);
} else {
options.set_causal_mask_bottom_right(do_causal);
}
}
if (mask_arr) {
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
set_tensor_attrs(bias_, BIAS, *mask_arr);
options.set_bias(bias_);
}
auto [d_q_, d_k_, d_v_] =
graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options);
@@ -286,7 +353,6 @@ bool supports_sdpa_cudnn(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool do_causal,
Stream s) {
static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1);
@@ -299,19 +365,8 @@ bool supports_sdpa_cudnn(
return false;
}
if (has_mask) {
// TODO: Support array masks.
if (!do_causal) {
return false;
}
// FIXME: Causal mask generates wrong results when L_Q != L_K.
if (q.shape(2) != k.shape(2)) {
return false;
}
}
// Only use cuDNN for prefilling and training.
if (q.shape(2) != k.shape(2)) {
// Only use cuDNN for prefilling (T_q > 1) and training (T_q == T_kv).
if ((q.shape(2) == 1) && (q.shape(2) != k.shape(2))) {
return false;
}
@@ -333,32 +388,33 @@ void sdpa_cudnn(
array& o,
array& stats,
bool do_causal,
const std::optional<array>& mask_arr,
bool output_logsumexp,
Stream s) {
auto& encoder = cu::get_command_encoder(s);
auto handle = encoder.device().cudnn_handle();
// TODO: Handle donation.
// TODO: Make O use same memory layout with Q.
o.set_data(cu::malloc_async(o.nbytes(), encoder));
malloc_with_same_layout(encoder, o, q);
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
encoder.set_output_array(o);
if (mask_arr) {
encoder.set_input_array(*mask_arr);
}
if (output_logsumexp) {
stats.set_data(cu::malloc_async(stats.nbytes(), encoder));
encoder.set_output_array(stats);
}
// Search cache.
auto cache_key =
build_sdpa_cache_key(encoder, q, k, v, do_causal, output_logsumexp);
auto cache_key = build_sdpa_cache_key(
encoder, q, k, v, do_causal, mask_arr, output_logsumexp);
auto it = sdpa_cache().find(cache_key);
if (it == sdpa_cache().end()) {
auto graph = build_sdpa_graph(
handle, q, k, v, do_causal, output_logsumexp, o, stats);
handle, q, k, v, do_causal, mask_arr, output_logsumexp, o, stats);
it = sdpa_cache().emplace(cache_key, std::move(graph)).first;
}
auto& graph = it->second;
@@ -369,6 +425,9 @@ void sdpa_cudnn(
{V, const_cast<void*>(gpu_ptr<void>(v))},
{SCALE, &scale},
{O, gpu_ptr<void>(o)}};
if (mask_arr) {
variant_pack[BIAS] = const_cast<void*>(gpu_ptr<void>(*mask_arr));
}
if (output_logsumexp) {
variant_pack[STATS] = gpu_ptr<void>(stats);
}
@@ -384,6 +443,7 @@ void sdpa_backward_cudnn(
const array& o,
const array& stats,
bool do_causal,
const std::optional<array>& mask_arr,
const array& d_o,
array& d_q,
array& d_k,
@@ -392,10 +452,9 @@ void sdpa_backward_cudnn(
auto& encoder = cu::get_command_encoder(s);
auto handle = encoder.device().cudnn_handle();
// TODO: Handle donation.
d_q.set_data(cu::malloc_async(d_q.nbytes(), encoder));
d_k.set_data(cu::malloc_async(d_k.nbytes(), encoder));
d_v.set_data(cu::malloc_async(d_v.nbytes(), encoder));
malloc_with_same_layout(encoder, d_q, q);
malloc_with_same_layout(encoder, d_k, k);
malloc_with_same_layout(encoder, d_v, v);
encoder.set_input_array(q);
encoder.set_input_array(k);
@@ -406,13 +465,16 @@ void sdpa_backward_cudnn(
encoder.set_output_array(d_q);
encoder.set_output_array(d_k);
encoder.set_output_array(d_v);
if (mask_arr) {
encoder.set_input_array(*mask_arr);
}
// Search cache.
auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal);
auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal, mask_arr);
auto it = sdpa_backward_cache().find(cache_key);
if (it == sdpa_backward_cache().end()) {
auto graph = build_sdpa_backward_graph(
handle, q, k, v, do_causal, o, d_o, stats, d_q, d_k, d_v);
handle, q, k, v, do_causal, mask_arr, o, d_o, stats, d_q, d_k, d_v);
it = sdpa_backward_cache().emplace(cache_key, std::move(graph)).first;
}
auto& graph = it->second;
@@ -428,6 +490,9 @@ void sdpa_backward_cudnn(
{D_Q, gpu_ptr<void>(d_q)},
{D_K, gpu_ptr<void>(d_k)},
{D_V, gpu_ptr<void>(d_v)}};
if (mask_arr) {
variant_pack[BIAS] = const_cast<void*>(gpu_ptr<void>(*mask_arr));
}
execute_graph(encoder, handle, graph, variant_pack);
}
@@ -469,7 +534,11 @@ bool ScaledDotProductAttention::use_fallback(
return !supports_sdpa_vector(
q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) &&
!supports_sdpa_cudnn(q, k, v, has_mask, do_causal, s);
!supports_sdpa_cudnn(q, k, v, do_causal, s);
}
bool ScaledDotProductAttention::supports_bool_mask() {
return false;
}
void ScaledDotProductAttention::eval_gpu(
@@ -487,6 +556,11 @@ void ScaledDotProductAttention::eval_gpu(
bool has_mask = inputs.size() - has_sinks_ > 3;
bool has_arr_mask = has_mask && !do_causal_;
std::optional<array> mask_arr;
if (has_arr_mask) {
mask_arr = prepare_sdpa_input(inputs[3], s);
}
if (supports_sdpa_vector(
q, k, v, has_mask, has_arr_mask, do_causal_, output_logsumexp_)) {
if (has_sinks_) {
@@ -495,7 +569,17 @@ void ScaledDotProductAttention::eval_gpu(
sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s);
}
} else {
sdpa_cudnn(q, k, v, scale_, out, stats, do_causal_, output_logsumexp_, s);
sdpa_cudnn(
q,
k,
v,
scale_,
out,
stats,
do_causal_,
mask_arr,
output_logsumexp_,
s);
}
}
@@ -515,13 +599,21 @@ void ScaledDotProductAttentionVJP::eval_gpu(
auto& s = stream();
assert(inputs.size() == 6);
assert(inputs.size() >= 6);
int primals_size = inputs.size() - 3;
bool has_arr_mask = primals_size > 3 + has_sinks_;
array q = prepare_sdpa_input(inputs[0], s);
array k = prepare_sdpa_input(inputs[1], s);
array v = prepare_sdpa_input(inputs[2], s);
array o = prepare_sdpa_input(inputs[3], s);
array stats = prepare_sdpa_input(inputs[4], s);
array d_o = prepare_sdpa_input(inputs[5], s);
array o = prepare_sdpa_input(inputs[primals_size], s);
array stats = prepare_sdpa_input(inputs[primals_size + 1], s);
array d_o = prepare_sdpa_input(inputs[primals_size + 2], s);
std::optional<array> mask_arr;
if (has_arr_mask) {
mask_arr = prepare_sdpa_input(inputs[3], s);
}
assert(outputs.size() == 3);
auto& d_q = outputs[0];
@@ -529,7 +621,7 @@ void ScaledDotProductAttentionVJP::eval_gpu(
auto& d_v = outputs[2];
sdpa_backward_cudnn(
q, k, v, scale_, o, stats, do_causal_, d_o, d_q, d_k, d_v, s);
q, k, v, scale_, o, stats, do_causal_, mask_arr, d_o, d_q, d_k, d_v, s);
}
} // namespace fast

View File

@@ -16,7 +16,77 @@ INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
mkdir -p "$OUTPUT_DIR"
CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
# CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
CCC="xcrun -sdk macosx metal -x metal"
HDRS=$( $CCC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P -CC -C -H "$INPUT_FILE" $CFLAGS -w 2>&1 1>/dev/null )
declare -a HDRS_LIST=($HDRS)
declare -a HDRS_STACK=()
declare -a HDRS_SORTED=()
length=${#HDRS_LIST[@]}
HDRS_LIST+=(".")
for ((i=0; i<${length}; i+=2));
do
header="${HDRS_LIST[$i+1]#$SRC_DIR/}"
str_this="${HDRS_LIST[$i]}"
str_next="${HDRS_LIST[$i + 2]}"
depth_this=${#str_this}
depth_next=${#str_next}
# If we have a dependency then we stack it
if [ $depth_next -gt $depth_this ]; then
HDRS_STACK=($header ${HDRS_STACK[@]})
# If we are done with this level
else
# We add the header to out list
HDRS_SORTED+=($header)
# Pop the stacked up dependencies
pop_len=$((depth_this - depth_next))
for popped_header in "${HDRS_STACK[@]:0:$pop_len}"
do
HDRS_SORTED+=($popped_header)
done
HDRS_STACK=(${HDRS_STACK[@]:$pop_len})
fi
done
HDRS_SORTED+=("${INPUT_FILE#$SRC_DIR/}")
CONTENT=$(
echo "// Copyright © 2025 Apple Inc."
echo ""
echo "// Auto generated source for $INPUT_FILE"
echo ""
for header in "${HDRS_SORTED[@]}"
do
echo "///////////////////////////////////////////////////////////////////////////////"
echo "// Contents from \"${header}\""
echo "///////////////////////////////////////////////////////////////////////////////"
echo ""
echo "#line 1 \"${header}\""
grep -h -v -G -e "#include \".*.h\"" -e "#pragma once" "${SRC_DIR}/${header}"
echo ""
done
echo "///////////////////////////////////////////////////////////////////////////////"
)
cat << EOF > "$OUTPUT_FILE"
namespace mlx::core::metal {

View File

@@ -569,6 +569,10 @@ bool ScaledDotProductAttention::use_fallback(
return !(supports_sdpa_full || supports_sdpa_vector);
}
bool ScaledDotProductAttention::supports_bool_mask() {
return true;
}
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {

View File

@@ -36,6 +36,10 @@ bool fast::ScaledDotProductAttention::use_fallback(
return true;
}
bool fast::ScaledDotProductAttention::supports_bool_mask() {
return false;
}
bool fast::ScaledDotProductAttentionVJP::use_fallback(
const array& q,
Stream s) {

View File

@@ -5,3 +5,48 @@
ncclResult_t ncclGetUniqueId(ncclUniqueId*) {
return ncclSuccess;
}
const char* ncclGetErrorString(ncclResult_t result) {
return nullptr;
}
ncclResult_t
ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank) {
return ncclSuccess;
}
ncclResult_t ncclCommDestroy(ncclComm_t comm) {
return ncclSuccess;
}
ncclResult_t ncclAllGather(
const void* sendbuff,
void* recvbuff,
size_t sendcount,
ncclDataType_t datatype,
ncclComm_t comm,
cudaStream_t stream) {
return ncclSuccess;
}
ncclResult_t ncclAllReduce(
const void* sendbuff,
void* recvbuff,
size_t count,
ncclDataType_t datatype,
ncclRedOp_t op,
ncclComm_t comm,
cudaStream_t stream) {
return ncclSuccess;
}
ncclResult_t ncclReduceScatter(
const void* sendbuff,
void* recvbuff,
size_t recvcount,
ncclDataType_t datatype,
ncclRedOp_t op,
ncclComm_t comm,
cudaStream_t stream) {
return ncclSuccess;
}

View File

@@ -800,6 +800,15 @@ array scaled_dot_product_attention(
is_training,
output_logsumexp,
stream)) {
if (has_bool_mask && !ScaledDotProductAttention::supports_bool_mask()) {
// Convert bool mask to additive mask.
float inf = std::numeric_limits<float>::infinity();
array& mask = inputs[3];
mask = where(
mask,
full_like(mask, 0, final_type, s),
full_like(mask, -inf, final_type, s));
}
Shape out_shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
auto primitive = std::make_shared<ScaledDotProductAttention>(
stream, fallback, scale, do_causal, has_sinks, output_logsumexp);
@@ -839,7 +848,7 @@ std::vector<array> ScaledDotProductAttention::vjp(
std::vector<Shape> shapes;
std::vector<Dtype> dtypes;
for (int i = 0; i < primals.size(); ++i) {
for (int i = 0; i < /* outputs size */ 3; ++i) {
shapes.push_back(primals[i].shape());
dtypes.push_back(primals[i].dtype());
}

View File

@@ -228,6 +228,7 @@ class ScaledDotProductAttention : public Custom {
bool is_training,
bool output_logsumexp,
Stream s);
static bool supports_bool_mask();
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {

View File

@@ -135,7 +135,11 @@ class Scheduler {
~Scheduler() {
for (auto s : streams_) {
synchronize(s);
try {
synchronize(s);
} catch (const std::runtime_error&) {
// ignore errors if synch fails
}
}
for (auto t : threads_) {
if (t != nullptr) {

View File

@@ -407,7 +407,10 @@ class Module(dict):
instance).
Args:
apply_fn (Callable): The function to apply to the modules.
apply_fn (Callable): The function to apply to the modules which
takes two parameters. The first parameter is the string path of
the module (e.g. ``"model.layers.0.linear"``). The second
parameter is the module object.
Returns:
The module instance after updating submodules.

View File

@@ -1445,7 +1445,7 @@ void init_ops(nb::module_& m) {
"dtype"_a.none() = mx::float32,
"stream"_a = nb::none(),
nb::sig(
"def linspace(start, stop, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array"),
"def linspace(start: scalar, stop: scalar, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Generate ``num`` evenly spaced numbers over interval ``[start, stop]``.

View File

@@ -1238,8 +1238,18 @@ void init_transforms(nb::module_& m) {
same in number, shape, and type as the inputs of ``fun`` (i.e. the ``primals``).
Returns:
list(array): A list of the Jacobian-vector products which
is the same in number, shape, and type of the inputs to ``fun``.
tuple(list(array), list(array)): A tuple with the outputs of
``fun`` in the first position and the Jacobian-vector products
in the second position.
Example:
.. code-block:: python
import mlx.core as mx
outs, jvps = mx.jvp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))
)pbdoc");
m.def(
"vjp",
@@ -1277,8 +1287,18 @@ void init_transforms(nb::module_& m) {
same in number, shape, and type as the outputs of ``fun``.
Returns:
list(array): A list of the vector-Jacobian products which
is the same in number, shape, and type of the outputs of ``fun``.
tuple(list(array), list(array)): A tuple with the outputs of
``fun`` in the first position and the vector-Jacobian products
in the second position.
Example:
.. code-block:: python
import mlx.core as mx
outs, vjps = mx.vjp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))
)pbdoc");
m.def(
"value_and_grad",

View File

@@ -739,37 +739,69 @@ class TestSDPA(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(out, expected, atol=1e-5))
def test_sdpa_grad(self):
B, N_kv, T, D = (2, 8, 128, 64)
scale = D**-0.5
f1 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale)
f2 = lambda q, k, v: mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
f3 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale).sum()
f4 = lambda q, k, v: (
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
).sum()
# High tolerance due to cuDNN SDPA kernel requiring tf32.
tolerance = {"rtol": 1e-2, "atol": 1e-2}
def test_vjp(slow, fast, primals):
cotan = mx.ones_like(primals[0])
o1, vjp1 = mx.vjp(slow, primals, [cotan])
o2, vjp2 = mx.vjp(fast, primals, [cotan])
self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance))
for i in range(3):
self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance))
def test_grad(slow, fast, args):
g1 = mx.grad(slow)(*args)
g2 = mx.grad(fast)(*args)
self.assertTrue(mx.allclose(g1, g2, **tolerance))
sdpa_mask_slow = lambda q, k, v, mask: mlx_ref_attn(
q, k, v, scale=scale, mask=mask
)
sdpa_mask_fast = lambda q, k, v, mask: mx.fast.scaled_dot_product_attention(
q, k, v, scale=scale, mask=mask
)
loss_mask_slow = lambda q, k, v, mask: mlx_ref_attn(
q, k, v, scale=scale, mask=mask
).sum()
loss_mask_fast = lambda q, k, v, mask: (
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
).sum()
B, N_kv, T, D = (2, 8, 128, 64)
scale = D**-0.5
for N_q in (8, 32):
q = mx.random.normal(shape=(B, N_q, T, D), dtype=mx.float16)
k = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
v = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
cotan = mx.ones_like(q)
o1, vjp1 = mx.vjp(f1, [q, k, v], [cotan])
o2, vjp2 = mx.vjp(f2, [q, k, v], [cotan])
mask_additive = mx.random.normal((B, N_q, T, T), dtype=mx.float16)
mask_bool = mx.random.uniform(0, 1, (B, N_q, T, T), dtype=mx.float16) < 0.5
self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance))
for i in range(3):
self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance))
for mask in (mask_additive, mask_bool):
test_vjp(sdpa_mask_slow, sdpa_mask_fast, [q, k, v, mask])
test_grad(loss_mask_slow, loss_mask_fast, [q, k, v, mask])
g1 = mx.grad(f3)(q, k, v)
g2 = mx.grad(f4)(q, k, v)
for mask in (None, "causal"):
sdpa_slow = lambda q, k, v: mlx_ref_attn(
q, k, v, scale=scale, mask=mask
)
sdpa_fast = lambda q, k, v: mx.fast.scaled_dot_product_attention(
q, k, v, scale=scale, mask=mask
)
test_vjp(sdpa_slow, sdpa_fast, [q, k, v])
self.assertTrue(mx.allclose(g1, g2, **tolerance))
loss_slow = lambda q, k, v: mlx_ref_attn(
q, k, v, scale=scale, mask=mask
).sum()
loss_fast = lambda q, k, v: mx.fast.scaled_dot_product_attention(
q, k, v, scale=scale, mask=mask
).sum()
test_grad(loss_slow, loss_fast, [q, k, v])
if __name__ == "__main__":