mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
15 Commits
d5f61a93fa
...
jit-nax
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5cf6f10bef | ||
|
|
7c1abc50c0 | ||
|
|
2b95d0c270 | ||
|
|
b054838780 | ||
|
|
dd79d3c465 | ||
|
|
704fd1ae28 | ||
|
|
c9f4dc851f | ||
|
|
f8bd675655 | ||
|
|
23a9168d34 | ||
|
|
bca205e287 | ||
|
|
1d4eacb737 | ||
|
|
8abd37ad05 | ||
|
|
3e05cea9f8 | ||
|
|
5b0f047226 | ||
|
|
618c87af8c |
@@ -1,18 +1,13 @@
|
|||||||
name: 'Build CUDA wheel'
|
name: 'Build CUDA wheel'
|
||||||
description: 'Build CUDA wheel'
|
description: 'Build CUDA wheel'
|
||||||
|
|
||||||
inputs:
|
|
||||||
toolkit:
|
|
||||||
description: 'The CUDA toolkit'
|
|
||||||
required: true
|
|
||||||
|
|
||||||
runs:
|
runs:
|
||||||
using: "composite"
|
using: "composite"
|
||||||
steps:
|
steps:
|
||||||
- name: Build package
|
- name: Build package
|
||||||
shell: bash
|
shell: bash
|
||||||
env:
|
env:
|
||||||
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc
|
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON
|
||||||
run: |
|
run: |
|
||||||
pip install auditwheel build patchelf setuptools
|
pip install auditwheel build patchelf setuptools
|
||||||
python setup.py clean --all
|
python setup.py clean --all
|
||||||
|
|||||||
26
.github/actions/build-cuda/action.yml
vendored
26
.github/actions/build-cuda/action.yml
vendored
@@ -1,26 +0,0 @@
|
|||||||
name: 'Build and Test with CUDA'
|
|
||||||
description: 'Build and test MLX with CUDA'
|
|
||||||
|
|
||||||
inputs:
|
|
||||||
toolkit:
|
|
||||||
description: 'The CUDA toolkit'
|
|
||||||
required: true
|
|
||||||
|
|
||||||
runs:
|
|
||||||
using: "composite"
|
|
||||||
steps:
|
|
||||||
- name: Install Python package
|
|
||||||
shell: bash
|
|
||||||
env:
|
|
||||||
DEBUG: 1
|
|
||||||
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc
|
|
||||||
run: pip install --no-build-isolation -e ".[dev]" -v
|
|
||||||
|
|
||||||
- name: Build CPP only
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
cmake . -B build \
|
|
||||||
-DMLX_BUILD_CUDA=ON \
|
|
||||||
-DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc \
|
|
||||||
-DCMAKE_BUILD_TYPE=DEBUG
|
|
||||||
cmake --build build -j $(nproc)
|
|
||||||
28
.github/actions/build-linux/action.yml
vendored
28
.github/actions/build-linux/action.yml
vendored
@@ -1,15 +1,32 @@
|
|||||||
name: 'Build and Test on Linux'
|
name: 'Build and Test on Linux'
|
||||||
description: 'Build and test MLX on Linux'
|
|
||||||
|
inputs:
|
||||||
|
toolkit:
|
||||||
|
description: 'The toolkit to build with'
|
||||||
|
required: false
|
||||||
|
default: 'cpu'
|
||||||
|
|
||||||
runs:
|
runs:
|
||||||
using: "composite"
|
using: "composite"
|
||||||
steps:
|
steps:
|
||||||
- name: Install Python package
|
- name: Install Python package
|
||||||
|
id: python_build
|
||||||
shell: sh
|
shell: sh
|
||||||
env:
|
env:
|
||||||
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
|
|
||||||
DEBUG: 1
|
DEBUG: 1
|
||||||
run: pip install --no-build-isolation -e ".[dev]" -v
|
CMAKE_ARGS: >-
|
||||||
|
-DCMAKE_COMPILE_WARNING_AS_ERROR=ON
|
||||||
|
-DMLX_BUILD_CUDA=${{ startsWith(inputs.toolkit, 'cuda') && 'ON' || 'OFF' }}
|
||||||
|
run: |
|
||||||
|
if ${{ startsWith(inputs.toolkit, 'cuda') && runner.arch == 'arm64' }} ; then
|
||||||
|
# There is no GPU in arm64 runner, use a common arch.
|
||||||
|
CMAKE_ARGS="$CMAKE_ARGS -DMLX_CUDA_ARCHITECTURES=90a"
|
||||||
|
# Can not build tests when the built executables can not run.
|
||||||
|
CMAKE_ARGS="$CMAKE_ARGS -DMLX_BUILD_TESTS=OFF"
|
||||||
|
fi
|
||||||
|
pip install --no-build-isolation -e ".[dev]" -v
|
||||||
|
# Pass the CMAKE_ARGS to following steps.
|
||||||
|
echo CMAKE_ARGS="$CMAKE_ARGS" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Generate package stubs
|
- name: Generate package stubs
|
||||||
shell: sh
|
shell: sh
|
||||||
@@ -20,6 +37,5 @@ runs:
|
|||||||
- name: Build CPP only
|
- name: Build CPP only
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
mkdir -p build && cd build
|
cmake . -B build -DCMAKE_BUILD_TYPE=Debug ${{ steps.python_build.outputs.CMAKE_ARGS }}
|
||||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
cmake --build build -j $(nproc)
|
||||||
make -j $(nproc)
|
|
||||||
|
|||||||
7
.github/actions/setup-linux/action.yml
vendored
7
.github/actions/setup-linux/action.yml
vendored
@@ -51,8 +51,6 @@ runs:
|
|||||||
# Note: the CI machine does not meet CUDA 13's driver requirement.
|
# Note: the CI machine does not meet CUDA 13's driver requirement.
|
||||||
# Compatibility matrix:
|
# Compatibility matrix:
|
||||||
# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
|
# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
|
||||||
# The `nvcc` is installed into `/usr/local/cuda-VERSION/bin/nvcc` - but
|
|
||||||
# it's *not* on the default toolkit path.
|
|
||||||
PACKAGES: |
|
PACKAGES: |
|
||||||
{
|
{
|
||||||
"cuda-12.6": "libcudnn9-dev-cuda-12 cuda-toolkit-12-6",
|
"cuda-12.6": "libcudnn9-dev-cuda-12 cuda-toolkit-12-6",
|
||||||
@@ -60,13 +58,16 @@ runs:
|
|||||||
"cuda-13.0": "libcudnn9-dev-cuda-13 cuda-toolkit-13-0"
|
"cuda-13.0": "libcudnn9-dev-cuda-13 cuda-toolkit-13-0"
|
||||||
}
|
}
|
||||||
run: |
|
run: |
|
||||||
export ARCH=${{ runner.arch == 'arm64' && 'arm64' || 'x86_64' }}
|
# The CUDA binaries are hosted in the "sbsa" repo, the "arm64" repo is
|
||||||
|
# Jetson specific. SBSA means Arm Server Base System Architecture.
|
||||||
|
ARCH=${{ runner.arch == 'arm64' && 'sbsa' || 'x86_64' }}
|
||||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$ARCH/cuda-keyring_1.1-1_all.deb
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$ARCH/cuda-keyring_1.1-1_all.deb
|
||||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install -y \
|
sudo apt-get install -y \
|
||||||
libnccl2 libnccl-dev \
|
libnccl2 libnccl-dev \
|
||||||
${{ fromJson(env.PACKAGES)[inputs.toolkit] }}
|
${{ fromJson(env.PACKAGES)[inputs.toolkit] }}
|
||||||
|
echo "/usr/local/${{ inputs.toolkit }}/bin" >> $GITHUB_PATH
|
||||||
|
|
||||||
- name: CUDA packages and driver report
|
- name: CUDA packages and driver report
|
||||||
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
|
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
|
||||||
|
|||||||
12
.github/actions/test-linux/action.yml
vendored
12
.github/actions/test-linux/action.yml
vendored
@@ -1,8 +1,8 @@
|
|||||||
name: 'Run Linux tests'
|
name: 'Run Linux tests'
|
||||||
|
|
||||||
inputs:
|
inputs:
|
||||||
cpu-only:
|
has-gpu:
|
||||||
description: 'Skip GPU tests'
|
description: 'Run GPU tests'
|
||||||
required: false
|
required: false
|
||||||
default: false
|
default: false
|
||||||
|
|
||||||
@@ -17,7 +17,7 @@ runs:
|
|||||||
echo "::endgroup::"
|
echo "::endgroup::"
|
||||||
|
|
||||||
- name: Run distributed tests
|
- name: Run distributed tests
|
||||||
if: ${{ inputs.cpu-only == 'true' }}
|
if: ${{ inputs.has-gpu == 'false' }}
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
echo "::group::Distributed tests"
|
echo "::group::Distributed tests"
|
||||||
@@ -30,7 +30,7 @@ runs:
|
|||||||
echo "::endgroup::"
|
echo "::endgroup::"
|
||||||
|
|
||||||
- name: Run Python tests - CPU
|
- name: Run Python tests - CPU
|
||||||
if: ${{ inputs.cpu-only == 'true' }}
|
if: ${{ inputs.has-gpu == 'false' }}
|
||||||
shell: bash
|
shell: bash
|
||||||
env:
|
env:
|
||||||
DEVICE: cpu
|
DEVICE: cpu
|
||||||
@@ -40,7 +40,7 @@ runs:
|
|||||||
echo "::endgroup::"
|
echo "::endgroup::"
|
||||||
|
|
||||||
- name: Run Python tests - GPU
|
- name: Run Python tests - GPU
|
||||||
if: ${{ inputs.cpu-only == 'false' }}
|
if: ${{ inputs.has-gpu == 'true' }}
|
||||||
shell: bash
|
shell: bash
|
||||||
env:
|
env:
|
||||||
DEVICE: gpu
|
DEVICE: gpu
|
||||||
@@ -59,7 +59,7 @@ runs:
|
|||||||
echo "::endgroup::"
|
echo "::endgroup::"
|
||||||
|
|
||||||
- name: Run CPP tests - GPU
|
- name: Run CPP tests - GPU
|
||||||
if: ${{ inputs.cpu-only == 'false' }}
|
if: ${{ inputs.has-gpu == 'true' }}
|
||||||
shell: bash
|
shell: bash
|
||||||
env:
|
env:
|
||||||
DEVICE: gpu
|
DEVICE: gpu
|
||||||
|
|||||||
@@ -17,29 +17,51 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
check_lint:
|
check_lint:
|
||||||
|
name: Check Lint
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
- uses: pre-commit/action@v3.0.1
|
- uses: pre-commit/action@v3.0.1
|
||||||
|
|
||||||
linux_build_and_test:
|
linux_build_and_test:
|
||||||
|
name: Linux (cpu, ${{ matrix.arch }})
|
||||||
needs: check_lint
|
needs: check_lint
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
|
||||||
runner:
|
|
||||||
- ubuntu-22.04
|
|
||||||
- ubuntu-22.04-arm
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
runs-on: ${{ matrix.runner }}
|
matrix:
|
||||||
|
arch: ['x86_64', 'aarch64']
|
||||||
|
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
- uses: ./.github/actions/setup-linux
|
- uses: ./.github/actions/setup-linux
|
||||||
- uses: ./.github/actions/build-linux
|
- uses: ./.github/actions/build-linux
|
||||||
- uses: ./.github/actions/test-linux
|
- uses: ./.github/actions/test-linux
|
||||||
|
|
||||||
|
cuda_build_and_test:
|
||||||
|
name: Linux (${{ matrix.toolkit }}, ${{ matrix.arch }})
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
|
needs: check_lint
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: ['x86_64', 'aarch64']
|
||||||
|
toolkit: ['cuda-12.6', 'cuda-12.9']
|
||||||
|
runs-on: ${{ matrix.arch == 'x86_64' && 'gpu-t4-4-core' || 'ubuntu-22.04-arm' }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: ./.github/actions/setup-linux
|
||||||
with:
|
with:
|
||||||
cpu-only: true
|
toolkit: ${{ matrix.toolkit }}
|
||||||
|
- uses: ./.github/actions/build-linux
|
||||||
|
with:
|
||||||
|
toolkit: ${{ matrix.toolkit }}
|
||||||
|
- uses: ./.github/actions/test-linux
|
||||||
|
if: matrix.arch == 'x86_64'
|
||||||
|
with:
|
||||||
|
has-gpu: true
|
||||||
|
|
||||||
mac_build_and_test:
|
mac_build_and_test:
|
||||||
|
name: macOS (${{ matrix.macos-target }})
|
||||||
if: github.repository == 'ml-explore/mlx'
|
if: github.repository == 'ml-explore/mlx'
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
@@ -49,38 +71,21 @@ jobs:
|
|||||||
MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos-target }}
|
MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos-target }}
|
||||||
needs: check_lint
|
needs: check_lint
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
- uses: ./.github/actions/setup-macos
|
- uses: ./.github/actions/setup-macos
|
||||||
- uses: ./.github/actions/build-macos
|
- uses: ./.github/actions/build-macos
|
||||||
|
|
||||||
cuda_build_and_test:
|
|
||||||
if: github.repository == 'ml-explore/mlx'
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
toolkit: ['cuda-12.6', 'cuda-12.9']
|
|
||||||
runs-on: gpu-t4-4-core
|
|
||||||
needs: check_lint
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v5
|
|
||||||
- uses: ./.github/actions/setup-linux
|
|
||||||
with:
|
|
||||||
toolkit: ${{ matrix.toolkit }}
|
|
||||||
- uses: ./.github/actions/build-cuda
|
|
||||||
with:
|
|
||||||
toolkit: ${{ matrix.toolkit }}
|
|
||||||
- uses: ./.github/actions/test-linux
|
|
||||||
|
|
||||||
build_documentation:
|
build_documentation:
|
||||||
|
name: Build Documentation
|
||||||
if: github.repository == 'ml-explore/mlx'
|
if: github.repository == 'ml-explore/mlx'
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
needs: check_lint
|
needs: check_lint
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
- uses: ./.github/actions/build-docs
|
- uses: ./.github/actions/build-docs
|
||||||
|
|
||||||
linux_fedora_build_cpp:
|
linux_fedora_build_cpp:
|
||||||
name: Linux Fedora CPP Build (${{ matrix.arch }})
|
name: Linux Fedora (${{ matrix.arch }})
|
||||||
needs: check_lint
|
needs: check_lint
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
@@ -96,7 +101,7 @@ jobs:
|
|||||||
image: fedora:42
|
image: fedora:42
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v5
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: CPP Build Test - No Release
|
- name: CPP Build Test - No Release
|
||||||
run: |
|
run: |
|
||||||
2
.github/workflows/documentation.yml
vendored
2
.github/workflows/documentation.yml
vendored
@@ -10,7 +10,7 @@ jobs:
|
|||||||
build:
|
build:
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
- uses: ./.github/actions/build-docs
|
- uses: ./.github/actions/build-docs
|
||||||
|
|
||||||
deploy:
|
deploy:
|
||||||
|
|||||||
10
.github/workflows/nightly.yml
vendored
10
.github/workflows/nightly.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
|||||||
python_version: ["3.10", "3.14"]
|
python_version: ["3.10", "3.14"]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
- uses: ./.github/actions/setup-linux
|
- uses: ./.github/actions/setup-linux
|
||||||
- uses: ./.github/actions/build-linux-release
|
- uses: ./.github/actions/build-linux-release
|
||||||
with:
|
with:
|
||||||
@@ -46,14 +46,12 @@ jobs:
|
|||||||
- ubuntu-22.04-arm
|
- ubuntu-22.04-arm
|
||||||
runs-on: ${{ matrix.runner }}
|
runs-on: ${{ matrix.runner }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
- uses: ./.github/actions/setup-linux
|
- uses: ./.github/actions/setup-linux
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python_version }}
|
python-version: ${{ matrix.python_version }}
|
||||||
- uses: ./.github/actions/build-linux
|
- uses: ./.github/actions/build-linux
|
||||||
- uses: ./.github/actions/test-linux
|
- uses: ./.github/actions/test-linux
|
||||||
with:
|
|
||||||
cpu-only: true
|
|
||||||
|
|
||||||
build_mac_release:
|
build_mac_release:
|
||||||
if: github.repository == 'ml-explore/mlx'
|
if: github.repository == 'ml-explore/mlx'
|
||||||
@@ -62,7 +60,7 @@ jobs:
|
|||||||
python-version: ["3.10", "3.13"]
|
python-version: ["3.10", "3.13"]
|
||||||
runs-on: [self-hosted, macos]
|
runs-on: [self-hosted, macos]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
- uses: ./.github/actions/setup-macos
|
- uses: ./.github/actions/setup-macos
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
@@ -82,7 +80,7 @@ jobs:
|
|||||||
if: github.repository == 'ml-explore/mlx'
|
if: github.repository == 'ml-explore/mlx'
|
||||||
runs-on: ubuntu-22-large
|
runs-on: ubuntu-22-large
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
- uses: ./.github/actions/setup-linux
|
- uses: ./.github/actions/setup-linux
|
||||||
with:
|
with:
|
||||||
toolkit: 'cuda-12.9'
|
toolkit: 'cuda-12.9'
|
||||||
|
|||||||
10
.github/workflows/release.yml
vendored
10
.github/workflows/release.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
|||||||
if: github.repository == 'ml-explore/mlx'
|
if: github.repository == 'ml-explore/mlx'
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
- uses: ./.github/actions/build-docs
|
- uses: ./.github/actions/build-docs
|
||||||
|
|
||||||
deploy_documentation:
|
deploy_documentation:
|
||||||
@@ -53,7 +53,7 @@ jobs:
|
|||||||
PYPI_RELEASE: 1
|
PYPI_RELEASE: 1
|
||||||
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
- uses: ./.github/actions/setup-linux
|
- uses: ./.github/actions/setup-linux
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python_version }}
|
python-version: ${{ matrix.python_version }}
|
||||||
@@ -86,7 +86,7 @@ jobs:
|
|||||||
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
- uses: ./.github/actions/setup-macos
|
- uses: ./.github/actions/setup-macos
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
@@ -133,14 +133,12 @@ jobs:
|
|||||||
PYPI_RELEASE: 1
|
PYPI_RELEASE: 1
|
||||||
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
- uses: ./.github/actions/setup-linux
|
- uses: ./.github/actions/setup-linux
|
||||||
with:
|
with:
|
||||||
toolkit: 'cuda-12.9'
|
toolkit: 'cuda-12.9'
|
||||||
- name: Build Python package
|
- name: Build Python package
|
||||||
uses: ./.github/actions/build-cuda-release
|
uses: ./.github/actions/build-cuda-release
|
||||||
with:
|
|
||||||
toolkit: 'cuda-12.9'
|
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
|
|||||||
@@ -12,6 +12,167 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
complex64_t to_complex(T r, T i) {
|
||||||
|
return {static_cast<float>(r), static_cast<float>(i)};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, class Enable = void>
|
||||||
|
struct EigWork {};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct EigWork<
|
||||||
|
T,
|
||||||
|
typename std::enable_if<std::is_floating_point<T>::value>::type> {
|
||||||
|
using O = complex64_t;
|
||||||
|
|
||||||
|
char jobl;
|
||||||
|
char jobr;
|
||||||
|
int N;
|
||||||
|
int lwork;
|
||||||
|
int info;
|
||||||
|
std::vector<array::Data> buffers;
|
||||||
|
|
||||||
|
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
|
||||||
|
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1) {
|
||||||
|
T work;
|
||||||
|
int n_vecs_l = compute_eigenvectors ? N_ : 1;
|
||||||
|
int n_vecs_r = 1;
|
||||||
|
geev<T>(
|
||||||
|
&jobl,
|
||||||
|
&jobr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_l,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_r,
|
||||||
|
&work,
|
||||||
|
&lwork,
|
||||||
|
&info);
|
||||||
|
lwork = static_cast<int>(work);
|
||||||
|
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * N * 2));
|
||||||
|
if (compute_eigenvectors) {
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * N * N * 2));
|
||||||
|
}
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||||
|
}
|
||||||
|
|
||||||
|
void run(T* a, O* values, O* vectors) {
|
||||||
|
auto eig_tmp = static_cast<T*>(buffers[0].buffer.raw_ptr());
|
||||||
|
T* vec_tmp = nullptr;
|
||||||
|
if (vectors) {
|
||||||
|
vec_tmp = static_cast<T*>(buffers[1].buffer.raw_ptr());
|
||||||
|
}
|
||||||
|
auto work = static_cast<T*>(buffers.back().buffer.raw_ptr());
|
||||||
|
|
||||||
|
int n_vecs_l = vectors ? N : 1;
|
||||||
|
int n_vecs_r = 1;
|
||||||
|
geev<T>(
|
||||||
|
&jobl,
|
||||||
|
&jobr,
|
||||||
|
&N,
|
||||||
|
a,
|
||||||
|
&N,
|
||||||
|
eig_tmp,
|
||||||
|
eig_tmp + N,
|
||||||
|
vectors ? vec_tmp : nullptr,
|
||||||
|
&n_vecs_l,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_r,
|
||||||
|
work,
|
||||||
|
&lwork,
|
||||||
|
&info);
|
||||||
|
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
values[i] = to_complex(eig_tmp[i], eig_tmp[N + i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vectors) {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
if (values[i].imag() != 0) {
|
||||||
|
for (int j = 0; j < N; ++j) {
|
||||||
|
vectors[i * N + j] =
|
||||||
|
to_complex(vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]);
|
||||||
|
vectors[(i + 1) * N + j] =
|
||||||
|
to_complex(vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]);
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
} else {
|
||||||
|
for (int j = 0; j < N; ++j) {
|
||||||
|
vectors[i * N + j] = to_complex(vec_tmp[i * N + j], T(0.0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct EigWork<std::complex<float>> {
|
||||||
|
using T = std::complex<float>;
|
||||||
|
using R = float;
|
||||||
|
using O = T;
|
||||||
|
|
||||||
|
char jobl;
|
||||||
|
char jobr;
|
||||||
|
int N;
|
||||||
|
int lwork;
|
||||||
|
int lrwork;
|
||||||
|
int info;
|
||||||
|
std::vector<array::Data> buffers;
|
||||||
|
|
||||||
|
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
|
||||||
|
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1), lrwork(2 * N_) {
|
||||||
|
T work;
|
||||||
|
R rwork;
|
||||||
|
int n_vecs_l = compute_eigenvectors ? N_ : 1;
|
||||||
|
int n_vecs_r = 1;
|
||||||
|
geev<T>(
|
||||||
|
&jobl,
|
||||||
|
&jobr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_l,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_r,
|
||||||
|
&work,
|
||||||
|
&lwork,
|
||||||
|
&rwork,
|
||||||
|
&info);
|
||||||
|
lwork = static_cast<int>(work.real());
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
|
||||||
|
}
|
||||||
|
|
||||||
|
void run(T* a, T* values, T* vectors) {
|
||||||
|
int n_vecs_l = vectors ? N : 1;
|
||||||
|
int n_vecs_r = 1;
|
||||||
|
geev<T>(
|
||||||
|
&jobl,
|
||||||
|
&jobr,
|
||||||
|
&N,
|
||||||
|
a,
|
||||||
|
&N,
|
||||||
|
values,
|
||||||
|
vectors,
|
||||||
|
&n_vecs_l,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_r,
|
||||||
|
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
||||||
|
&lwork,
|
||||||
|
static_cast<R*>(buffers[1].buffer.raw_ptr()),
|
||||||
|
&info);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void eig_impl(
|
void eig_impl(
|
||||||
array& a,
|
array& a,
|
||||||
@@ -19,101 +180,39 @@ void eig_impl(
|
|||||||
array& values,
|
array& values,
|
||||||
bool compute_eigenvectors,
|
bool compute_eigenvectors,
|
||||||
Stream stream) {
|
Stream stream) {
|
||||||
using OT = std::complex<T>;
|
|
||||||
auto a_ptr = a.data<T>();
|
auto a_ptr = a.data<T>();
|
||||||
auto eig_ptr = values.data<OT>();
|
auto val_ptr = values.data<complex64_t>();
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
encoder.set_input_array(a);
|
encoder.set_input_array(a);
|
||||||
encoder.set_output_array(values);
|
encoder.set_output_array(values);
|
||||||
OT* vec_ptr = nullptr;
|
complex64_t* vec_ptr = nullptr;
|
||||||
if (compute_eigenvectors) {
|
if (compute_eigenvectors) {
|
||||||
encoder.set_output_array(vectors);
|
encoder.set_output_array(vectors);
|
||||||
vec_ptr = vectors.data<OT>();
|
vec_ptr = vectors.data<complex64_t>();
|
||||||
}
|
}
|
||||||
encoder.dispatch([a_ptr,
|
encoder.dispatch([a_ptr,
|
||||||
|
val_ptr,
|
||||||
vec_ptr,
|
vec_ptr,
|
||||||
eig_ptr,
|
|
||||||
compute_eigenvectors,
|
compute_eigenvectors,
|
||||||
N = vectors.shape(-1),
|
N = vectors.shape(-1),
|
||||||
size = vectors.size()]() mutable {
|
size = vectors.size()]() mutable {
|
||||||
// Work query
|
|
||||||
char jobr = 'N';
|
char jobr = 'N';
|
||||||
char jobl = compute_eigenvectors ? 'V' : 'N';
|
char jobl = compute_eigenvectors ? 'V' : 'N';
|
||||||
int n_vecs_r = 1;
|
|
||||||
int n_vecs_l = compute_eigenvectors ? N : 1;
|
|
||||||
int lwork = -1;
|
|
||||||
int info;
|
|
||||||
{
|
|
||||||
T work;
|
|
||||||
geev<T>(
|
|
||||||
&jobl,
|
|
||||||
&jobr,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
&n_vecs_l,
|
|
||||||
nullptr,
|
|
||||||
&n_vecs_r,
|
|
||||||
&work,
|
|
||||||
&lwork,
|
|
||||||
&info);
|
|
||||||
lwork = static_cast<int>(work);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
|
EigWork<T> work(jobl, jobr, N, compute_eigenvectors);
|
||||||
auto vec_tmp_data =
|
|
||||||
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
|
|
||||||
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
|
|
||||||
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
|
|
||||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
|
||||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||||
geev<T>(
|
work.run(a_ptr, val_ptr, vec_ptr);
|
||||||
&jobl,
|
a_ptr += N * N;
|
||||||
&jobr,
|
val_ptr += N;
|
||||||
&N,
|
|
||||||
a_ptr,
|
|
||||||
&N,
|
|
||||||
eig_tmp,
|
|
||||||
eig_tmp + N,
|
|
||||||
vec_tmp,
|
|
||||||
&n_vecs_l,
|
|
||||||
nullptr,
|
|
||||||
&n_vecs_r,
|
|
||||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
|
||||||
&lwork,
|
|
||||||
&info);
|
|
||||||
for (int i = 0; i < N; ++i) {
|
|
||||||
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
|
|
||||||
}
|
|
||||||
if (vec_ptr) {
|
if (vec_ptr) {
|
||||||
for (int i = 0; i < N; ++i) {
|
|
||||||
if (eig_ptr[i].imag() != 0) {
|
|
||||||
// This vector and the next are a pair
|
|
||||||
for (int j = 0; j < N; ++j) {
|
|
||||||
vec_ptr[i * N + j] = {
|
|
||||||
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
|
|
||||||
vec_ptr[(i + 1) * N + j] = {
|
|
||||||
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
|
|
||||||
}
|
|
||||||
i += 1;
|
|
||||||
} else {
|
|
||||||
for (int j = 0; j < N; ++j) {
|
|
||||||
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
vec_ptr += N * N;
|
vec_ptr += N * N;
|
||||||
}
|
}
|
||||||
a_ptr += N * N;
|
if (work.info != 0) {
|
||||||
eig_ptr += N;
|
|
||||||
if (info != 0) {
|
|
||||||
std::stringstream msg;
|
std::stringstream msg;
|
||||||
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
|
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||||
<< info;
|
<< work.info;
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -165,8 +264,17 @@ void Eig::eval_cpu(
|
|||||||
case float32:
|
case float32:
|
||||||
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
|
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
eig_impl<double>(
|
||||||
|
a_copy, vectors, values, compute_eigenvectors_, stream());
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
eig_impl<std::complex<float>>(
|
||||||
|
a_copy, vectors, values, compute_eigenvectors_, stream());
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
|
throw std::runtime_error(
|
||||||
|
"[Eig::eval_cpu] only supports float32, float64, or complex64.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -45,9 +45,7 @@
|
|||||||
INSTANTIATE_LAPACK_REAL(geqrf)
|
INSTANTIATE_LAPACK_REAL(geqrf)
|
||||||
INSTANTIATE_LAPACK_REAL(orgqr)
|
INSTANTIATE_LAPACK_REAL(orgqr)
|
||||||
INSTANTIATE_LAPACK_REAL(syevd)
|
INSTANTIATE_LAPACK_REAL(syevd)
|
||||||
INSTANTIATE_LAPACK_REAL(geev)
|
|
||||||
INSTANTIATE_LAPACK_REAL(potrf)
|
INSTANTIATE_LAPACK_REAL(potrf)
|
||||||
INSTANTIATE_LAPACK_REAL(gesdd)
|
|
||||||
INSTANTIATE_LAPACK_REAL(getrf)
|
INSTANTIATE_LAPACK_REAL(getrf)
|
||||||
INSTANTIATE_LAPACK_REAL(getri)
|
INSTANTIATE_LAPACK_REAL(getri)
|
||||||
INSTANTIATE_LAPACK_REAL(trtri)
|
INSTANTIATE_LAPACK_REAL(trtri)
|
||||||
@@ -63,3 +61,20 @@ INSTANTIATE_LAPACK_REAL(trtri)
|
|||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_LAPACK_COMPLEX(heevd)
|
INSTANTIATE_LAPACK_COMPLEX(heevd)
|
||||||
|
|
||||||
|
#define INSTANTIATE_LAPACK_ALL(FUNC) \
|
||||||
|
template <typename T, typename... Args> \
|
||||||
|
void FUNC(Args... args) { \
|
||||||
|
if constexpr (std::is_same_v<T, float>) { \
|
||||||
|
MLX_LAPACK_FUNC(s##FUNC)(std::forward<Args>(args)...); \
|
||||||
|
} else if constexpr (std::is_same_v<T, double>) { \
|
||||||
|
MLX_LAPACK_FUNC(d##FUNC)(std::forward<Args>(args)...); \
|
||||||
|
} else if constexpr (std::is_same_v<T, std::complex<float>>) { \
|
||||||
|
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
|
||||||
|
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
|
||||||
|
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_LAPACK_ALL(geev)
|
||||||
|
INSTANTIATE_LAPACK_ALL(gesdd)
|
||||||
|
|||||||
@@ -8,6 +8,183 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <typename T, class Enable = void>
|
||||||
|
struct SVDWork {};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct SVDWork<
|
||||||
|
T,
|
||||||
|
typename std::enable_if<std::is_floating_point<T>::value>::type> {
|
||||||
|
using R = T;
|
||||||
|
|
||||||
|
int N;
|
||||||
|
int M;
|
||||||
|
int K;
|
||||||
|
int lda;
|
||||||
|
int ldu;
|
||||||
|
int ldvt;
|
||||||
|
char jobz;
|
||||||
|
std::vector<array::Data> buffers;
|
||||||
|
int lwork;
|
||||||
|
|
||||||
|
SVDWork(int N, int M, int K, char jobz)
|
||||||
|
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
|
||||||
|
T workspace_dimension = 0;
|
||||||
|
|
||||||
|
// Will contain the indices of eigenvectors that failed to converge (not
|
||||||
|
// used here but required by lapack).
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
|
||||||
|
|
||||||
|
int lwork_query = -1;
|
||||||
|
int info;
|
||||||
|
|
||||||
|
// Compute workspace size.
|
||||||
|
gesdd<T>(
|
||||||
|
/* jobz = */ &jobz,
|
||||||
|
// M and N are swapped since lapack expects column-major.
|
||||||
|
/* m = */ &N,
|
||||||
|
/* n = */ &M,
|
||||||
|
/* a = */ nullptr,
|
||||||
|
/* lda = */ &lda,
|
||||||
|
/* s = */ nullptr,
|
||||||
|
/* u = */ nullptr,
|
||||||
|
/* ldu = */ &ldu,
|
||||||
|
/* vt = */ nullptr,
|
||||||
|
/* ldvt = */ &ldvt,
|
||||||
|
/* work = */ &workspace_dimension,
|
||||||
|
/* lwork = */ &lwork_query,
|
||||||
|
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
|
||||||
|
/* info = */ &info);
|
||||||
|
|
||||||
|
if (info != 0) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
|
||||||
|
throw std::runtime_error(ss.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
lwork = workspace_dimension;
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||||
|
}
|
||||||
|
|
||||||
|
void run(T* a, R* s, T* u, T* vt) {
|
||||||
|
int info;
|
||||||
|
gesdd<T>(
|
||||||
|
/* jobz = */ &jobz,
|
||||||
|
// M and N are swapped since lapack expects column-major.
|
||||||
|
/* m = */ &N,
|
||||||
|
/* n = */ &M,
|
||||||
|
/* a = */ a,
|
||||||
|
/* lda = */ &lda,
|
||||||
|
/* s = */ s,
|
||||||
|
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||||
|
/* u = */ u,
|
||||||
|
/* ldu = */ &ldu,
|
||||||
|
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
||||||
|
/* vt = */ vt,
|
||||||
|
/* ldvt = */ &ldvt,
|
||||||
|
/* work = */ static_cast<T*>(buffers[1].buffer.raw_ptr()),
|
||||||
|
/* lwork = */ &lwork,
|
||||||
|
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
|
||||||
|
/* info = */ &info);
|
||||||
|
|
||||||
|
if (info != 0) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
||||||
|
throw std::runtime_error(ss.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct SVDWork<std::complex<float>> {
|
||||||
|
using T = std::complex<float>;
|
||||||
|
using R = float;
|
||||||
|
|
||||||
|
int N;
|
||||||
|
int M;
|
||||||
|
int K;
|
||||||
|
int lda;
|
||||||
|
int ldu;
|
||||||
|
int ldvt;
|
||||||
|
char jobz;
|
||||||
|
std::vector<array::Data> buffers;
|
||||||
|
int lwork;
|
||||||
|
|
||||||
|
SVDWork(int N, int M, int K, char jobz)
|
||||||
|
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
|
||||||
|
T workspace_dimension = 0;
|
||||||
|
|
||||||
|
// Will contain the indices of eigenvectors that failed to converge (not
|
||||||
|
// used here but required by lapack).
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
|
||||||
|
|
||||||
|
const int lrwork =
|
||||||
|
jobz == 'A' ? std::max(1, 5 * K * K + 5 * K) : std::max(1, 7 * K);
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(float) * lrwork));
|
||||||
|
|
||||||
|
int lwork_query = -1;
|
||||||
|
int work_query = -1;
|
||||||
|
int info;
|
||||||
|
|
||||||
|
// Compute workspace size.
|
||||||
|
gesdd<T>(
|
||||||
|
/* jobz = */ &jobz,
|
||||||
|
// M and N are swapped since lapack expects column-major.
|
||||||
|
/* m = */ &N,
|
||||||
|
/* n = */ &M,
|
||||||
|
/* a = */ nullptr,
|
||||||
|
/* lda = */ &lda,
|
||||||
|
/* s = */ nullptr,
|
||||||
|
/* u = */ nullptr,
|
||||||
|
/* ldu = */ &ldu,
|
||||||
|
/* vt = */ nullptr,
|
||||||
|
/* ldvt = */ &ldvt,
|
||||||
|
/* work = */ &workspace_dimension,
|
||||||
|
/* lwork = */ &lwork_query,
|
||||||
|
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
|
||||||
|
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
|
||||||
|
/* info = */ &info);
|
||||||
|
|
||||||
|
if (info != 0) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
|
||||||
|
throw std::runtime_error(ss.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
lwork = workspace_dimension.real();
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||||
|
}
|
||||||
|
|
||||||
|
void run(T* a, R* s, T* u, T* vt) {
|
||||||
|
int info;
|
||||||
|
gesdd<T>(
|
||||||
|
/* jobz = */ &jobz,
|
||||||
|
// M and N are swapped since lapack expects column-major.
|
||||||
|
/* m = */ &N,
|
||||||
|
/* n = */ &M,
|
||||||
|
/* a = */ a,
|
||||||
|
/* lda = */ &lda,
|
||||||
|
/* s = */ s,
|
||||||
|
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||||
|
/* u = */ u,
|
||||||
|
/* ldu = */ &ldu,
|
||||||
|
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
||||||
|
/* vt = */ vt,
|
||||||
|
/* ldvt = */ &ldvt,
|
||||||
|
/* work = */ static_cast<T*>(buffers[2].buffer.raw_ptr()),
|
||||||
|
/* lwork = */ &lwork,
|
||||||
|
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
|
||||||
|
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
|
||||||
|
/* info = */ &info);
|
||||||
|
|
||||||
|
if (info != 0) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
||||||
|
throw std::runtime_error(ss.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void svd_impl(
|
void svd_impl(
|
||||||
const array& a,
|
const array& a,
|
||||||
@@ -27,6 +204,8 @@ void svd_impl(
|
|||||||
const int N = a.shape(-1);
|
const int N = a.shape(-1);
|
||||||
const int K = std::min(M, N);
|
const int K = std::min(M, N);
|
||||||
|
|
||||||
|
using R = typename SVDWork<T>::R;
|
||||||
|
|
||||||
size_t num_matrices = a.size() / (M * N);
|
size_t num_matrices = a.size() / (M * N);
|
||||||
|
|
||||||
// lapack clobbers the input, so we have to make a copy.
|
// lapack clobbers the input, so we have to make a copy.
|
||||||
@@ -42,7 +221,7 @@ void svd_impl(
|
|||||||
encoder.set_input_array(a);
|
encoder.set_input_array(a);
|
||||||
auto in_ptr = in.data<T>();
|
auto in_ptr = in.data<T>();
|
||||||
T* u_ptr;
|
T* u_ptr;
|
||||||
T* s_ptr;
|
R* s_ptr;
|
||||||
T* vt_ptr;
|
T* vt_ptr;
|
||||||
|
|
||||||
if (compute_uv) {
|
if (compute_uv) {
|
||||||
@@ -58,7 +237,7 @@ void svd_impl(
|
|||||||
encoder.set_output_array(s);
|
encoder.set_output_array(s);
|
||||||
encoder.set_output_array(vt);
|
encoder.set_output_array(vt);
|
||||||
|
|
||||||
s_ptr = s.data<T>();
|
s_ptr = s.data<R>();
|
||||||
u_ptr = u.data<T>();
|
u_ptr = u.data<T>();
|
||||||
vt_ptr = vt.data<T>();
|
vt_ptr = vt.data<T>();
|
||||||
} else {
|
} else {
|
||||||
@@ -68,96 +247,26 @@ void svd_impl(
|
|||||||
|
|
||||||
encoder.set_output_array(s);
|
encoder.set_output_array(s);
|
||||||
|
|
||||||
s_ptr = s.data<T>();
|
s_ptr = s.data<R>();
|
||||||
u_ptr = nullptr;
|
u_ptr = nullptr;
|
||||||
vt_ptr = nullptr;
|
vt_ptr = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() {
|
encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() {
|
||||||
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
|
auto jobz = (u_ptr) ? 'A' : 'N';
|
||||||
const int lda = N;
|
SVDWork<T> svd_work(N, M, K, jobz);
|
||||||
// U of shape M x M. (N x N in lapack).
|
|
||||||
const int ldu = N;
|
|
||||||
// Vᵀ of shape N x N. (M x M in lapack).
|
|
||||||
const int ldvt = M;
|
|
||||||
|
|
||||||
auto jobz = (u_ptr) ? "A" : "N";
|
|
||||||
|
|
||||||
T workspace_dimension = 0;
|
|
||||||
|
|
||||||
// Will contain the indices of eigenvectors that failed to converge (not
|
|
||||||
// used here but required by lapack).
|
|
||||||
auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)};
|
|
||||||
|
|
||||||
static const int lwork_query = -1;
|
|
||||||
|
|
||||||
int info;
|
|
||||||
|
|
||||||
// Compute workspace size.
|
|
||||||
gesdd<T>(
|
|
||||||
/* jobz = */ jobz,
|
|
||||||
// M and N are swapped since lapack expects column-major.
|
|
||||||
/* m = */ &N,
|
|
||||||
/* n = */ &M,
|
|
||||||
/* a = */ nullptr,
|
|
||||||
/* lda = */ &lda,
|
|
||||||
/* s = */ nullptr,
|
|
||||||
/* u = */ nullptr,
|
|
||||||
/* ldu = */ &ldu,
|
|
||||||
/* vt = */ nullptr,
|
|
||||||
/* ldvt = */ &ldvt,
|
|
||||||
/* work = */ &workspace_dimension,
|
|
||||||
/* lwork = */ &lwork_query,
|
|
||||||
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
|
||||||
/* info = */ &info);
|
|
||||||
|
|
||||||
if (info != 0) {
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
|
|
||||||
throw std::runtime_error(ss.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
const int lwork = workspace_dimension;
|
|
||||||
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
|
||||||
|
|
||||||
// Loop over matrices.
|
// Loop over matrices.
|
||||||
for (int i = 0; i < num_matrices; i++) {
|
for (int i = 0; i < num_matrices; i++) {
|
||||||
gesdd<T>(
|
svd_work.run(
|
||||||
/* jobz = */ jobz,
|
in_ptr + M * N * i,
|
||||||
// M and N are swapped since lapack expects column-major.
|
s_ptr + K * i,
|
||||||
/* m = */ &N,
|
vt_ptr ? vt_ptr + N * N * i : nullptr,
|
||||||
/* n = */ &M,
|
u_ptr ? u_ptr + M * M * i : nullptr);
|
||||||
/* a = */ in_ptr + M * N * i,
|
|
||||||
/* lda = */ &lda,
|
|
||||||
/* s = */ s_ptr + K * i,
|
|
||||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
|
||||||
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
|
|
||||||
/* ldu = */ &ldu,
|
|
||||||
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
|
||||||
/* vt = */ u_ptr ? u_ptr + M * M * i : nullptr,
|
|
||||||
/* ldvt = */ &ldvt,
|
|
||||||
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
|
|
||||||
/* lwork = */ &lwork,
|
|
||||||
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
|
||||||
/* info = */ &info);
|
|
||||||
|
|
||||||
if (info != 0) {
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
|
||||||
throw std::runtime_error(ss.str());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
encoder.add_temporary(in);
|
encoder.add_temporary(in);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void compute_svd(
|
|
||||||
const array& a,
|
|
||||||
bool compute_uv,
|
|
||||||
std::vector<array>& outputs,
|
|
||||||
Stream stream) {}
|
|
||||||
|
|
||||||
void SVD::eval_cpu(
|
void SVD::eval_cpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
@@ -168,9 +277,12 @@ void SVD::eval_cpu(
|
|||||||
case float64:
|
case float64:
|
||||||
svd_impl<double>(inputs[0], outputs, compute_uv_, stream());
|
svd_impl<double>(inputs[0], outputs, compute_uv_, stream());
|
||||||
break;
|
break;
|
||||||
|
case complex64:
|
||||||
|
svd_impl<std::complex<float>>(inputs[0], outputs, compute_uv_, stream());
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[SVD::eval_cpu] only supports float32 or float64.");
|
"[SVD::eval_cpu] only supports float32, float64, or complex64.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -123,14 +123,21 @@ if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
|
|||||||
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
|
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
|
# Use native CUDA arch by default.
|
||||||
# managed memory.
|
|
||||||
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
|
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
|
||||||
execute_process(
|
execute_process(
|
||||||
COMMAND bash detect_cuda_arch.sh
|
COMMAND __nvcc_device_query
|
||||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
|
||||||
OUTPUT_VARIABLE MLX_CUDA_ARCHITECTURES
|
OUTPUT_VARIABLE MLX_CUDA_ARCHITECTURES
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE)
|
OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||||
|
set(UPGRADABLE_ARCHITECTURES "90;100;121")
|
||||||
|
if(MLX_CUDA_ARCHITECTURES STREQUAL "")
|
||||||
|
message(
|
||||||
|
FATAL_ERROR
|
||||||
|
"Can not get native CUDA arch, must set MLX_CUDA_ARCHITECTURES")
|
||||||
|
elseif(MLX_CUDA_ARCHITECTURES IN_LIST UPGRADABLE_ARCHITECTURES)
|
||||||
|
# Use arch-specific compute capability whenever possible.
|
||||||
|
set(MLX_CUDA_ARCHITECTURES "${MLX_CUDA_ARCHITECTURES}a")
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||||
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
||||||
|
|||||||
@@ -154,17 +154,21 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
|
|||||||
}
|
}
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
buf = new CudaBuffer{nullptr, size, device};
|
|
||||||
cudaError_t err;
|
cudaError_t err;
|
||||||
|
void* data = nullptr;
|
||||||
if (device == -1) {
|
if (device == -1) {
|
||||||
err = cudaMallocManaged(&buf->data, size);
|
err = cudaMallocManaged(&data, size);
|
||||||
} else {
|
} else {
|
||||||
err = cudaMallocAsync(&buf->data, size, stream);
|
err = cudaMallocAsync(&data, size, stream);
|
||||||
}
|
}
|
||||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||||
throw std::runtime_error(fmt::format(
|
throw std::runtime_error(fmt::format(
|
||||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||||
}
|
}
|
||||||
|
if (!data) {
|
||||||
|
return Buffer{nullptr};
|
||||||
|
}
|
||||||
|
buf = new CudaBuffer{data, size, device};
|
||||||
}
|
}
|
||||||
lock.lock();
|
lock.lock();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,6 +29,10 @@ class CudaHandle {
|
|||||||
}
|
}
|
||||||
|
|
||||||
~CudaHandle() {
|
~CudaHandle() {
|
||||||
|
// Skip if there was an error to avoid throwing in the destructors
|
||||||
|
if (cudaPeekAtLastError() != cudaSuccess) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
reset();
|
reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
arch=`__nvcc_device_query`
|
|
||||||
case "$arch" in
|
|
||||||
"90")
|
|
||||||
echo "90a" ;;
|
|
||||||
"100")
|
|
||||||
echo "100a" ;;
|
|
||||||
"121")
|
|
||||||
echo "121a" ;;
|
|
||||||
*)
|
|
||||||
echo "native" ;;
|
|
||||||
esac
|
|
||||||
@@ -24,12 +24,21 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool use_cuda_graphs() {
|
bool use_cuda_graphs() {
|
||||||
static bool use_graphs = []() {
|
static bool use_graphs = env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
||||||
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
|
||||||
}();
|
|
||||||
return use_graphs;
|
return use_graphs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char* save_cuda_graphs_dot_file() {
|
||||||
|
static const char* filename = []() -> const char* {
|
||||||
|
const char* env = std::getenv("MLX_SAVE_CUDA_GRAPHS_DOT_FILE");
|
||||||
|
if (env && std::strlen(env) == 0) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return env;
|
||||||
|
}();
|
||||||
|
return filename;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Device::Device(int device) : device_(device) {
|
Device::Device(int device) : device_(device) {
|
||||||
@@ -115,18 +124,17 @@ CommandEncoder::ConcurrentContext::~ConcurrentContext() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Use an empty graph node for synchronization
|
// Use an empty graph node for synchronization
|
||||||
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
|
CommandEncoder::GraphNode empty{NULL, "E", std::to_string(enc.node_count_++)};
|
||||||
enc.empty_node_count_++;
|
|
||||||
CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0));
|
CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0));
|
||||||
|
|
||||||
// Insert the concurrent -> empty node dependencies
|
// Insert the concurrent -> empty node dependencies
|
||||||
for (auto& from : enc.concurrent_nodes_) {
|
for (auto& from : enc.concurrent_nodes_) {
|
||||||
enc.from_nodes_.push_back(from.node);
|
enc.from_nodes_.push_back(from.node);
|
||||||
enc.to_nodes_.push_back(empty.node);
|
enc.to_nodes_.push_back(empty.node);
|
||||||
enc.graph_key_ += from.id;
|
enc.graph_deps_key_ += from.id;
|
||||||
enc.graph_key_ += from.node_type;
|
enc.graph_deps_key_ += "-";
|
||||||
enc.graph_key_ += empty.id;
|
enc.graph_deps_key_ += empty.id;
|
||||||
enc.graph_key_ += empty.node_type;
|
enc.graph_deps_key_ += "-";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert the input -> concurrent node dependencies without updating output
|
// Insert the input -> concurrent node dependencies without updating output
|
||||||
@@ -141,9 +149,6 @@ CommandEncoder::ConcurrentContext::~ConcurrentContext() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::insert_graph_dependencies(GraphNode node) {
|
void CommandEncoder::insert_graph_dependencies(GraphNode node) {
|
||||||
if (node.node_type == 'G') {
|
|
||||||
graph_node_count_++;
|
|
||||||
}
|
|
||||||
node.id = std::to_string(node_count_++);
|
node.id = std::to_string(node_count_++);
|
||||||
if (in_concurrent_) {
|
if (in_concurrent_) {
|
||||||
concurrent_nodes_.push_back(std::move(node));
|
concurrent_nodes_.push_back(std::move(node));
|
||||||
@@ -155,6 +160,10 @@ void CommandEncoder::insert_graph_dependencies(GraphNode node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
||||||
|
for (auto& node : nodes) {
|
||||||
|
graph_nodes_key_ += node.node_type;
|
||||||
|
graph_nodes_key_ += "-";
|
||||||
|
}
|
||||||
std::vector<GraphNode> deps;
|
std::vector<GraphNode> deps;
|
||||||
{
|
{
|
||||||
// Dependencies must be added in the same order to produce a consistent
|
// Dependencies must be added in the same order to produce a consistent
|
||||||
@@ -182,10 +191,10 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
|||||||
for (auto& to : nodes) {
|
for (auto& to : nodes) {
|
||||||
from_nodes_.push_back(from.node);
|
from_nodes_.push_back(from.node);
|
||||||
to_nodes_.push_back(to.node);
|
to_nodes_.push_back(to.node);
|
||||||
graph_key_ += from.id;
|
graph_deps_key_ += from.id;
|
||||||
graph_key_ += from.node_type;
|
graph_deps_key_ += "-";
|
||||||
graph_key_ += to.id;
|
graph_deps_key_ += to.id;
|
||||||
graph_key_ += to.node_type;
|
graph_deps_key_ += "-";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -309,13 +318,46 @@ void CommandEncoder::add_kernel_node(
|
|||||||
void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
|
void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
|
||||||
cudaGraphNode_t node;
|
cudaGraphNode_t node;
|
||||||
CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
insert_graph_dependencies(GraphNode{node, "K"});
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
||||||
CUgraphNode node;
|
CUgraphNode node;
|
||||||
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
insert_graph_dependencies(GraphNode{node, "K"});
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
|
||||||
|
// CUDA graphs do not get updated correctly if a kernel node getting updated
|
||||||
|
// has a different cluster shape than the node it's being updated with.
|
||||||
|
size_t num_nodes = 0;
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));
|
||||||
|
if (num_nodes == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<cudaGraphNode_t> nodes(num_nodes);
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
|
||||||
|
for (const auto& node : nodes) {
|
||||||
|
cudaGraphNodeType type;
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
||||||
|
if (type != cudaGraphNodeTypeKernel) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
cudaLaunchAttributeValue cluster_dim;
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
||||||
|
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
||||||
|
// Only dim.x can be greater than 1
|
||||||
|
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Only one child node allowed when subgraph uses clusters
|
||||||
|
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
cluster_dim_x = cluster_dim.clusterDim.x;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||||
@@ -328,8 +370,11 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
cudaGraphNode_t node;
|
cudaGraphNode_t node;
|
||||||
|
int cluster_dim_x = 0;
|
||||||
|
is_graph_updatable_ = is_graph_updatable(child, cluster_dim_x);
|
||||||
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||||
insert_graph_dependencies(GraphNode{node, 'G'});
|
insert_graph_dependencies(
|
||||||
|
GraphNode{node, "G" + std::to_string(cluster_dim_x)});
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CommandEncoder::needs_commit() {
|
bool CommandEncoder::needs_commit() {
|
||||||
@@ -354,44 +399,53 @@ void CommandEncoder::commit() {
|
|||||||
from_nodes_.size()));
|
from_nodes_.size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
graph_key_ += ".";
|
|
||||||
graph_key_ += std::to_string(node_count_);
|
|
||||||
graph_key_ += ".";
|
|
||||||
graph_key_ += std::to_string(graph_node_count_);
|
|
||||||
graph_key_ += ".";
|
|
||||||
graph_key_ += std::to_string(empty_node_count_);
|
|
||||||
|
|
||||||
CudaGraphExec& graph_exec = graph_cache_[graph_key_];
|
|
||||||
|
|
||||||
if (graph_exec != nullptr) {
|
|
||||||
cudaGraphExecUpdateResult update_result;
|
|
||||||
#if CUDART_VERSION >= 12000
|
|
||||||
cudaGraphExecUpdateResultInfo info;
|
|
||||||
cudaGraphExecUpdate(graph_exec, graph_, &info);
|
|
||||||
update_result = info.result;
|
|
||||||
#else
|
|
||||||
cudaGraphNode_t error_node;
|
|
||||||
cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result);
|
|
||||||
#endif // CUDART_VERSION >= 12000
|
|
||||||
if (update_result != cudaGraphExecUpdateSuccess) {
|
|
||||||
cudaGetLastError(); // reset error
|
|
||||||
graph_exec.reset();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (graph_exec == nullptr) {
|
|
||||||
graph_exec.instantiate(graph_);
|
|
||||||
}
|
|
||||||
device_.make_current();
|
device_.make_current();
|
||||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
|
||||||
|
if (!is_graph_updatable_) {
|
||||||
|
CudaGraphExec graph_exec;
|
||||||
|
graph_exec.instantiate(graph_);
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||||
|
} else {
|
||||||
|
auto graph_key = graph_nodes_key_ + ":" + graph_deps_key_;
|
||||||
|
auto& graph_exec = graph_cache_[graph_key];
|
||||||
|
|
||||||
|
if (graph_exec != nullptr) {
|
||||||
|
cudaGraphExecUpdateResult update_result;
|
||||||
|
#if CUDART_VERSION >= 12000
|
||||||
|
cudaGraphExecUpdateResultInfo info;
|
||||||
|
cudaGraphExecUpdate(graph_exec, graph_, &info);
|
||||||
|
update_result = info.result;
|
||||||
|
#else
|
||||||
|
cudaGraphNode_t error_node;
|
||||||
|
cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result);
|
||||||
|
#endif // CUDART_VERSION >= 12000
|
||||||
|
if (update_result != cudaGraphExecUpdateSuccess) {
|
||||||
|
cudaGetLastError(); // reset error
|
||||||
|
graph_exec.reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (graph_exec == nullptr) {
|
||||||
|
graph_exec.instantiate(graph_);
|
||||||
|
}
|
||||||
|
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save cuda graph to dot file
|
||||||
|
if (const char* filename = save_cuda_graphs_dot_file(); filename) {
|
||||||
|
static int count = 0;
|
||||||
|
auto path = fmt::format("{}_{}.dot", filename, ++count);
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphDebugDotPrint(graph_, path.c_str(), 0));
|
||||||
|
}
|
||||||
|
|
||||||
// Reset state
|
// Reset state
|
||||||
graph_node_count_ = 0;
|
|
||||||
empty_node_count_ = 0;
|
|
||||||
from_nodes_.clear();
|
from_nodes_.clear();
|
||||||
to_nodes_.clear();
|
to_nodes_.clear();
|
||||||
graph_key_.clear();
|
graph_deps_key_.clear();
|
||||||
|
graph_nodes_key_.clear();
|
||||||
node_map_.clear();
|
node_map_.clear();
|
||||||
graph_ = CudaGraph(device_);
|
graph_ = CudaGraph(device_);
|
||||||
|
is_graph_updatable_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Put completion handlers in a batch.
|
// Put completion handlers in a batch.
|
||||||
|
|||||||
@@ -106,8 +106,9 @@ class CommandEncoder {
|
|||||||
cudaGraphNode_t node;
|
cudaGraphNode_t node;
|
||||||
// K = kernel
|
// K = kernel
|
||||||
// E = empty
|
// E = empty
|
||||||
// G = subgraph
|
// G* = subgraph (with metadata)
|
||||||
char node_type;
|
// Symbols ':', '-' are reserved as separators
|
||||||
|
std::string node_type;
|
||||||
std::string id;
|
std::string id;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -119,12 +120,11 @@ class CommandEncoder {
|
|||||||
CudaGraph graph_;
|
CudaGraph graph_;
|
||||||
Worker worker_;
|
Worker worker_;
|
||||||
char node_count_{0};
|
char node_count_{0};
|
||||||
char graph_node_count_{0};
|
|
||||||
char empty_node_count_{0};
|
|
||||||
bool in_concurrent_{false};
|
bool in_concurrent_{false};
|
||||||
std::vector<cudaGraphNode_t> from_nodes_;
|
std::vector<cudaGraphNode_t> from_nodes_;
|
||||||
std::vector<cudaGraphNode_t> to_nodes_;
|
std::vector<cudaGraphNode_t> to_nodes_;
|
||||||
std::string graph_key_;
|
std::string graph_nodes_key_;
|
||||||
|
std::string graph_deps_key_;
|
||||||
std::vector<GraphNode> concurrent_nodes_;
|
std::vector<GraphNode> concurrent_nodes_;
|
||||||
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
||||||
LRUCache<std::string, CudaGraphExec> graph_cache_;
|
LRUCache<std::string, CudaGraphExec> graph_cache_;
|
||||||
@@ -132,6 +132,7 @@ class CommandEncoder {
|
|||||||
std::vector<std::uintptr_t> active_outputs_;
|
std::vector<std::uintptr_t> active_outputs_;
|
||||||
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
|
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
|
||||||
size_t bytes_in_graph_{0};
|
size_t bytes_in_graph_{0};
|
||||||
|
bool is_graph_updatable_{true};
|
||||||
int max_ops_per_graph_;
|
int max_ops_per_graph_;
|
||||||
int max_mb_per_graph_;
|
int max_mb_per_graph_;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -305,6 +305,7 @@ void Event::wait() {
|
|||||||
} else {
|
} else {
|
||||||
event->atomic->wait(value());
|
event->atomic->wait(value());
|
||||||
}
|
}
|
||||||
|
CHECK_CUDA_ERROR(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Event::wait(Stream s) {
|
void Event::wait(Stream s) {
|
||||||
|
|||||||
@@ -22,26 +22,28 @@ inline __device__ float2 plus_f2(const float2& a, const float2& b) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
|
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
|
||||||
template <typename T, int BLOCK_DIM>
|
template <typename T, int BLOCK_DIM, int GROUP_DIM = WARP_SIZE>
|
||||||
struct BlockBroadcastReduce {
|
struct BlockBroadcastReduce {
|
||||||
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
|
using TempStorage = T[std::max(BLOCK_DIM / WARP_SIZE, 1)];
|
||||||
static_assert(BLOCK_DIM % WARP_SIZE == 0);
|
|
||||||
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
|
|
||||||
|
|
||||||
cg::thread_block& block;
|
cg::thread_block& block;
|
||||||
TempStorage& temp;
|
TempStorage& temp;
|
||||||
|
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
|
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
|
||||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
auto warp = cg::tiled_partition<GROUP_DIM>(block);
|
||||||
T x = cg::reduce(warp, input, op);
|
T x = cg::reduce(warp, input, op);
|
||||||
if (warp.thread_rank() == 0) {
|
if constexpr (BLOCK_DIM > GROUP_DIM) {
|
||||||
temp[warp.meta_group_rank()] = x;
|
if (warp.thread_rank() == 0) {
|
||||||
|
temp[warp.meta_group_rank()] = x;
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
|
||||||
|
: init_value;
|
||||||
|
return cg::reduce(warp, x, op);
|
||||||
|
} else {
|
||||||
|
return x;
|
||||||
}
|
}
|
||||||
block.sync();
|
|
||||||
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
|
|
||||||
: init_value;
|
|
||||||
return cg::reduce(warp, x, op);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ T Sum(const T& input) {
|
__device__ T Sum(const T& input) {
|
||||||
@@ -49,6 +51,52 @@ struct BlockBroadcastReduce {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T, int BLOCK_DIM, int REDUCE_DIM, int N_READS = 4>
|
||||||
|
__global__ void rms_norm_small(
|
||||||
|
const T* x,
|
||||||
|
const T* w,
|
||||||
|
T* out,
|
||||||
|
float eps,
|
||||||
|
uint32_t axis_size,
|
||||||
|
uint32_t n_rows,
|
||||||
|
int64_t w_stride) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
|
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM, REDUCE_DIM>;
|
||||||
|
__shared__ typename BlockReduceT::TempStorage temp;
|
||||||
|
|
||||||
|
auto row =
|
||||||
|
(grid.block_rank() * block.dim_threads().y) + block.thread_index().y;
|
||||||
|
if (row >= n_rows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
x += row * axis_size;
|
||||||
|
out += row * axis_size;
|
||||||
|
|
||||||
|
// Normalizer.
|
||||||
|
float normalizer = 0;
|
||||||
|
auto index = block.thread_index().x;
|
||||||
|
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
float t = static_cast<float>(xn[i]);
|
||||||
|
normalizer += t * t;
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
|
||||||
|
normalizer = rsqrt(normalizer / axis_size + eps);
|
||||||
|
|
||||||
|
// Outputs.
|
||||||
|
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
float y = static_cast<float>(xn[i]) * normalizer;
|
||||||
|
xn[i] = wn[i] * static_cast<T>(y);
|
||||||
|
}
|
||||||
|
store_vector<N_READS>(out, index, xn, axis_size);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, int BLOCK_DIM, int N_READS = 4>
|
template <typename T, int BLOCK_DIM, int N_READS = 4>
|
||||||
__global__ void rms_norm(
|
__global__ void rms_norm(
|
||||||
const T* x,
|
const T* x,
|
||||||
@@ -94,6 +142,74 @@ __global__ void rms_norm(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
bool HAS_W,
|
||||||
|
int BLOCK_DIM,
|
||||||
|
int REDUCE_DIM,
|
||||||
|
int N_READS = 4>
|
||||||
|
__global__ void rms_norm_vjp_small(
|
||||||
|
const T* x,
|
||||||
|
const T* w,
|
||||||
|
const T* g,
|
||||||
|
T* gx,
|
||||||
|
T* gw,
|
||||||
|
float eps,
|
||||||
|
int32_t axis_size,
|
||||||
|
int32_t n_rows,
|
||||||
|
int64_t w_stride) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
|
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM, REDUCE_DIM>;
|
||||||
|
__shared__ typename BlockReduceF2::TempStorage temp;
|
||||||
|
|
||||||
|
auto row =
|
||||||
|
(grid.block_rank() * block.dim_threads().y) + block.thread_index().y;
|
||||||
|
if (row >= n_rows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
x += row * axis_size;
|
||||||
|
g += row * axis_size;
|
||||||
|
gx += row * axis_size;
|
||||||
|
gw += row * axis_size;
|
||||||
|
|
||||||
|
// Normalizer.
|
||||||
|
float2 factors = {};
|
||||||
|
auto index = block.thread_index().x;
|
||||||
|
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||||
|
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
|
||||||
|
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
float t = static_cast<float>(xn[i]);
|
||||||
|
float wi = wn[i];
|
||||||
|
float gi = gn[i];
|
||||||
|
float wg = wi * gi;
|
||||||
|
factors = plus_f2(factors, {wg * t, t * t});
|
||||||
|
}
|
||||||
|
|
||||||
|
factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});
|
||||||
|
float meangwx = factors.x / axis_size;
|
||||||
|
float normalizer = rsqrt(factors.y / axis_size + eps);
|
||||||
|
float normalizer3 = normalizer * normalizer * normalizer;
|
||||||
|
|
||||||
|
// Outputs.
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
float xi = xn[i];
|
||||||
|
float wi = wn[i];
|
||||||
|
float gi = gn[i];
|
||||||
|
xn[i] = static_cast<T>(normalizer * wi * gi - xi * meangwx * normalizer3);
|
||||||
|
if constexpr (HAS_W) {
|
||||||
|
wn[i] = static_cast<T>(gi * xi * normalizer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
store_vector<N_READS>(gx, index, xn, axis_size);
|
||||||
|
if constexpr (HAS_W) {
|
||||||
|
store_vector<N_READS>(gw, index, wn, axis_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
|
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
|
||||||
__global__ void rms_norm_vjp(
|
__global__ void rms_norm_vjp(
|
||||||
const T* x,
|
const T* x,
|
||||||
@@ -107,12 +223,8 @@ __global__ void rms_norm_vjp(
|
|||||||
auto grid = cg::this_grid();
|
auto grid = cg::this_grid();
|
||||||
auto block = cg::this_thread_block();
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
|
|
||||||
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>;
|
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>;
|
||||||
__shared__ union {
|
__shared__ typename BlockReduceF2::TempStorage temp;
|
||||||
typename BlockReduceF::TempStorage f;
|
|
||||||
typename BlockReduceF2::TempStorage f2;
|
|
||||||
} temp;
|
|
||||||
|
|
||||||
x += grid.block_rank() * axis_size;
|
x += grid.block_rank() * axis_size;
|
||||||
g += grid.block_rank() * axis_size;
|
g += grid.block_rank() * axis_size;
|
||||||
@@ -134,7 +246,7 @@ __global__ void rms_norm_vjp(
|
|||||||
factors = plus_f2(factors, {wg * t, t * t});
|
factors = plus_f2(factors, {wg * t, t * t});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {});
|
factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});
|
||||||
float meangwx = factors.x / axis_size;
|
float meangwx = factors.x / axis_size;
|
||||||
float normalizer = rsqrt(factors.y / axis_size + eps);
|
float normalizer = rsqrt(factors.y / axis_size + eps);
|
||||||
float normalizer3 = normalizer * normalizer * normalizer;
|
float normalizer3 = normalizer * normalizer * normalizer;
|
||||||
@@ -169,6 +281,43 @@ bool RMSNorm::use_fallback(Stream s) {
|
|||||||
return s.device == Device::cpu;
|
return s.device == Device::cpu;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int n_per_thread, typename F>
|
||||||
|
void dispatch_group_dim(int axis_size, F&& f) {
|
||||||
|
if (axis_size <= n_per_thread * 8) {
|
||||||
|
f(std::integral_constant<int, 8>{},
|
||||||
|
std::integral_constant<int, 1>(),
|
||||||
|
std::integral_constant<int, 16>());
|
||||||
|
} else if (axis_size <= n_per_thread * 16) {
|
||||||
|
f(std::integral_constant<int, 16>{},
|
||||||
|
std::integral_constant<int, 1>(),
|
||||||
|
std::integral_constant<int, 8>());
|
||||||
|
} else if (axis_size <= n_per_thread * 32) {
|
||||||
|
f(std::integral_constant<int, 32>{},
|
||||||
|
std::integral_constant<int, 1>(),
|
||||||
|
std::integral_constant<int, 4>());
|
||||||
|
} else if (axis_size <= n_per_thread * 32 * 2) {
|
||||||
|
f(std::integral_constant<int, 32>{},
|
||||||
|
std::integral_constant<int, 2>(),
|
||||||
|
std::integral_constant<int, 2>());
|
||||||
|
} else if (axis_size <= n_per_thread * 32 * 4) {
|
||||||
|
f(std::integral_constant<int, 32>{},
|
||||||
|
std::integral_constant<int, 4>(),
|
||||||
|
std::integral_constant<int, 1>());
|
||||||
|
} else if (axis_size <= n_per_thread * 32 * 8) {
|
||||||
|
f(std::integral_constant<int, 32>{},
|
||||||
|
std::integral_constant<int, 8>(),
|
||||||
|
std::integral_constant<int, 1>());
|
||||||
|
} else if (axis_size <= n_per_thread * 32 * 16) {
|
||||||
|
f(std::integral_constant<int, 32>{},
|
||||||
|
std::integral_constant<int, 16>(),
|
||||||
|
std::integral_constant<int, 1>());
|
||||||
|
} else {
|
||||||
|
f(std::integral_constant<int, 32>{},
|
||||||
|
std::integral_constant<int, 32>(),
|
||||||
|
std::integral_constant<int, 1>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: There are duplicate code with backend/metal/normalization.cpp
|
// TODO: There are duplicate code with backend/metal/normalization.cpp
|
||||||
void RMSNorm::eval_gpu(
|
void RMSNorm::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
@@ -216,12 +365,33 @@ void RMSNorm::eval_gpu(
|
|||||||
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
|
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
|
||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
constexpr int N_READS = 16 / sizeof(DataType);
|
constexpr int N_READS = 16 / sizeof(DataType);
|
||||||
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
if (axis_size <= N_READS * 1024) {
|
||||||
auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>;
|
dispatch_group_dim<N_READS>(
|
||||||
|
axis_size, [&](auto group_dim, auto n_groups, auto groups_per_block) {
|
||||||
|
constexpr int block_dim = n_groups() * group_dim();
|
||||||
|
auto kernel =
|
||||||
|
cu::rms_norm_small<DataType, block_dim, group_dim(), N_READS>;
|
||||||
|
auto n_blocks =
|
||||||
|
(n_rows + groups_per_block() - 1) / groups_per_block();
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
n_blocks,
|
||||||
|
{block_dim, groups_per_block()},
|
||||||
|
0,
|
||||||
|
gpu_ptr<DataType>(x),
|
||||||
|
gpu_ptr<DataType>(w),
|
||||||
|
gpu_ptr<DataType>(out),
|
||||||
|
eps_,
|
||||||
|
axis_size,
|
||||||
|
n_rows,
|
||||||
|
w_stride);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
auto kernel = cu::rms_norm<DataType, 1024, N_READS>;
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
1024,
|
||||||
0,
|
0,
|
||||||
gpu_ptr<DataType>(x),
|
gpu_ptr<DataType>(x),
|
||||||
gpu_ptr<DataType>(w),
|
gpu_ptr<DataType>(w),
|
||||||
@@ -229,7 +399,7 @@ void RMSNorm::eval_gpu(
|
|||||||
eps_,
|
eps_,
|
||||||
axis_size,
|
axis_size,
|
||||||
w_stride);
|
w_stride);
|
||||||
});
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -306,27 +476,51 @@ void RMSNormVJP::eval_gpu(
|
|||||||
dispatch_bool(has_w, [&](auto has_w_constant) {
|
dispatch_bool(has_w, [&](auto has_w_constant) {
|
||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
constexpr int N_READS = 16 / sizeof(DataType);
|
constexpr int N_READS = 16 / sizeof(DataType);
|
||||||
dispatch_block_dim(
|
if (axis_size <= N_READS * 1024) {
|
||||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
dispatch_group_dim<N_READS>(
|
||||||
auto kernel = cu::rms_norm_vjp<
|
axis_size,
|
||||||
DataType,
|
[&](auto group_dim, auto n_groups, auto groups_per_block) {
|
||||||
has_w_constant.value,
|
constexpr int block_dim = group_dim() * n_groups();
|
||||||
block_dim(),
|
auto kernel = cu::rms_norm_vjp_small<
|
||||||
N_READS>;
|
DataType,
|
||||||
encoder.add_kernel_node(
|
has_w_constant.value,
|
||||||
kernel,
|
block_dim,
|
||||||
n_rows,
|
group_dim(),
|
||||||
block_dim(),
|
N_READS>;
|
||||||
0,
|
auto n_blocks =
|
||||||
gpu_ptr<DataType>(x),
|
(n_rows + groups_per_block() - 1) / groups_per_block();
|
||||||
gpu_ptr<DataType>(w),
|
encoder.add_kernel_node(
|
||||||
gpu_ptr<DataType>(g),
|
kernel,
|
||||||
gpu_ptr<DataType>(gx),
|
n_blocks,
|
||||||
gpu_ptr<DataType>(gw_temp),
|
{block_dim, groups_per_block()},
|
||||||
eps_,
|
0,
|
||||||
axis_size,
|
gpu_ptr<DataType>(x),
|
||||||
w_stride);
|
gpu_ptr<DataType>(w),
|
||||||
});
|
gpu_ptr<DataType>(g),
|
||||||
|
gpu_ptr<DataType>(gx),
|
||||||
|
gpu_ptr<DataType>(gw_temp),
|
||||||
|
eps_,
|
||||||
|
axis_size,
|
||||||
|
n_rows,
|
||||||
|
w_stride);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
auto kernel =
|
||||||
|
cu::rms_norm_vjp<DataType, has_w_constant.value, 1024, N_READS>;
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
n_rows,
|
||||||
|
1024,
|
||||||
|
0,
|
||||||
|
gpu_ptr<DataType>(x),
|
||||||
|
gpu_ptr<DataType>(w),
|
||||||
|
gpu_ptr<DataType>(g),
|
||||||
|
gpu_ptr<DataType>(gx),
|
||||||
|
gpu_ptr<DataType>(gw_temp),
|
||||||
|
eps_,
|
||||||
|
axis_size,
|
||||||
|
w_stride);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -63,6 +63,38 @@ array prepare_sdpa_input(const array& x, Stream s) {
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void malloc_with_same_layout(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array& o,
|
||||||
|
const array& q) {
|
||||||
|
if (q.flags().row_contiguous) {
|
||||||
|
o.set_data(cu::malloc_async(o.nbytes(), encoder));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// fill_order = argsort(q.strides())
|
||||||
|
Shape fill_order(q.ndim());
|
||||||
|
std::iota(fill_order.begin(), fill_order.end(), 0);
|
||||||
|
std::stable_sort(
|
||||||
|
fill_order.begin(), fill_order.end(), [&q](int idx1, int idx2) {
|
||||||
|
auto s1 = q.strides(idx1) > 0 ? q.strides(idx1) : 1;
|
||||||
|
auto s2 = q.strides(idx2) > 0 ? q.strides(idx2) : 1;
|
||||||
|
return s1 < s2;
|
||||||
|
});
|
||||||
|
// Generate o_strides with fill_order
|
||||||
|
Strides o_strides(q.ndim());
|
||||||
|
int64_t stride = 1;
|
||||||
|
for (int i : fill_order) {
|
||||||
|
o_strides[i] = stride;
|
||||||
|
stride *= o.shape(i);
|
||||||
|
}
|
||||||
|
// o is a transposed contiguous array
|
||||||
|
o.set_data(
|
||||||
|
cu::malloc_async(o.nbytes(), encoder),
|
||||||
|
o.size(),
|
||||||
|
o_strides,
|
||||||
|
{true, false, false});
|
||||||
|
}
|
||||||
|
|
||||||
constexpr int QKV_NDIM = 4;
|
constexpr int QKV_NDIM = 4;
|
||||||
|
|
||||||
struct SDPACacheKey {
|
struct SDPACacheKey {
|
||||||
@@ -75,6 +107,8 @@ struct SDPACacheKey {
|
|||||||
std::array<int64_t, QKV_NDIM> k_strides;
|
std::array<int64_t, QKV_NDIM> k_strides;
|
||||||
std::array<int64_t, QKV_NDIM> v_strides;
|
std::array<int64_t, QKV_NDIM> v_strides;
|
||||||
bool do_causal;
|
bool do_causal;
|
||||||
|
std::array<int, QKV_NDIM> mask_shape;
|
||||||
|
std::array<int64_t, QKV_NDIM> mask_strides;
|
||||||
bool output_logsumexp;
|
bool output_logsumexp;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -84,6 +118,7 @@ inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
|
|||||||
const array& k,
|
const array& k,
|
||||||
const array& v,
|
const array& v,
|
||||||
bool do_causal,
|
bool do_causal,
|
||||||
|
const std::optional<array>& mask_arr,
|
||||||
bool output_logsumexp = true) {
|
bool output_logsumexp = true) {
|
||||||
BytesKey<SDPACacheKey> cache_key;
|
BytesKey<SDPACacheKey> cache_key;
|
||||||
cache_key.pod = {
|
cache_key.pod = {
|
||||||
@@ -96,20 +131,26 @@ inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
|
|||||||
vector_key<QKV_NDIM>(k.strides()),
|
vector_key<QKV_NDIM>(k.strides()),
|
||||||
vector_key<QKV_NDIM>(v.strides()),
|
vector_key<QKV_NDIM>(v.strides()),
|
||||||
do_causal,
|
do_causal,
|
||||||
|
{},
|
||||||
|
{},
|
||||||
output_logsumexp,
|
output_logsumexp,
|
||||||
};
|
};
|
||||||
|
if (mask_arr) {
|
||||||
|
cache_key.pod.mask_shape = vector_key<QKV_NDIM>(mask_arr->shape());
|
||||||
|
cache_key.pod.mask_strides = vector_key<QKV_NDIM>(mask_arr->strides());
|
||||||
|
}
|
||||||
return cache_key;
|
return cache_key;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& sdpa_cache() {
|
auto& sdpa_cache() {
|
||||||
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
|
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
|
||||||
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 16);
|
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 64);
|
||||||
return cache;
|
return cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& sdpa_backward_cache() {
|
auto& sdpa_backward_cache() {
|
||||||
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
|
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
|
||||||
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 16);
|
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 64);
|
||||||
return cache;
|
return cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,6 +159,7 @@ enum UIDS {
|
|||||||
K,
|
K,
|
||||||
V,
|
V,
|
||||||
SCALE,
|
SCALE,
|
||||||
|
BIAS,
|
||||||
O,
|
O,
|
||||||
STATS,
|
STATS,
|
||||||
// Backward graph:
|
// Backward graph:
|
||||||
@@ -133,6 +175,7 @@ fe::graph::Graph build_sdpa_graph(
|
|||||||
const array& k,
|
const array& k,
|
||||||
const array& v,
|
const array& v,
|
||||||
bool do_causal,
|
bool do_causal,
|
||||||
|
const std::optional<array>& mask_arr,
|
||||||
bool output_logsumexp,
|
bool output_logsumexp,
|
||||||
const array& o,
|
const array& o,
|
||||||
const array& stats) {
|
const array& stats) {
|
||||||
@@ -164,8 +207,19 @@ fe::graph::Graph build_sdpa_graph(
|
|||||||
auto options = fe::graph::SDPA_attributes()
|
auto options = fe::graph::SDPA_attributes()
|
||||||
.set_name("sdpa_cudnn")
|
.set_name("sdpa_cudnn")
|
||||||
.set_attn_scale(scale)
|
.set_attn_scale(scale)
|
||||||
.set_causal_mask(do_causal)
|
|
||||||
.set_generate_stats(output_logsumexp);
|
.set_generate_stats(output_logsumexp);
|
||||||
|
if (do_causal) {
|
||||||
|
if (q.shape(2) > k.shape(2)) {
|
||||||
|
options.set_causal_mask(do_causal);
|
||||||
|
} else {
|
||||||
|
options.set_causal_mask_bottom_right(do_causal);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (mask_arr) {
|
||||||
|
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
|
||||||
|
set_tensor_attrs(bias_, BIAS, *mask_arr);
|
||||||
|
options.set_bias(bias_);
|
||||||
|
}
|
||||||
|
|
||||||
auto [o_, stats_] = graph.sdpa(q_, k_, v_, options);
|
auto [o_, stats_] = graph.sdpa(q_, k_, v_, options);
|
||||||
o_->set_output(true);
|
o_->set_output(true);
|
||||||
@@ -192,6 +246,7 @@ fe::graph::Graph build_sdpa_backward_graph(
|
|||||||
const array& k,
|
const array& k,
|
||||||
const array& v,
|
const array& v,
|
||||||
bool do_causal,
|
bool do_causal,
|
||||||
|
const std::optional<array>& mask_arr,
|
||||||
const array& o,
|
const array& o,
|
||||||
const array& d_o,
|
const array& d_o,
|
||||||
const array& stats,
|
const array& stats,
|
||||||
@@ -233,7 +288,19 @@ fe::graph::Graph build_sdpa_backward_graph(
|
|||||||
auto options = fe::graph::SDPA_backward_attributes()
|
auto options = fe::graph::SDPA_backward_attributes()
|
||||||
.set_name("sdpa_backward_cudnn")
|
.set_name("sdpa_backward_cudnn")
|
||||||
.set_attn_scale(scale)
|
.set_attn_scale(scale)
|
||||||
.set_causal_mask(do_causal);
|
.set_attn_scale(scale);
|
||||||
|
if (do_causal) {
|
||||||
|
if (q.shape(2) > k.shape(2)) {
|
||||||
|
options.set_causal_mask(do_causal);
|
||||||
|
} else {
|
||||||
|
options.set_causal_mask_bottom_right(do_causal);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (mask_arr) {
|
||||||
|
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
|
||||||
|
set_tensor_attrs(bias_, BIAS, *mask_arr);
|
||||||
|
options.set_bias(bias_);
|
||||||
|
}
|
||||||
|
|
||||||
auto [d_q_, d_k_, d_v_] =
|
auto [d_q_, d_k_, d_v_] =
|
||||||
graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options);
|
graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options);
|
||||||
@@ -286,7 +353,6 @@ bool supports_sdpa_cudnn(
|
|||||||
const array& q,
|
const array& q,
|
||||||
const array& k,
|
const array& k,
|
||||||
const array& v,
|
const array& v,
|
||||||
bool has_mask,
|
|
||||||
bool do_causal,
|
bool do_causal,
|
||||||
Stream s) {
|
Stream s) {
|
||||||
static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1);
|
static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1);
|
||||||
@@ -299,19 +365,8 @@ bool supports_sdpa_cudnn(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (has_mask) {
|
// Only use cuDNN for prefilling (T_q > 1) and training (T_q == T_kv).
|
||||||
// TODO: Support array masks.
|
if ((q.shape(2) == 1) && (q.shape(2) != k.shape(2))) {
|
||||||
if (!do_causal) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// FIXME: Causal mask generates wrong results when L_Q != L_K.
|
|
||||||
if (q.shape(2) != k.shape(2)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only use cuDNN for prefilling and training.
|
|
||||||
if (q.shape(2) != k.shape(2)) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -333,32 +388,33 @@ void sdpa_cudnn(
|
|||||||
array& o,
|
array& o,
|
||||||
array& stats,
|
array& stats,
|
||||||
bool do_causal,
|
bool do_causal,
|
||||||
|
const std::optional<array>& mask_arr,
|
||||||
bool output_logsumexp,
|
bool output_logsumexp,
|
||||||
Stream s) {
|
Stream s) {
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
auto handle = encoder.device().cudnn_handle();
|
auto handle = encoder.device().cudnn_handle();
|
||||||
|
|
||||||
// TODO: Handle donation.
|
malloc_with_same_layout(encoder, o, q);
|
||||||
// TODO: Make O use same memory layout with Q.
|
|
||||||
o.set_data(cu::malloc_async(o.nbytes(), encoder));
|
|
||||||
|
|
||||||
encoder.set_input_array(q);
|
encoder.set_input_array(q);
|
||||||
encoder.set_input_array(k);
|
encoder.set_input_array(k);
|
||||||
encoder.set_input_array(v);
|
encoder.set_input_array(v);
|
||||||
encoder.set_output_array(o);
|
encoder.set_output_array(o);
|
||||||
|
if (mask_arr) {
|
||||||
|
encoder.set_input_array(*mask_arr);
|
||||||
|
}
|
||||||
if (output_logsumexp) {
|
if (output_logsumexp) {
|
||||||
stats.set_data(cu::malloc_async(stats.nbytes(), encoder));
|
stats.set_data(cu::malloc_async(stats.nbytes(), encoder));
|
||||||
encoder.set_output_array(stats);
|
encoder.set_output_array(stats);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Search cache.
|
// Search cache.
|
||||||
auto cache_key =
|
auto cache_key = build_sdpa_cache_key(
|
||||||
build_sdpa_cache_key(encoder, q, k, v, do_causal, output_logsumexp);
|
encoder, q, k, v, do_causal, mask_arr, output_logsumexp);
|
||||||
auto it = sdpa_cache().find(cache_key);
|
auto it = sdpa_cache().find(cache_key);
|
||||||
if (it == sdpa_cache().end()) {
|
if (it == sdpa_cache().end()) {
|
||||||
auto graph = build_sdpa_graph(
|
auto graph = build_sdpa_graph(
|
||||||
handle, q, k, v, do_causal, output_logsumexp, o, stats);
|
handle, q, k, v, do_causal, mask_arr, output_logsumexp, o, stats);
|
||||||
it = sdpa_cache().emplace(cache_key, std::move(graph)).first;
|
it = sdpa_cache().emplace(cache_key, std::move(graph)).first;
|
||||||
}
|
}
|
||||||
auto& graph = it->second;
|
auto& graph = it->second;
|
||||||
@@ -369,6 +425,9 @@ void sdpa_cudnn(
|
|||||||
{V, const_cast<void*>(gpu_ptr<void>(v))},
|
{V, const_cast<void*>(gpu_ptr<void>(v))},
|
||||||
{SCALE, &scale},
|
{SCALE, &scale},
|
||||||
{O, gpu_ptr<void>(o)}};
|
{O, gpu_ptr<void>(o)}};
|
||||||
|
if (mask_arr) {
|
||||||
|
variant_pack[BIAS] = const_cast<void*>(gpu_ptr<void>(*mask_arr));
|
||||||
|
}
|
||||||
if (output_logsumexp) {
|
if (output_logsumexp) {
|
||||||
variant_pack[STATS] = gpu_ptr<void>(stats);
|
variant_pack[STATS] = gpu_ptr<void>(stats);
|
||||||
}
|
}
|
||||||
@@ -384,6 +443,7 @@ void sdpa_backward_cudnn(
|
|||||||
const array& o,
|
const array& o,
|
||||||
const array& stats,
|
const array& stats,
|
||||||
bool do_causal,
|
bool do_causal,
|
||||||
|
const std::optional<array>& mask_arr,
|
||||||
const array& d_o,
|
const array& d_o,
|
||||||
array& d_q,
|
array& d_q,
|
||||||
array& d_k,
|
array& d_k,
|
||||||
@@ -392,10 +452,9 @@ void sdpa_backward_cudnn(
|
|||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
auto handle = encoder.device().cudnn_handle();
|
auto handle = encoder.device().cudnn_handle();
|
||||||
|
|
||||||
// TODO: Handle donation.
|
malloc_with_same_layout(encoder, d_q, q);
|
||||||
d_q.set_data(cu::malloc_async(d_q.nbytes(), encoder));
|
malloc_with_same_layout(encoder, d_k, k);
|
||||||
d_k.set_data(cu::malloc_async(d_k.nbytes(), encoder));
|
malloc_with_same_layout(encoder, d_v, v);
|
||||||
d_v.set_data(cu::malloc_async(d_v.nbytes(), encoder));
|
|
||||||
|
|
||||||
encoder.set_input_array(q);
|
encoder.set_input_array(q);
|
||||||
encoder.set_input_array(k);
|
encoder.set_input_array(k);
|
||||||
@@ -406,13 +465,16 @@ void sdpa_backward_cudnn(
|
|||||||
encoder.set_output_array(d_q);
|
encoder.set_output_array(d_q);
|
||||||
encoder.set_output_array(d_k);
|
encoder.set_output_array(d_k);
|
||||||
encoder.set_output_array(d_v);
|
encoder.set_output_array(d_v);
|
||||||
|
if (mask_arr) {
|
||||||
|
encoder.set_input_array(*mask_arr);
|
||||||
|
}
|
||||||
|
|
||||||
// Search cache.
|
// Search cache.
|
||||||
auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal);
|
auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal, mask_arr);
|
||||||
auto it = sdpa_backward_cache().find(cache_key);
|
auto it = sdpa_backward_cache().find(cache_key);
|
||||||
if (it == sdpa_backward_cache().end()) {
|
if (it == sdpa_backward_cache().end()) {
|
||||||
auto graph = build_sdpa_backward_graph(
|
auto graph = build_sdpa_backward_graph(
|
||||||
handle, q, k, v, do_causal, o, d_o, stats, d_q, d_k, d_v);
|
handle, q, k, v, do_causal, mask_arr, o, d_o, stats, d_q, d_k, d_v);
|
||||||
it = sdpa_backward_cache().emplace(cache_key, std::move(graph)).first;
|
it = sdpa_backward_cache().emplace(cache_key, std::move(graph)).first;
|
||||||
}
|
}
|
||||||
auto& graph = it->second;
|
auto& graph = it->second;
|
||||||
@@ -428,6 +490,9 @@ void sdpa_backward_cudnn(
|
|||||||
{D_Q, gpu_ptr<void>(d_q)},
|
{D_Q, gpu_ptr<void>(d_q)},
|
||||||
{D_K, gpu_ptr<void>(d_k)},
|
{D_K, gpu_ptr<void>(d_k)},
|
||||||
{D_V, gpu_ptr<void>(d_v)}};
|
{D_V, gpu_ptr<void>(d_v)}};
|
||||||
|
if (mask_arr) {
|
||||||
|
variant_pack[BIAS] = const_cast<void*>(gpu_ptr<void>(*mask_arr));
|
||||||
|
}
|
||||||
|
|
||||||
execute_graph(encoder, handle, graph, variant_pack);
|
execute_graph(encoder, handle, graph, variant_pack);
|
||||||
}
|
}
|
||||||
@@ -469,7 +534,11 @@ bool ScaledDotProductAttention::use_fallback(
|
|||||||
|
|
||||||
return !supports_sdpa_vector(
|
return !supports_sdpa_vector(
|
||||||
q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) &&
|
q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) &&
|
||||||
!supports_sdpa_cudnn(q, k, v, has_mask, do_causal, s);
|
!supports_sdpa_cudnn(q, k, v, do_causal, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ScaledDotProductAttention::supports_bool_mask() {
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ScaledDotProductAttention::eval_gpu(
|
void ScaledDotProductAttention::eval_gpu(
|
||||||
@@ -487,6 +556,11 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
bool has_mask = inputs.size() - has_sinks_ > 3;
|
bool has_mask = inputs.size() - has_sinks_ > 3;
|
||||||
bool has_arr_mask = has_mask && !do_causal_;
|
bool has_arr_mask = has_mask && !do_causal_;
|
||||||
|
|
||||||
|
std::optional<array> mask_arr;
|
||||||
|
if (has_arr_mask) {
|
||||||
|
mask_arr = prepare_sdpa_input(inputs[3], s);
|
||||||
|
}
|
||||||
|
|
||||||
if (supports_sdpa_vector(
|
if (supports_sdpa_vector(
|
||||||
q, k, v, has_mask, has_arr_mask, do_causal_, output_logsumexp_)) {
|
q, k, v, has_mask, has_arr_mask, do_causal_, output_logsumexp_)) {
|
||||||
if (has_sinks_) {
|
if (has_sinks_) {
|
||||||
@@ -495,7 +569,17 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s);
|
sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
sdpa_cudnn(q, k, v, scale_, out, stats, do_causal_, output_logsumexp_, s);
|
sdpa_cudnn(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
scale_,
|
||||||
|
out,
|
||||||
|
stats,
|
||||||
|
do_causal_,
|
||||||
|
mask_arr,
|
||||||
|
output_logsumexp_,
|
||||||
|
s);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -515,13 +599,21 @@ void ScaledDotProductAttentionVJP::eval_gpu(
|
|||||||
|
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
|
|
||||||
assert(inputs.size() == 6);
|
assert(inputs.size() >= 6);
|
||||||
|
int primals_size = inputs.size() - 3;
|
||||||
|
bool has_arr_mask = primals_size > 3 + has_sinks_;
|
||||||
|
|
||||||
array q = prepare_sdpa_input(inputs[0], s);
|
array q = prepare_sdpa_input(inputs[0], s);
|
||||||
array k = prepare_sdpa_input(inputs[1], s);
|
array k = prepare_sdpa_input(inputs[1], s);
|
||||||
array v = prepare_sdpa_input(inputs[2], s);
|
array v = prepare_sdpa_input(inputs[2], s);
|
||||||
array o = prepare_sdpa_input(inputs[3], s);
|
array o = prepare_sdpa_input(inputs[primals_size], s);
|
||||||
array stats = prepare_sdpa_input(inputs[4], s);
|
array stats = prepare_sdpa_input(inputs[primals_size + 1], s);
|
||||||
array d_o = prepare_sdpa_input(inputs[5], s);
|
array d_o = prepare_sdpa_input(inputs[primals_size + 2], s);
|
||||||
|
|
||||||
|
std::optional<array> mask_arr;
|
||||||
|
if (has_arr_mask) {
|
||||||
|
mask_arr = prepare_sdpa_input(inputs[3], s);
|
||||||
|
}
|
||||||
|
|
||||||
assert(outputs.size() == 3);
|
assert(outputs.size() == 3);
|
||||||
auto& d_q = outputs[0];
|
auto& d_q = outputs[0];
|
||||||
@@ -529,7 +621,7 @@ void ScaledDotProductAttentionVJP::eval_gpu(
|
|||||||
auto& d_v = outputs[2];
|
auto& d_v = outputs[2];
|
||||||
|
|
||||||
sdpa_backward_cudnn(
|
sdpa_backward_cudnn(
|
||||||
q, k, v, scale_, o, stats, do_causal_, d_o, d_q, d_k, d_v, s);
|
q, k, v, scale_, o, stats, do_causal_, mask_arr, d_o, d_q, d_k, d_v, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace fast
|
} // namespace fast
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,77 @@ INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h
|
|||||||
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
|
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
|
||||||
|
|
||||||
mkdir -p "$OUTPUT_DIR"
|
mkdir -p "$OUTPUT_DIR"
|
||||||
CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
|
# CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
|
||||||
|
|
||||||
|
CCC="xcrun -sdk macosx metal -x metal"
|
||||||
|
|
||||||
|
HDRS=$( $CCC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P -CC -C -H "$INPUT_FILE" $CFLAGS -w 2>&1 1>/dev/null )
|
||||||
|
|
||||||
|
declare -a HDRS_LIST=($HDRS)
|
||||||
|
declare -a HDRS_STACK=()
|
||||||
|
declare -a HDRS_SORTED=()
|
||||||
|
|
||||||
|
length=${#HDRS_LIST[@]}
|
||||||
|
|
||||||
|
HDRS_LIST+=(".")
|
||||||
|
|
||||||
|
for ((i=0; i<${length}; i+=2));
|
||||||
|
do
|
||||||
|
|
||||||
|
header="${HDRS_LIST[$i+1]#$SRC_DIR/}"
|
||||||
|
|
||||||
|
str_this="${HDRS_LIST[$i]}"
|
||||||
|
str_next="${HDRS_LIST[$i + 2]}"
|
||||||
|
|
||||||
|
depth_this=${#str_this}
|
||||||
|
depth_next=${#str_next}
|
||||||
|
|
||||||
|
# If we have a dependency then we stack it
|
||||||
|
if [ $depth_next -gt $depth_this ]; then
|
||||||
|
HDRS_STACK=($header ${HDRS_STACK[@]})
|
||||||
|
|
||||||
|
# If we are done with this level
|
||||||
|
else
|
||||||
|
# We add the header to out list
|
||||||
|
HDRS_SORTED+=($header)
|
||||||
|
|
||||||
|
# Pop the stacked up dependencies
|
||||||
|
pop_len=$((depth_this - depth_next))
|
||||||
|
for popped_header in "${HDRS_STACK[@]:0:$pop_len}"
|
||||||
|
do
|
||||||
|
HDRS_SORTED+=($popped_header)
|
||||||
|
done
|
||||||
|
|
||||||
|
HDRS_STACK=(${HDRS_STACK[@]:$pop_len})
|
||||||
|
fi
|
||||||
|
|
||||||
|
done
|
||||||
|
|
||||||
|
HDRS_SORTED+=("${INPUT_FILE#$SRC_DIR/}")
|
||||||
|
|
||||||
|
CONTENT=$(
|
||||||
|
echo "// Copyright © 2025 Apple Inc."
|
||||||
|
echo ""
|
||||||
|
echo "// Auto generated source for $INPUT_FILE"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
for header in "${HDRS_SORTED[@]}"
|
||||||
|
do
|
||||||
|
echo "///////////////////////////////////////////////////////////////////////////////"
|
||||||
|
echo "// Contents from \"${header}\""
|
||||||
|
echo "///////////////////////////////////////////////////////////////////////////////"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
echo "#line 1 \"${header}\""
|
||||||
|
|
||||||
|
grep -h -v -G -e "#include \".*.h\"" -e "#pragma once" "${SRC_DIR}/${header}"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "///////////////////////////////////////////////////////////////////////////////"
|
||||||
|
)
|
||||||
|
|
||||||
cat << EOF > "$OUTPUT_FILE"
|
cat << EOF > "$OUTPUT_FILE"
|
||||||
namespace mlx::core::metal {
|
namespace mlx::core::metal {
|
||||||
|
|||||||
@@ -569,6 +569,10 @@ bool ScaledDotProductAttention::use_fallback(
|
|||||||
return !(supports_sdpa_full || supports_sdpa_vector);
|
return !(supports_sdpa_full || supports_sdpa_vector);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ScaledDotProductAttention::supports_bool_mask() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
void ScaledDotProductAttention::eval_gpu(
|
void ScaledDotProductAttention::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
|
|||||||
@@ -36,6 +36,10 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool fast::ScaledDotProductAttention::supports_bool_mask() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
bool fast::ScaledDotProductAttentionVJP::use_fallback(
|
bool fast::ScaledDotProductAttentionVJP::use_fallback(
|
||||||
const array& q,
|
const array& q,
|
||||||
Stream s) {
|
Stream s) {
|
||||||
|
|||||||
@@ -5,3 +5,48 @@
|
|||||||
ncclResult_t ncclGetUniqueId(ncclUniqueId*) {
|
ncclResult_t ncclGetUniqueId(ncclUniqueId*) {
|
||||||
return ncclSuccess;
|
return ncclSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char* ncclGetErrorString(ncclResult_t result) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
ncclResult_t
|
||||||
|
ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank) {
|
||||||
|
return ncclSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
ncclResult_t ncclCommDestroy(ncclComm_t comm) {
|
||||||
|
return ncclSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
ncclResult_t ncclAllGather(
|
||||||
|
const void* sendbuff,
|
||||||
|
void* recvbuff,
|
||||||
|
size_t sendcount,
|
||||||
|
ncclDataType_t datatype,
|
||||||
|
ncclComm_t comm,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
return ncclSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
ncclResult_t ncclAllReduce(
|
||||||
|
const void* sendbuff,
|
||||||
|
void* recvbuff,
|
||||||
|
size_t count,
|
||||||
|
ncclDataType_t datatype,
|
||||||
|
ncclRedOp_t op,
|
||||||
|
ncclComm_t comm,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
return ncclSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
ncclResult_t ncclReduceScatter(
|
||||||
|
const void* sendbuff,
|
||||||
|
void* recvbuff,
|
||||||
|
size_t recvcount,
|
||||||
|
ncclDataType_t datatype,
|
||||||
|
ncclRedOp_t op,
|
||||||
|
ncclComm_t comm,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
return ncclSuccess;
|
||||||
|
}
|
||||||
|
|||||||
11
mlx/fast.cpp
11
mlx/fast.cpp
@@ -800,6 +800,15 @@ array scaled_dot_product_attention(
|
|||||||
is_training,
|
is_training,
|
||||||
output_logsumexp,
|
output_logsumexp,
|
||||||
stream)) {
|
stream)) {
|
||||||
|
if (has_bool_mask && !ScaledDotProductAttention::supports_bool_mask()) {
|
||||||
|
// Convert bool mask to additive mask.
|
||||||
|
float inf = std::numeric_limits<float>::infinity();
|
||||||
|
array& mask = inputs[3];
|
||||||
|
mask = where(
|
||||||
|
mask,
|
||||||
|
full_like(mask, 0, final_type, s),
|
||||||
|
full_like(mask, -inf, final_type, s));
|
||||||
|
}
|
||||||
Shape out_shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
Shape out_shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
||||||
auto primitive = std::make_shared<ScaledDotProductAttention>(
|
auto primitive = std::make_shared<ScaledDotProductAttention>(
|
||||||
stream, fallback, scale, do_causal, has_sinks, output_logsumexp);
|
stream, fallback, scale, do_causal, has_sinks, output_logsumexp);
|
||||||
@@ -839,7 +848,7 @@ std::vector<array> ScaledDotProductAttention::vjp(
|
|||||||
|
|
||||||
std::vector<Shape> shapes;
|
std::vector<Shape> shapes;
|
||||||
std::vector<Dtype> dtypes;
|
std::vector<Dtype> dtypes;
|
||||||
for (int i = 0; i < primals.size(); ++i) {
|
for (int i = 0; i < /* outputs size */ 3; ++i) {
|
||||||
shapes.push_back(primals[i].shape());
|
shapes.push_back(primals[i].shape());
|
||||||
dtypes.push_back(primals[i].dtype());
|
dtypes.push_back(primals[i].dtype());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -228,6 +228,7 @@ class ScaledDotProductAttention : public Custom {
|
|||||||
bool is_training,
|
bool is_training,
|
||||||
bool output_logsumexp,
|
bool output_logsumexp,
|
||||||
Stream s);
|
Stream s);
|
||||||
|
static bool supports_bool_mask();
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override {
|
override {
|
||||||
|
|||||||
@@ -250,7 +250,7 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
std::vector<array>
|
std::vector<array>
|
||||||
svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
|
svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
|
||||||
check_cpu_stream(s, "[linalg::svd]");
|
check_cpu_stream(s, "[linalg::svd]");
|
||||||
check_float(a.dtype(), "[linalg::svd]");
|
check_float_or_complex(a.dtype(), "[linalg::svd]");
|
||||||
|
|
||||||
if (a.ndim() < 2) {
|
if (a.ndim() < 2) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@@ -268,10 +268,12 @@ svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
|
|||||||
s_shape.pop_back();
|
s_shape.pop_back();
|
||||||
s_shape[rank - 2] = std::min(m, n);
|
s_shape[rank - 2] = std::min(m, n);
|
||||||
|
|
||||||
|
auto s_dtype = a.dtype() == complex64 ? float32 : a.dtype();
|
||||||
|
|
||||||
if (!compute_uv) {
|
if (!compute_uv) {
|
||||||
return {array(
|
return {array(
|
||||||
std::move(s_shape),
|
std::move(s_shape),
|
||||||
a.dtype(),
|
s_dtype,
|
||||||
std::make_shared<SVD>(to_stream(s), compute_uv),
|
std::make_shared<SVD>(to_stream(s), compute_uv),
|
||||||
{a})};
|
{a})};
|
||||||
}
|
}
|
||||||
@@ -286,7 +288,7 @@ svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
|
|||||||
|
|
||||||
return array::make_arrays(
|
return array::make_arrays(
|
||||||
{u_shape, s_shape, vt_shape},
|
{u_shape, s_shape, vt_shape},
|
||||||
{a.dtype(), a.dtype(), a.dtype()},
|
{a.dtype(), s_dtype, a.dtype()},
|
||||||
std::make_shared<SVD>(to_stream(s), compute_uv),
|
std::make_shared<SVD>(to_stream(s), compute_uv),
|
||||||
{a});
|
{a});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -135,7 +135,11 @@ class Scheduler {
|
|||||||
|
|
||||||
~Scheduler() {
|
~Scheduler() {
|
||||||
for (auto s : streams_) {
|
for (auto s : streams_) {
|
||||||
synchronize(s);
|
try {
|
||||||
|
synchronize(s);
|
||||||
|
} catch (const std::runtime_error&) {
|
||||||
|
// ignore errors if synch fails
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for (auto t : threads_) {
|
for (auto t : threads_) {
|
||||||
if (t != nullptr) {
|
if (t != nullptr) {
|
||||||
|
|||||||
@@ -407,7 +407,10 @@ class Module(dict):
|
|||||||
instance).
|
instance).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
apply_fn (Callable): The function to apply to the modules.
|
apply_fn (Callable): The function to apply to the modules which
|
||||||
|
takes two parameters. The first parameter is the string path of
|
||||||
|
the module (e.g. ``"model.layers.0.linear"``). The second
|
||||||
|
parameter is the module object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The module instance after updating submodules.
|
The module instance after updating submodules.
|
||||||
|
|||||||
@@ -1445,7 +1445,7 @@ void init_ops(nb::module_& m) {
|
|||||||
"dtype"_a.none() = mx::float32,
|
"dtype"_a.none() = mx::float32,
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def linspace(start, stop, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array"),
|
"def linspace(start: scalar, stop: scalar, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Generate ``num`` evenly spaced numbers over interval ``[start, stop]``.
|
Generate ``num`` evenly spaced numbers over interval ``[start, stop]``.
|
||||||
|
|
||||||
@@ -4021,7 +4021,7 @@ void init_ops(nb::module_& m) {
|
|||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def load(file: Union[file, str, pathlib.Path], /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"),
|
"def load(file: Union[file, str, pathlib.Path], /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array], Tuple[dict[str, array], dict[str, Any]]]"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Load array(s) from a binary file.
|
Load array(s) from a binary file.
|
||||||
|
|
||||||
@@ -4037,11 +4037,12 @@ void init_ops(nb::module_& m) {
|
|||||||
which support matadata. The metadata will be returned as an
|
which support matadata. The metadata will be returned as an
|
||||||
additional dictionary. Default: ``False``.
|
additional dictionary. Default: ``False``.
|
||||||
Returns:
|
Returns:
|
||||||
array or dict:
|
array, dict, or tuple:
|
||||||
A single array if loading from a ``.npy`` file or a dict
|
A single array if loading from a ``.npy`` file or a dict
|
||||||
mapping names to arrays if loading from a ``.npz`` or
|
mapping names to arrays if loading from a ``.npz`` or
|
||||||
``.safetensors`` file. If ``return_metadata`` is ``True`` an
|
``.safetensors`` file. If ``return_metadata`` is ``True`` a
|
||||||
additional dictionary of metadata will be returned.
|
tuple ``(arrays, metadata)`` will be returned where the second
|
||||||
|
element is a dictionary containing the metadata.
|
||||||
|
|
||||||
Warning:
|
Warning:
|
||||||
|
|
||||||
|
|||||||
@@ -1238,8 +1238,18 @@ void init_transforms(nb::module_& m) {
|
|||||||
same in number, shape, and type as the inputs of ``fun`` (i.e. the ``primals``).
|
same in number, shape, and type as the inputs of ``fun`` (i.e. the ``primals``).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list(array): A list of the Jacobian-vector products which
|
tuple(list(array), list(array)): A tuple with the outputs of
|
||||||
is the same in number, shape, and type of the inputs to ``fun``.
|
``fun`` in the first position and the Jacobian-vector products
|
||||||
|
in the second position.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
outs, jvps = mx.jvp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))
|
||||||
|
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"vjp",
|
"vjp",
|
||||||
@@ -1277,8 +1287,18 @@ void init_transforms(nb::module_& m) {
|
|||||||
same in number, shape, and type as the outputs of ``fun``.
|
same in number, shape, and type as the outputs of ``fun``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list(array): A list of the vector-Jacobian products which
|
tuple(list(array), list(array)): A tuple with the outputs of
|
||||||
is the same in number, shape, and type of the outputs of ``fun``.
|
``fun`` in the first position and the vector-Jacobian products
|
||||||
|
in the second position.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
outs, vjps = mx.vjp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))
|
||||||
|
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"value_and_grad",
|
"value_and_grad",
|
||||||
|
|||||||
@@ -739,37 +739,69 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.allclose(out, expected, atol=1e-5))
|
self.assertTrue(mx.allclose(out, expected, atol=1e-5))
|
||||||
|
|
||||||
def test_sdpa_grad(self):
|
def test_sdpa_grad(self):
|
||||||
B, N_kv, T, D = (2, 8, 128, 64)
|
|
||||||
scale = D**-0.5
|
|
||||||
|
|
||||||
f1 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale)
|
|
||||||
f2 = lambda q, k, v: mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
|
|
||||||
|
|
||||||
f3 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale).sum()
|
|
||||||
f4 = lambda q, k, v: (
|
|
||||||
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
|
|
||||||
).sum()
|
|
||||||
|
|
||||||
# High tolerance due to cuDNN SDPA kernel requiring tf32.
|
# High tolerance due to cuDNN SDPA kernel requiring tf32.
|
||||||
tolerance = {"rtol": 1e-2, "atol": 1e-2}
|
tolerance = {"rtol": 1e-2, "atol": 1e-2}
|
||||||
|
|
||||||
|
def test_vjp(slow, fast, primals):
|
||||||
|
cotan = mx.ones_like(primals[0])
|
||||||
|
o1, vjp1 = mx.vjp(slow, primals, [cotan])
|
||||||
|
o2, vjp2 = mx.vjp(fast, primals, [cotan])
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance))
|
||||||
|
for i in range(3):
|
||||||
|
self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance))
|
||||||
|
|
||||||
|
def test_grad(slow, fast, args):
|
||||||
|
g1 = mx.grad(slow)(*args)
|
||||||
|
g2 = mx.grad(fast)(*args)
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(g1, g2, **tolerance))
|
||||||
|
|
||||||
|
sdpa_mask_slow = lambda q, k, v, mask: mlx_ref_attn(
|
||||||
|
q, k, v, scale=scale, mask=mask
|
||||||
|
)
|
||||||
|
sdpa_mask_fast = lambda q, k, v, mask: mx.fast.scaled_dot_product_attention(
|
||||||
|
q, k, v, scale=scale, mask=mask
|
||||||
|
)
|
||||||
|
|
||||||
|
loss_mask_slow = lambda q, k, v, mask: mlx_ref_attn(
|
||||||
|
q, k, v, scale=scale, mask=mask
|
||||||
|
).sum()
|
||||||
|
loss_mask_fast = lambda q, k, v, mask: (
|
||||||
|
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||||
|
).sum()
|
||||||
|
|
||||||
|
B, N_kv, T, D = (2, 8, 128, 64)
|
||||||
|
scale = D**-0.5
|
||||||
|
|
||||||
for N_q in (8, 32):
|
for N_q in (8, 32):
|
||||||
q = mx.random.normal(shape=(B, N_q, T, D), dtype=mx.float16)
|
q = mx.random.normal(shape=(B, N_q, T, D), dtype=mx.float16)
|
||||||
k = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
|
k = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
|
||||||
v = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
|
v = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
|
||||||
|
|
||||||
cotan = mx.ones_like(q)
|
mask_additive = mx.random.normal((B, N_q, T, T), dtype=mx.float16)
|
||||||
o1, vjp1 = mx.vjp(f1, [q, k, v], [cotan])
|
mask_bool = mx.random.uniform(0, 1, (B, N_q, T, T), dtype=mx.float16) < 0.5
|
||||||
o2, vjp2 = mx.vjp(f2, [q, k, v], [cotan])
|
|
||||||
|
|
||||||
self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance))
|
for mask in (mask_additive, mask_bool):
|
||||||
for i in range(3):
|
test_vjp(sdpa_mask_slow, sdpa_mask_fast, [q, k, v, mask])
|
||||||
self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance))
|
test_grad(loss_mask_slow, loss_mask_fast, [q, k, v, mask])
|
||||||
|
|
||||||
g1 = mx.grad(f3)(q, k, v)
|
for mask in (None, "causal"):
|
||||||
g2 = mx.grad(f4)(q, k, v)
|
sdpa_slow = lambda q, k, v: mlx_ref_attn(
|
||||||
|
q, k, v, scale=scale, mask=mask
|
||||||
|
)
|
||||||
|
sdpa_fast = lambda q, k, v: mx.fast.scaled_dot_product_attention(
|
||||||
|
q, k, v, scale=scale, mask=mask
|
||||||
|
)
|
||||||
|
test_vjp(sdpa_slow, sdpa_fast, [q, k, v])
|
||||||
|
|
||||||
self.assertTrue(mx.allclose(g1, g2, **tolerance))
|
loss_slow = lambda q, k, v: mlx_ref_attn(
|
||||||
|
q, k, v, scale=scale, mask=mask
|
||||||
|
).sum()
|
||||||
|
loss_fast = lambda q, k, v: mx.fast.scaled_dot_product_attention(
|
||||||
|
q, k, v, scale=scale, mask=mask
|
||||||
|
).sum()
|
||||||
|
test_grad(loss_slow, loss_fast, [q, k, v])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -168,6 +168,42 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Test float64 - use CPU stream since float64 is not supported on GPU
|
||||||
|
with mx.stream(mx.cpu):
|
||||||
|
A_f64 = mx.array(
|
||||||
|
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float64
|
||||||
|
)
|
||||||
|
U_f64, S_f64, Vt_f64 = mx.linalg.svd(A_f64, compute_uv=True)
|
||||||
|
mx.eval(U_f64, S_f64, Vt_f64)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
U_f64[:, : len(S_f64)] @ mx.diag(S_f64) @ Vt_f64,
|
||||||
|
A_f64,
|
||||||
|
rtol=1e-5,
|
||||||
|
atol=1e-7,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(S_f64.dtype, mx.float64)
|
||||||
|
|
||||||
|
# Test complex64 - use CPU stream since complex64 is not supported on GPU
|
||||||
|
with mx.stream(mx.cpu):
|
||||||
|
A_c64 = mx.array(
|
||||||
|
[[1.0 + 1j, 2.0 + 2j], [3.0 + 3j, 4.0 + 4j]], dtype=mx.complex64
|
||||||
|
)
|
||||||
|
U_c64, S_c64, Vt_c64 = mx.linalg.svd(A_c64, compute_uv=True)
|
||||||
|
mx.eval(U_c64, S_c64, Vt_c64)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
U_c64[:, : len(S_c64)] @ mx.diag(S_c64) @ Vt_c64,
|
||||||
|
A_c64,
|
||||||
|
rtol=1e-5,
|
||||||
|
atol=1e-7,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(S_c64.dtype, mx.float32)
|
||||||
|
self.assertEqual(U_c64.dtype, mx.complex64)
|
||||||
|
self.assertEqual(Vt_c64.dtype, mx.complex64)
|
||||||
|
|
||||||
def test_inverse(self):
|
def test_inverse(self):
|
||||||
A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32)
|
A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32)
|
||||||
A_inv = mx.linalg.inv(A, stream=mx.cpu)
|
A_inv = mx.linalg.inv(A, stream=mx.cpu)
|
||||||
@@ -342,6 +378,43 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
|||||||
A_np = np.random.randn(3, n, n).astype(np.float32)
|
A_np = np.random.randn(3, n, n).astype(np.float32)
|
||||||
check_eigs_and_vecs(A_np)
|
check_eigs_and_vecs(A_np)
|
||||||
|
|
||||||
|
# Test float64 - use CPU stream since float64 is not supported on GPU
|
||||||
|
with mx.stream(mx.cpu):
|
||||||
|
A_np_f64 = np.array([[1.0, 1.0], [3.0, 4.0]], dtype=np.float64)
|
||||||
|
A_f64 = mx.array(A_np_f64, dtype=mx.float64)
|
||||||
|
eig_vals_f64, eig_vecs_f64 = mx.linalg.eig(A_f64)
|
||||||
|
mx.eval(eig_vals_f64, eig_vecs_f64)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
A_f64 @ eig_vecs_f64,
|
||||||
|
eig_vals_f64[..., None, :] * eig_vecs_f64,
|
||||||
|
rtol=1e-5,
|
||||||
|
atol=1e-5,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Eigenvalues should be complex64 (output dtype)
|
||||||
|
self.assertEqual(eig_vals_f64.dtype, mx.complex64)
|
||||||
|
self.assertEqual(eig_vecs_f64.dtype, mx.complex64)
|
||||||
|
|
||||||
|
# Test complex64 input - use CPU stream since complex64 is not supported on GPU
|
||||||
|
with mx.stream(mx.cpu):
|
||||||
|
A_np_c64 = np.array(
|
||||||
|
[[1.0 + 1j, 2.0 + 2j], [3.0 + 3j, 4.0 + 4j]], dtype=np.complex64
|
||||||
|
)
|
||||||
|
A_c64 = mx.array(A_np_c64, dtype=mx.complex64)
|
||||||
|
eig_vals_c64, eig_vecs_c64 = mx.linalg.eig(A_c64)
|
||||||
|
mx.eval(eig_vals_c64, eig_vecs_c64)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
A_c64 @ eig_vecs_c64,
|
||||||
|
eig_vals_c64[..., None, :] * eig_vecs_c64,
|
||||||
|
rtol=1e-5,
|
||||||
|
atol=1e-5,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(eig_vals_c64.dtype, mx.complex64)
|
||||||
|
self.assertEqual(eig_vecs_c64.dtype, mx.complex64)
|
||||||
|
|
||||||
# Test error cases
|
# Test error cases
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
mx.linalg.eig(mx.array([1.0, 2.0])) # 1D array
|
mx.linalg.eig(mx.array([1.0, 2.0])) # 1D array
|
||||||
|
|||||||
Reference in New Issue
Block a user