mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
12 Commits
3e05cea9f8
...
jit-nax
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5cf6f10bef | ||
|
|
7c1abc50c0 | ||
|
|
2b95d0c270 | ||
|
|
b054838780 | ||
|
|
dd79d3c465 | ||
|
|
704fd1ae28 | ||
|
|
c9f4dc851f | ||
|
|
f8bd675655 | ||
|
|
23a9168d34 | ||
|
|
bca205e287 | ||
|
|
1d4eacb737 | ||
|
|
8abd37ad05 |
@@ -1,18 +1,13 @@
|
||||
name: 'Build CUDA wheel'
|
||||
description: 'Build CUDA wheel'
|
||||
|
||||
inputs:
|
||||
toolkit:
|
||||
description: 'The CUDA toolkit'
|
||||
required: true
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Build package
|
||||
shell: bash
|
||||
env:
|
||||
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc
|
||||
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON
|
||||
run: |
|
||||
pip install auditwheel build patchelf setuptools
|
||||
python setup.py clean --all
|
||||
|
||||
26
.github/actions/build-cuda/action.yml
vendored
26
.github/actions/build-cuda/action.yml
vendored
@@ -1,26 +0,0 @@
|
||||
name: 'Build and Test with CUDA'
|
||||
description: 'Build and test MLX with CUDA'
|
||||
|
||||
inputs:
|
||||
toolkit:
|
||||
description: 'The CUDA toolkit'
|
||||
required: true
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Install Python package
|
||||
shell: bash
|
||||
env:
|
||||
DEBUG: 1
|
||||
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc
|
||||
run: pip install --no-build-isolation -e ".[dev]" -v
|
||||
|
||||
- name: Build CPP only
|
||||
shell: bash
|
||||
run: |
|
||||
cmake . -B build \
|
||||
-DMLX_BUILD_CUDA=ON \
|
||||
-DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc \
|
||||
-DCMAKE_BUILD_TYPE=DEBUG
|
||||
cmake --build build -j $(nproc)
|
||||
30
.github/actions/build-linux/action.yml
vendored
30
.github/actions/build-linux/action.yml
vendored
@@ -1,25 +1,41 @@
|
||||
name: 'Build and Test on Linux'
|
||||
description: 'Build and test MLX on Linux'
|
||||
|
||||
inputs:
|
||||
toolkit:
|
||||
description: 'The toolkit to build with'
|
||||
required: false
|
||||
default: 'cpu'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Install Python package
|
||||
id: python_build
|
||||
shell: sh
|
||||
env:
|
||||
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
|
||||
DEBUG: 1
|
||||
run: pip install --no-build-isolation -e ".[dev]" -v
|
||||
CMAKE_ARGS: >-
|
||||
-DCMAKE_COMPILE_WARNING_AS_ERROR=ON
|
||||
-DMLX_BUILD_CUDA=${{ startsWith(inputs.toolkit, 'cuda') && 'ON' || 'OFF' }}
|
||||
run: |
|
||||
if ${{ startsWith(inputs.toolkit, 'cuda') && runner.arch == 'arm64' }} ; then
|
||||
# There is no GPU in arm64 runner, use a common arch.
|
||||
CMAKE_ARGS="$CMAKE_ARGS -DMLX_CUDA_ARCHITECTURES=90a"
|
||||
# Can not build tests when the built executables can not run.
|
||||
CMAKE_ARGS="$CMAKE_ARGS -DMLX_BUILD_TESTS=OFF"
|
||||
fi
|
||||
pip install --no-build-isolation -e ".[dev]" -v
|
||||
# Pass the CMAKE_ARGS to following steps.
|
||||
echo CMAKE_ARGS="$CMAKE_ARGS" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Generate package stubs
|
||||
shell: sh
|
||||
run: |
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
|
||||
|
||||
- name: Build CPP only
|
||||
shell: bash
|
||||
run: |
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||
make -j $(nproc)
|
||||
cmake . -B build -DCMAKE_BUILD_TYPE=Debug ${{ steps.python_build.outputs.CMAKE_ARGS }}
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
7
.github/actions/setup-linux/action.yml
vendored
7
.github/actions/setup-linux/action.yml
vendored
@@ -51,8 +51,6 @@ runs:
|
||||
# Note: the CI machine does not meet CUDA 13's driver requirement.
|
||||
# Compatibility matrix:
|
||||
# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
|
||||
# The `nvcc` is installed into `/usr/local/cuda-VERSION/bin/nvcc` - but
|
||||
# it's *not* on the default toolkit path.
|
||||
PACKAGES: |
|
||||
{
|
||||
"cuda-12.6": "libcudnn9-dev-cuda-12 cuda-toolkit-12-6",
|
||||
@@ -60,13 +58,16 @@ runs:
|
||||
"cuda-13.0": "libcudnn9-dev-cuda-13 cuda-toolkit-13-0"
|
||||
}
|
||||
run: |
|
||||
export ARCH=${{ runner.arch == 'arm64' && 'arm64' || 'x86_64' }}
|
||||
# The CUDA binaries are hosted in the "sbsa" repo, the "arm64" repo is
|
||||
# Jetson specific. SBSA means Arm Server Base System Architecture.
|
||||
ARCH=${{ runner.arch == 'arm64' && 'sbsa' || 'x86_64' }}
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$ARCH/cuda-keyring_1.1-1_all.deb
|
||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y \
|
||||
libnccl2 libnccl-dev \
|
||||
${{ fromJson(env.PACKAGES)[inputs.toolkit] }}
|
||||
echo "/usr/local/${{ inputs.toolkit }}/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: CUDA packages and driver report
|
||||
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
|
||||
|
||||
12
.github/actions/test-linux/action.yml
vendored
12
.github/actions/test-linux/action.yml
vendored
@@ -1,8 +1,8 @@
|
||||
name: 'Run Linux tests'
|
||||
|
||||
inputs:
|
||||
cpu-only:
|
||||
description: 'Skip GPU tests'
|
||||
has-gpu:
|
||||
description: 'Run GPU tests'
|
||||
required: false
|
||||
default: false
|
||||
|
||||
@@ -17,7 +17,7 @@ runs:
|
||||
echo "::endgroup::"
|
||||
|
||||
- name: Run distributed tests
|
||||
if: ${{ inputs.cpu-only == 'true' }}
|
||||
if: ${{ inputs.has-gpu == 'false' }}
|
||||
shell: bash
|
||||
run: |
|
||||
echo "::group::Distributed tests"
|
||||
@@ -30,7 +30,7 @@ runs:
|
||||
echo "::endgroup::"
|
||||
|
||||
- name: Run Python tests - CPU
|
||||
if: ${{ inputs.cpu-only == 'true' }}
|
||||
if: ${{ inputs.has-gpu == 'false' }}
|
||||
shell: bash
|
||||
env:
|
||||
DEVICE: cpu
|
||||
@@ -40,7 +40,7 @@ runs:
|
||||
echo "::endgroup::"
|
||||
|
||||
- name: Run Python tests - GPU
|
||||
if: ${{ inputs.cpu-only == 'false' }}
|
||||
if: ${{ inputs.has-gpu == 'true' }}
|
||||
shell: bash
|
||||
env:
|
||||
DEVICE: gpu
|
||||
@@ -59,7 +59,7 @@ runs:
|
||||
echo "::endgroup::"
|
||||
|
||||
- name: Run CPP tests - GPU
|
||||
if: ${{ inputs.cpu-only == 'false' }}
|
||||
if: ${{ inputs.has-gpu == 'true' }}
|
||||
shell: bash
|
||||
env:
|
||||
DEVICE: gpu
|
||||
|
||||
@@ -17,29 +17,51 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
check_lint:
|
||||
name: Check Lint
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
|
||||
linux_build_and_test:
|
||||
name: Linux (cpu, ${{ matrix.arch }})
|
||||
needs: check_lint
|
||||
strategy:
|
||||
matrix:
|
||||
runner:
|
||||
- ubuntu-22.04
|
||||
- ubuntu-22.04-arm
|
||||
fail-fast: false
|
||||
runs-on: ${{ matrix.runner }}
|
||||
matrix:
|
||||
arch: ['x86_64', 'aarch64']
|
||||
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: ./.github/actions/setup-linux
|
||||
- uses: ./.github/actions/build-linux
|
||||
- uses: ./.github/actions/test-linux
|
||||
|
||||
cuda_build_and_test:
|
||||
name: Linux (${{ matrix.toolkit }}, ${{ matrix.arch }})
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
needs: check_lint
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: ['x86_64', 'aarch64']
|
||||
toolkit: ['cuda-12.6', 'cuda-12.9']
|
||||
runs-on: ${{ matrix.arch == 'x86_64' && 'gpu-t4-4-core' || 'ubuntu-22.04-arm' }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
cpu-only: true
|
||||
toolkit: ${{ matrix.toolkit }}
|
||||
- uses: ./.github/actions/build-linux
|
||||
with:
|
||||
toolkit: ${{ matrix.toolkit }}
|
||||
- uses: ./.github/actions/test-linux
|
||||
if: matrix.arch == 'x86_64'
|
||||
with:
|
||||
has-gpu: true
|
||||
|
||||
mac_build_and_test:
|
||||
name: macOS (${{ matrix.macos-target }})
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -49,38 +71,21 @@ jobs:
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos-target }}
|
||||
needs: check_lint
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: ./.github/actions/setup-macos
|
||||
- uses: ./.github/actions/build-macos
|
||||
|
||||
cuda_build_and_test:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolkit: ['cuda-12.6', 'cuda-12.9']
|
||||
runs-on: gpu-t4-4-core
|
||||
needs: check_lint
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
toolkit: ${{ matrix.toolkit }}
|
||||
- uses: ./.github/actions/build-cuda
|
||||
with:
|
||||
toolkit: ${{ matrix.toolkit }}
|
||||
- uses: ./.github/actions/test-linux
|
||||
|
||||
build_documentation:
|
||||
name: Build Documentation
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
runs-on: ubuntu-22.04
|
||||
needs: check_lint
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: ./.github/actions/build-docs
|
||||
|
||||
linux_fedora_build_cpp:
|
||||
name: Linux Fedora CPP Build (${{ matrix.arch }})
|
||||
name: Linux Fedora (${{ matrix.arch }})
|
||||
needs: check_lint
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -96,7 +101,7 @@ jobs:
|
||||
image: fedora:42
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: CPP Build Test - No Release
|
||||
run: |
|
||||
2
.github/workflows/documentation.yml
vendored
2
.github/workflows/documentation.yml
vendored
@@ -10,7 +10,7 @@ jobs:
|
||||
build:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: ./.github/actions/build-docs
|
||||
|
||||
deploy:
|
||||
|
||||
10
.github/workflows/nightly.yml
vendored
10
.github/workflows/nightly.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
python_version: ["3.10", "3.14"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: ./.github/actions/setup-linux
|
||||
- uses: ./.github/actions/build-linux-release
|
||||
with:
|
||||
@@ -46,14 +46,12 @@ jobs:
|
||||
- ubuntu-22.04-arm
|
||||
runs-on: ${{ matrix.runner }}
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
- uses: ./.github/actions/build-linux
|
||||
- uses: ./.github/actions/test-linux
|
||||
with:
|
||||
cpu-only: true
|
||||
|
||||
build_mac_release:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
@@ -62,7 +60,7 @@ jobs:
|
||||
python-version: ["3.10", "3.13"]
|
||||
runs-on: [self-hosted, macos]
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: ./.github/actions/setup-macos
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -82,7 +80,7 @@ jobs:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
runs-on: ubuntu-22-large
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
toolkit: 'cuda-12.9'
|
||||
|
||||
10
.github/workflows/release.yml
vendored
10
.github/workflows/release.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: ./.github/actions/build-docs
|
||||
|
||||
deploy_documentation:
|
||||
@@ -53,7 +53,7 @@ jobs:
|
||||
PYPI_RELEASE: 1
|
||||
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
@@ -86,7 +86,7 @@ jobs:
|
||||
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: ./.github/actions/setup-macos
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -133,14 +133,12 @@ jobs:
|
||||
PYPI_RELEASE: 1
|
||||
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
toolkit: 'cuda-12.9'
|
||||
- name: Build Python package
|
||||
uses: ./.github/actions/build-cuda-release
|
||||
with:
|
||||
toolkit: 'cuda-12.9'
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
|
||||
@@ -123,14 +123,21 @@ if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
|
||||
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
|
||||
endif()
|
||||
|
||||
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
|
||||
# managed memory.
|
||||
# Use native CUDA arch by default.
|
||||
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
|
||||
execute_process(
|
||||
COMMAND bash detect_cuda_arch.sh
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
COMMAND __nvcc_device_query
|
||||
OUTPUT_VARIABLE MLX_CUDA_ARCHITECTURES
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
set(UPGRADABLE_ARCHITECTURES "90;100;121")
|
||||
if(MLX_CUDA_ARCHITECTURES STREQUAL "")
|
||||
message(
|
||||
FATAL_ERROR
|
||||
"Can not get native CUDA arch, must set MLX_CUDA_ARCHITECTURES")
|
||||
elseif(MLX_CUDA_ARCHITECTURES IN_LIST UPGRADABLE_ARCHITECTURES)
|
||||
# Use arch-specific compute capability whenever possible.
|
||||
set(MLX_CUDA_ARCHITECTURES "${MLX_CUDA_ARCHITECTURES}a")
|
||||
endif()
|
||||
endif()
|
||||
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
||||
|
||||
@@ -154,17 +154,21 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
|
||||
}
|
||||
lock.unlock();
|
||||
if (!buf) {
|
||||
buf = new CudaBuffer{nullptr, size, device};
|
||||
cudaError_t err;
|
||||
void* data = nullptr;
|
||||
if (device == -1) {
|
||||
err = cudaMallocManaged(&buf->data, size);
|
||||
err = cudaMallocManaged(&data, size);
|
||||
} else {
|
||||
err = cudaMallocAsync(&buf->data, size, stream);
|
||||
err = cudaMallocAsync(&data, size, stream);
|
||||
}
|
||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||
}
|
||||
if (!data) {
|
||||
return Buffer{nullptr};
|
||||
}
|
||||
buf = new CudaBuffer{data, size, device};
|
||||
}
|
||||
lock.lock();
|
||||
}
|
||||
|
||||
@@ -29,6 +29,10 @@ class CudaHandle {
|
||||
}
|
||||
|
||||
~CudaHandle() {
|
||||
// Skip if there was an error to avoid throwing in the destructors
|
||||
if (cudaPeekAtLastError() != cudaSuccess) {
|
||||
return;
|
||||
}
|
||||
reset();
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
arch=`__nvcc_device_query`
|
||||
case "$arch" in
|
||||
"90")
|
||||
echo "90a" ;;
|
||||
"100")
|
||||
echo "100a" ;;
|
||||
"121")
|
||||
echo "121a" ;;
|
||||
*)
|
||||
echo "native" ;;
|
||||
esac
|
||||
@@ -24,12 +24,21 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
||||
}
|
||||
|
||||
bool use_cuda_graphs() {
|
||||
static bool use_graphs = []() {
|
||||
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
||||
}();
|
||||
static bool use_graphs = env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
||||
return use_graphs;
|
||||
}
|
||||
|
||||
const char* save_cuda_graphs_dot_file() {
|
||||
static const char* filename = []() -> const char* {
|
||||
const char* env = std::getenv("MLX_SAVE_CUDA_GRAPHS_DOT_FILE");
|
||||
if (env && std::strlen(env) == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return env;
|
||||
}();
|
||||
return filename;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Device::Device(int device) : device_(device) {
|
||||
@@ -421,6 +430,14 @@ void CommandEncoder::commit() {
|
||||
|
||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||
}
|
||||
|
||||
// Save cuda graph to dot file
|
||||
if (const char* filename = save_cuda_graphs_dot_file(); filename) {
|
||||
static int count = 0;
|
||||
auto path = fmt::format("{}_{}.dot", filename, ++count);
|
||||
CHECK_CUDA_ERROR(cudaGraphDebugDotPrint(graph_, path.c_str(), 0));
|
||||
}
|
||||
|
||||
// Reset state
|
||||
from_nodes_.clear();
|
||||
to_nodes_.clear();
|
||||
|
||||
@@ -305,6 +305,7 @@ void Event::wait() {
|
||||
} else {
|
||||
event->atomic->wait(value());
|
||||
}
|
||||
CHECK_CUDA_ERROR(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
void Event::wait(Stream s) {
|
||||
|
||||
@@ -22,26 +22,28 @@ inline __device__ float2 plus_f2(const float2& a, const float2& b) {
|
||||
}
|
||||
|
||||
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
|
||||
template <typename T, int BLOCK_DIM>
|
||||
template <typename T, int BLOCK_DIM, int GROUP_DIM = WARP_SIZE>
|
||||
struct BlockBroadcastReduce {
|
||||
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
|
||||
static_assert(BLOCK_DIM % WARP_SIZE == 0);
|
||||
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
|
||||
using TempStorage = T[std::max(BLOCK_DIM / WARP_SIZE, 1)];
|
||||
|
||||
cg::thread_block& block;
|
||||
TempStorage& temp;
|
||||
|
||||
template <typename Op>
|
||||
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
auto warp = cg::tiled_partition<GROUP_DIM>(block);
|
||||
T x = cg::reduce(warp, input, op);
|
||||
if (warp.thread_rank() == 0) {
|
||||
temp[warp.meta_group_rank()] = x;
|
||||
if constexpr (BLOCK_DIM > GROUP_DIM) {
|
||||
if (warp.thread_rank() == 0) {
|
||||
temp[warp.meta_group_rank()] = x;
|
||||
}
|
||||
block.sync();
|
||||
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
|
||||
: init_value;
|
||||
return cg::reduce(warp, x, op);
|
||||
} else {
|
||||
return x;
|
||||
}
|
||||
block.sync();
|
||||
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
|
||||
: init_value;
|
||||
return cg::reduce(warp, x, op);
|
||||
}
|
||||
|
||||
__device__ T Sum(const T& input) {
|
||||
@@ -49,6 +51,52 @@ struct BlockBroadcastReduce {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int BLOCK_DIM, int REDUCE_DIM, int N_READS = 4>
|
||||
__global__ void rms_norm_small(
|
||||
const T* x,
|
||||
const T* w,
|
||||
T* out,
|
||||
float eps,
|
||||
uint32_t axis_size,
|
||||
uint32_t n_rows,
|
||||
int64_t w_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM, REDUCE_DIM>;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
auto row =
|
||||
(grid.block_rank() * block.dim_threads().y) + block.thread_index().y;
|
||||
if (row >= n_rows) {
|
||||
return;
|
||||
}
|
||||
x += row * axis_size;
|
||||
out += row * axis_size;
|
||||
|
||||
// Normalizer.
|
||||
float normalizer = 0;
|
||||
auto index = block.thread_index().x;
|
||||
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float t = static_cast<float>(xn[i]);
|
||||
normalizer += t * t;
|
||||
}
|
||||
|
||||
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
|
||||
normalizer = rsqrt(normalizer / axis_size + eps);
|
||||
|
||||
// Outputs.
|
||||
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float y = static_cast<float>(xn[i]) * normalizer;
|
||||
xn[i] = wn[i] * static_cast<T>(y);
|
||||
}
|
||||
store_vector<N_READS>(out, index, xn, axis_size);
|
||||
}
|
||||
|
||||
template <typename T, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void rms_norm(
|
||||
const T* x,
|
||||
@@ -94,6 +142,74 @@ __global__ void rms_norm(
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
bool HAS_W,
|
||||
int BLOCK_DIM,
|
||||
int REDUCE_DIM,
|
||||
int N_READS = 4>
|
||||
__global__ void rms_norm_vjp_small(
|
||||
const T* x,
|
||||
const T* w,
|
||||
const T* g,
|
||||
T* gx,
|
||||
T* gw,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
int32_t n_rows,
|
||||
int64_t w_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM, REDUCE_DIM>;
|
||||
__shared__ typename BlockReduceF2::TempStorage temp;
|
||||
|
||||
auto row =
|
||||
(grid.block_rank() * block.dim_threads().y) + block.thread_index().y;
|
||||
if (row >= n_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
x += row * axis_size;
|
||||
g += row * axis_size;
|
||||
gx += row * axis_size;
|
||||
gw += row * axis_size;
|
||||
|
||||
// Normalizer.
|
||||
float2 factors = {};
|
||||
auto index = block.thread_index().x;
|
||||
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
|
||||
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float t = static_cast<float>(xn[i]);
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
float wg = wi * gi;
|
||||
factors = plus_f2(factors, {wg * t, t * t});
|
||||
}
|
||||
|
||||
factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});
|
||||
float meangwx = factors.x / axis_size;
|
||||
float normalizer = rsqrt(factors.y / axis_size + eps);
|
||||
float normalizer3 = normalizer * normalizer * normalizer;
|
||||
|
||||
// Outputs.
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = xn[i];
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
xn[i] = static_cast<T>(normalizer * wi * gi - xi * meangwx * normalizer3);
|
||||
if constexpr (HAS_W) {
|
||||
wn[i] = static_cast<T>(gi * xi * normalizer);
|
||||
}
|
||||
}
|
||||
store_vector<N_READS>(gx, index, xn, axis_size);
|
||||
if constexpr (HAS_W) {
|
||||
store_vector<N_READS>(gw, index, wn, axis_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void rms_norm_vjp(
|
||||
const T* x,
|
||||
@@ -107,12 +223,8 @@ __global__ void rms_norm_vjp(
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>;
|
||||
__shared__ union {
|
||||
typename BlockReduceF::TempStorage f;
|
||||
typename BlockReduceF2::TempStorage f2;
|
||||
} temp;
|
||||
__shared__ typename BlockReduceF2::TempStorage temp;
|
||||
|
||||
x += grid.block_rank() * axis_size;
|
||||
g += grid.block_rank() * axis_size;
|
||||
@@ -134,7 +246,7 @@ __global__ void rms_norm_vjp(
|
||||
factors = plus_f2(factors, {wg * t, t * t});
|
||||
}
|
||||
}
|
||||
factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {});
|
||||
factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});
|
||||
float meangwx = factors.x / axis_size;
|
||||
float normalizer = rsqrt(factors.y / axis_size + eps);
|
||||
float normalizer3 = normalizer * normalizer * normalizer;
|
||||
@@ -169,6 +281,43 @@ bool RMSNorm::use_fallback(Stream s) {
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
template <int n_per_thread, typename F>
|
||||
void dispatch_group_dim(int axis_size, F&& f) {
|
||||
if (axis_size <= n_per_thread * 8) {
|
||||
f(std::integral_constant<int, 8>{},
|
||||
std::integral_constant<int, 1>(),
|
||||
std::integral_constant<int, 16>());
|
||||
} else if (axis_size <= n_per_thread * 16) {
|
||||
f(std::integral_constant<int, 16>{},
|
||||
std::integral_constant<int, 1>(),
|
||||
std::integral_constant<int, 8>());
|
||||
} else if (axis_size <= n_per_thread * 32) {
|
||||
f(std::integral_constant<int, 32>{},
|
||||
std::integral_constant<int, 1>(),
|
||||
std::integral_constant<int, 4>());
|
||||
} else if (axis_size <= n_per_thread * 32 * 2) {
|
||||
f(std::integral_constant<int, 32>{},
|
||||
std::integral_constant<int, 2>(),
|
||||
std::integral_constant<int, 2>());
|
||||
} else if (axis_size <= n_per_thread * 32 * 4) {
|
||||
f(std::integral_constant<int, 32>{},
|
||||
std::integral_constant<int, 4>(),
|
||||
std::integral_constant<int, 1>());
|
||||
} else if (axis_size <= n_per_thread * 32 * 8) {
|
||||
f(std::integral_constant<int, 32>{},
|
||||
std::integral_constant<int, 8>(),
|
||||
std::integral_constant<int, 1>());
|
||||
} else if (axis_size <= n_per_thread * 32 * 16) {
|
||||
f(std::integral_constant<int, 32>{},
|
||||
std::integral_constant<int, 16>(),
|
||||
std::integral_constant<int, 1>());
|
||||
} else {
|
||||
f(std::integral_constant<int, 32>{},
|
||||
std::integral_constant<int, 32>(),
|
||||
std::integral_constant<int, 1>());
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: There are duplicate code with backend/metal/normalization.cpp
|
||||
void RMSNorm::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
@@ -216,12 +365,33 @@ void RMSNorm::eval_gpu(
|
||||
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr int N_READS = 16 / sizeof(DataType);
|
||||
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>;
|
||||
if (axis_size <= N_READS * 1024) {
|
||||
dispatch_group_dim<N_READS>(
|
||||
axis_size, [&](auto group_dim, auto n_groups, auto groups_per_block) {
|
||||
constexpr int block_dim = n_groups() * group_dim();
|
||||
auto kernel =
|
||||
cu::rms_norm_small<DataType, block_dim, group_dim(), N_READS>;
|
||||
auto n_blocks =
|
||||
(n_rows + groups_per_block() - 1) / groups_per_block();
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
n_blocks,
|
||||
{block_dim, groups_per_block()},
|
||||
0,
|
||||
gpu_ptr<DataType>(x),
|
||||
gpu_ptr<DataType>(w),
|
||||
gpu_ptr<DataType>(out),
|
||||
eps_,
|
||||
axis_size,
|
||||
n_rows,
|
||||
w_stride);
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::rms_norm<DataType, 1024, N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
n_rows,
|
||||
block_dim(),
|
||||
1024,
|
||||
0,
|
||||
gpu_ptr<DataType>(x),
|
||||
gpu_ptr<DataType>(w),
|
||||
@@ -229,7 +399,7 @@ void RMSNorm::eval_gpu(
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -306,27 +476,51 @@ void RMSNormVJP::eval_gpu(
|
||||
dispatch_bool(has_w, [&](auto has_w_constant) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr int N_READS = 16 / sizeof(DataType);
|
||||
dispatch_block_dim(
|
||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
auto kernel = cu::rms_norm_vjp<
|
||||
DataType,
|
||||
has_w_constant.value,
|
||||
block_dim(),
|
||||
N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
n_rows,
|
||||
block_dim(),
|
||||
0,
|
||||
gpu_ptr<DataType>(x),
|
||||
gpu_ptr<DataType>(w),
|
||||
gpu_ptr<DataType>(g),
|
||||
gpu_ptr<DataType>(gx),
|
||||
gpu_ptr<DataType>(gw_temp),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
if (axis_size <= N_READS * 1024) {
|
||||
dispatch_group_dim<N_READS>(
|
||||
axis_size,
|
||||
[&](auto group_dim, auto n_groups, auto groups_per_block) {
|
||||
constexpr int block_dim = group_dim() * n_groups();
|
||||
auto kernel = cu::rms_norm_vjp_small<
|
||||
DataType,
|
||||
has_w_constant.value,
|
||||
block_dim,
|
||||
group_dim(),
|
||||
N_READS>;
|
||||
auto n_blocks =
|
||||
(n_rows + groups_per_block() - 1) / groups_per_block();
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
n_blocks,
|
||||
{block_dim, groups_per_block()},
|
||||
0,
|
||||
gpu_ptr<DataType>(x),
|
||||
gpu_ptr<DataType>(w),
|
||||
gpu_ptr<DataType>(g),
|
||||
gpu_ptr<DataType>(gx),
|
||||
gpu_ptr<DataType>(gw_temp),
|
||||
eps_,
|
||||
axis_size,
|
||||
n_rows,
|
||||
w_stride);
|
||||
});
|
||||
} else {
|
||||
auto kernel =
|
||||
cu::rms_norm_vjp<DataType, has_w_constant.value, 1024, N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
n_rows,
|
||||
1024,
|
||||
0,
|
||||
gpu_ptr<DataType>(x),
|
||||
gpu_ptr<DataType>(w),
|
||||
gpu_ptr<DataType>(g),
|
||||
gpu_ptr<DataType>(gx),
|
||||
gpu_ptr<DataType>(gw_temp),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -63,6 +63,38 @@ array prepare_sdpa_input(const array& x, Stream s) {
|
||||
return x;
|
||||
}
|
||||
|
||||
void malloc_with_same_layout(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& o,
|
||||
const array& q) {
|
||||
if (q.flags().row_contiguous) {
|
||||
o.set_data(cu::malloc_async(o.nbytes(), encoder));
|
||||
return;
|
||||
}
|
||||
// fill_order = argsort(q.strides())
|
||||
Shape fill_order(q.ndim());
|
||||
std::iota(fill_order.begin(), fill_order.end(), 0);
|
||||
std::stable_sort(
|
||||
fill_order.begin(), fill_order.end(), [&q](int idx1, int idx2) {
|
||||
auto s1 = q.strides(idx1) > 0 ? q.strides(idx1) : 1;
|
||||
auto s2 = q.strides(idx2) > 0 ? q.strides(idx2) : 1;
|
||||
return s1 < s2;
|
||||
});
|
||||
// Generate o_strides with fill_order
|
||||
Strides o_strides(q.ndim());
|
||||
int64_t stride = 1;
|
||||
for (int i : fill_order) {
|
||||
o_strides[i] = stride;
|
||||
stride *= o.shape(i);
|
||||
}
|
||||
// o is a transposed contiguous array
|
||||
o.set_data(
|
||||
cu::malloc_async(o.nbytes(), encoder),
|
||||
o.size(),
|
||||
o_strides,
|
||||
{true, false, false});
|
||||
}
|
||||
|
||||
constexpr int QKV_NDIM = 4;
|
||||
|
||||
struct SDPACacheKey {
|
||||
@@ -75,6 +107,8 @@ struct SDPACacheKey {
|
||||
std::array<int64_t, QKV_NDIM> k_strides;
|
||||
std::array<int64_t, QKV_NDIM> v_strides;
|
||||
bool do_causal;
|
||||
std::array<int, QKV_NDIM> mask_shape;
|
||||
std::array<int64_t, QKV_NDIM> mask_strides;
|
||||
bool output_logsumexp;
|
||||
};
|
||||
|
||||
@@ -84,6 +118,7 @@ inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool do_causal,
|
||||
const std::optional<array>& mask_arr,
|
||||
bool output_logsumexp = true) {
|
||||
BytesKey<SDPACacheKey> cache_key;
|
||||
cache_key.pod = {
|
||||
@@ -96,20 +131,26 @@ inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
|
||||
vector_key<QKV_NDIM>(k.strides()),
|
||||
vector_key<QKV_NDIM>(v.strides()),
|
||||
do_causal,
|
||||
{},
|
||||
{},
|
||||
output_logsumexp,
|
||||
};
|
||||
if (mask_arr) {
|
||||
cache_key.pod.mask_shape = vector_key<QKV_NDIM>(mask_arr->shape());
|
||||
cache_key.pod.mask_strides = vector_key<QKV_NDIM>(mask_arr->strides());
|
||||
}
|
||||
return cache_key;
|
||||
}
|
||||
|
||||
auto& sdpa_cache() {
|
||||
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
|
||||
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 16);
|
||||
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 64);
|
||||
return cache;
|
||||
}
|
||||
|
||||
auto& sdpa_backward_cache() {
|
||||
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
|
||||
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 16);
|
||||
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 64);
|
||||
return cache;
|
||||
}
|
||||
|
||||
@@ -118,6 +159,7 @@ enum UIDS {
|
||||
K,
|
||||
V,
|
||||
SCALE,
|
||||
BIAS,
|
||||
O,
|
||||
STATS,
|
||||
// Backward graph:
|
||||
@@ -133,6 +175,7 @@ fe::graph::Graph build_sdpa_graph(
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool do_causal,
|
||||
const std::optional<array>& mask_arr,
|
||||
bool output_logsumexp,
|
||||
const array& o,
|
||||
const array& stats) {
|
||||
@@ -164,8 +207,19 @@ fe::graph::Graph build_sdpa_graph(
|
||||
auto options = fe::graph::SDPA_attributes()
|
||||
.set_name("sdpa_cudnn")
|
||||
.set_attn_scale(scale)
|
||||
.set_causal_mask(do_causal)
|
||||
.set_generate_stats(output_logsumexp);
|
||||
if (do_causal) {
|
||||
if (q.shape(2) > k.shape(2)) {
|
||||
options.set_causal_mask(do_causal);
|
||||
} else {
|
||||
options.set_causal_mask_bottom_right(do_causal);
|
||||
}
|
||||
}
|
||||
if (mask_arr) {
|
||||
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
|
||||
set_tensor_attrs(bias_, BIAS, *mask_arr);
|
||||
options.set_bias(bias_);
|
||||
}
|
||||
|
||||
auto [o_, stats_] = graph.sdpa(q_, k_, v_, options);
|
||||
o_->set_output(true);
|
||||
@@ -192,6 +246,7 @@ fe::graph::Graph build_sdpa_backward_graph(
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool do_causal,
|
||||
const std::optional<array>& mask_arr,
|
||||
const array& o,
|
||||
const array& d_o,
|
||||
const array& stats,
|
||||
@@ -233,7 +288,19 @@ fe::graph::Graph build_sdpa_backward_graph(
|
||||
auto options = fe::graph::SDPA_backward_attributes()
|
||||
.set_name("sdpa_backward_cudnn")
|
||||
.set_attn_scale(scale)
|
||||
.set_causal_mask(do_causal);
|
||||
.set_attn_scale(scale);
|
||||
if (do_causal) {
|
||||
if (q.shape(2) > k.shape(2)) {
|
||||
options.set_causal_mask(do_causal);
|
||||
} else {
|
||||
options.set_causal_mask_bottom_right(do_causal);
|
||||
}
|
||||
}
|
||||
if (mask_arr) {
|
||||
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
|
||||
set_tensor_attrs(bias_, BIAS, *mask_arr);
|
||||
options.set_bias(bias_);
|
||||
}
|
||||
|
||||
auto [d_q_, d_k_, d_v_] =
|
||||
graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options);
|
||||
@@ -286,7 +353,6 @@ bool supports_sdpa_cudnn(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1);
|
||||
@@ -299,19 +365,8 @@ bool supports_sdpa_cudnn(
|
||||
return false;
|
||||
}
|
||||
|
||||
if (has_mask) {
|
||||
// TODO: Support array masks.
|
||||
if (!do_causal) {
|
||||
return false;
|
||||
}
|
||||
// FIXME: Causal mask generates wrong results when L_Q != L_K.
|
||||
if (q.shape(2) != k.shape(2)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Only use cuDNN for prefilling and training.
|
||||
if (q.shape(2) != k.shape(2)) {
|
||||
// Only use cuDNN for prefilling (T_q > 1) and training (T_q == T_kv).
|
||||
if ((q.shape(2) == 1) && (q.shape(2) != k.shape(2))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -333,32 +388,33 @@ void sdpa_cudnn(
|
||||
array& o,
|
||||
array& stats,
|
||||
bool do_causal,
|
||||
const std::optional<array>& mask_arr,
|
||||
bool output_logsumexp,
|
||||
Stream s) {
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
|
||||
// TODO: Handle donation.
|
||||
// TODO: Make O use same memory layout with Q.
|
||||
o.set_data(cu::malloc_async(o.nbytes(), encoder));
|
||||
malloc_with_same_layout(encoder, o, q);
|
||||
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
encoder.set_input_array(v);
|
||||
encoder.set_output_array(o);
|
||||
|
||||
if (mask_arr) {
|
||||
encoder.set_input_array(*mask_arr);
|
||||
}
|
||||
if (output_logsumexp) {
|
||||
stats.set_data(cu::malloc_async(stats.nbytes(), encoder));
|
||||
encoder.set_output_array(stats);
|
||||
}
|
||||
|
||||
// Search cache.
|
||||
auto cache_key =
|
||||
build_sdpa_cache_key(encoder, q, k, v, do_causal, output_logsumexp);
|
||||
auto cache_key = build_sdpa_cache_key(
|
||||
encoder, q, k, v, do_causal, mask_arr, output_logsumexp);
|
||||
auto it = sdpa_cache().find(cache_key);
|
||||
if (it == sdpa_cache().end()) {
|
||||
auto graph = build_sdpa_graph(
|
||||
handle, q, k, v, do_causal, output_logsumexp, o, stats);
|
||||
handle, q, k, v, do_causal, mask_arr, output_logsumexp, o, stats);
|
||||
it = sdpa_cache().emplace(cache_key, std::move(graph)).first;
|
||||
}
|
||||
auto& graph = it->second;
|
||||
@@ -369,6 +425,9 @@ void sdpa_cudnn(
|
||||
{V, const_cast<void*>(gpu_ptr<void>(v))},
|
||||
{SCALE, &scale},
|
||||
{O, gpu_ptr<void>(o)}};
|
||||
if (mask_arr) {
|
||||
variant_pack[BIAS] = const_cast<void*>(gpu_ptr<void>(*mask_arr));
|
||||
}
|
||||
if (output_logsumexp) {
|
||||
variant_pack[STATS] = gpu_ptr<void>(stats);
|
||||
}
|
||||
@@ -384,6 +443,7 @@ void sdpa_backward_cudnn(
|
||||
const array& o,
|
||||
const array& stats,
|
||||
bool do_causal,
|
||||
const std::optional<array>& mask_arr,
|
||||
const array& d_o,
|
||||
array& d_q,
|
||||
array& d_k,
|
||||
@@ -392,10 +452,9 @@ void sdpa_backward_cudnn(
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
|
||||
// TODO: Handle donation.
|
||||
d_q.set_data(cu::malloc_async(d_q.nbytes(), encoder));
|
||||
d_k.set_data(cu::malloc_async(d_k.nbytes(), encoder));
|
||||
d_v.set_data(cu::malloc_async(d_v.nbytes(), encoder));
|
||||
malloc_with_same_layout(encoder, d_q, q);
|
||||
malloc_with_same_layout(encoder, d_k, k);
|
||||
malloc_with_same_layout(encoder, d_v, v);
|
||||
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
@@ -406,13 +465,16 @@ void sdpa_backward_cudnn(
|
||||
encoder.set_output_array(d_q);
|
||||
encoder.set_output_array(d_k);
|
||||
encoder.set_output_array(d_v);
|
||||
if (mask_arr) {
|
||||
encoder.set_input_array(*mask_arr);
|
||||
}
|
||||
|
||||
// Search cache.
|
||||
auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal);
|
||||
auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal, mask_arr);
|
||||
auto it = sdpa_backward_cache().find(cache_key);
|
||||
if (it == sdpa_backward_cache().end()) {
|
||||
auto graph = build_sdpa_backward_graph(
|
||||
handle, q, k, v, do_causal, o, d_o, stats, d_q, d_k, d_v);
|
||||
handle, q, k, v, do_causal, mask_arr, o, d_o, stats, d_q, d_k, d_v);
|
||||
it = sdpa_backward_cache().emplace(cache_key, std::move(graph)).first;
|
||||
}
|
||||
auto& graph = it->second;
|
||||
@@ -428,6 +490,9 @@ void sdpa_backward_cudnn(
|
||||
{D_Q, gpu_ptr<void>(d_q)},
|
||||
{D_K, gpu_ptr<void>(d_k)},
|
||||
{D_V, gpu_ptr<void>(d_v)}};
|
||||
if (mask_arr) {
|
||||
variant_pack[BIAS] = const_cast<void*>(gpu_ptr<void>(*mask_arr));
|
||||
}
|
||||
|
||||
execute_graph(encoder, handle, graph, variant_pack);
|
||||
}
|
||||
@@ -469,7 +534,11 @@ bool ScaledDotProductAttention::use_fallback(
|
||||
|
||||
return !supports_sdpa_vector(
|
||||
q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) &&
|
||||
!supports_sdpa_cudnn(q, k, v, has_mask, do_causal, s);
|
||||
!supports_sdpa_cudnn(q, k, v, do_causal, s);
|
||||
}
|
||||
|
||||
bool ScaledDotProductAttention::supports_bool_mask() {
|
||||
return false;
|
||||
}
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
@@ -487,6 +556,11 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
bool has_mask = inputs.size() - has_sinks_ > 3;
|
||||
bool has_arr_mask = has_mask && !do_causal_;
|
||||
|
||||
std::optional<array> mask_arr;
|
||||
if (has_arr_mask) {
|
||||
mask_arr = prepare_sdpa_input(inputs[3], s);
|
||||
}
|
||||
|
||||
if (supports_sdpa_vector(
|
||||
q, k, v, has_mask, has_arr_mask, do_causal_, output_logsumexp_)) {
|
||||
if (has_sinks_) {
|
||||
@@ -495,7 +569,17 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s);
|
||||
}
|
||||
} else {
|
||||
sdpa_cudnn(q, k, v, scale_, out, stats, do_causal_, output_logsumexp_, s);
|
||||
sdpa_cudnn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale_,
|
||||
out,
|
||||
stats,
|
||||
do_causal_,
|
||||
mask_arr,
|
||||
output_logsumexp_,
|
||||
s);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -515,13 +599,21 @@ void ScaledDotProductAttentionVJP::eval_gpu(
|
||||
|
||||
auto& s = stream();
|
||||
|
||||
assert(inputs.size() == 6);
|
||||
assert(inputs.size() >= 6);
|
||||
int primals_size = inputs.size() - 3;
|
||||
bool has_arr_mask = primals_size > 3 + has_sinks_;
|
||||
|
||||
array q = prepare_sdpa_input(inputs[0], s);
|
||||
array k = prepare_sdpa_input(inputs[1], s);
|
||||
array v = prepare_sdpa_input(inputs[2], s);
|
||||
array o = prepare_sdpa_input(inputs[3], s);
|
||||
array stats = prepare_sdpa_input(inputs[4], s);
|
||||
array d_o = prepare_sdpa_input(inputs[5], s);
|
||||
array o = prepare_sdpa_input(inputs[primals_size], s);
|
||||
array stats = prepare_sdpa_input(inputs[primals_size + 1], s);
|
||||
array d_o = prepare_sdpa_input(inputs[primals_size + 2], s);
|
||||
|
||||
std::optional<array> mask_arr;
|
||||
if (has_arr_mask) {
|
||||
mask_arr = prepare_sdpa_input(inputs[3], s);
|
||||
}
|
||||
|
||||
assert(outputs.size() == 3);
|
||||
auto& d_q = outputs[0];
|
||||
@@ -529,7 +621,7 @@ void ScaledDotProductAttentionVJP::eval_gpu(
|
||||
auto& d_v = outputs[2];
|
||||
|
||||
sdpa_backward_cudnn(
|
||||
q, k, v, scale_, o, stats, do_causal_, d_o, d_q, d_k, d_v, s);
|
||||
q, k, v, scale_, o, stats, do_causal_, mask_arr, d_o, d_q, d_k, d_v, s);
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
@@ -16,7 +16,77 @@ INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h
|
||||
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
|
||||
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
|
||||
# CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
|
||||
|
||||
CCC="xcrun -sdk macosx metal -x metal"
|
||||
|
||||
HDRS=$( $CCC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P -CC -C -H "$INPUT_FILE" $CFLAGS -w 2>&1 1>/dev/null )
|
||||
|
||||
declare -a HDRS_LIST=($HDRS)
|
||||
declare -a HDRS_STACK=()
|
||||
declare -a HDRS_SORTED=()
|
||||
|
||||
length=${#HDRS_LIST[@]}
|
||||
|
||||
HDRS_LIST+=(".")
|
||||
|
||||
for ((i=0; i<${length}; i+=2));
|
||||
do
|
||||
|
||||
header="${HDRS_LIST[$i+1]#$SRC_DIR/}"
|
||||
|
||||
str_this="${HDRS_LIST[$i]}"
|
||||
str_next="${HDRS_LIST[$i + 2]}"
|
||||
|
||||
depth_this=${#str_this}
|
||||
depth_next=${#str_next}
|
||||
|
||||
# If we have a dependency then we stack it
|
||||
if [ $depth_next -gt $depth_this ]; then
|
||||
HDRS_STACK=($header ${HDRS_STACK[@]})
|
||||
|
||||
# If we are done with this level
|
||||
else
|
||||
# We add the header to out list
|
||||
HDRS_SORTED+=($header)
|
||||
|
||||
# Pop the stacked up dependencies
|
||||
pop_len=$((depth_this - depth_next))
|
||||
for popped_header in "${HDRS_STACK[@]:0:$pop_len}"
|
||||
do
|
||||
HDRS_SORTED+=($popped_header)
|
||||
done
|
||||
|
||||
HDRS_STACK=(${HDRS_STACK[@]:$pop_len})
|
||||
fi
|
||||
|
||||
done
|
||||
|
||||
HDRS_SORTED+=("${INPUT_FILE#$SRC_DIR/}")
|
||||
|
||||
CONTENT=$(
|
||||
echo "// Copyright © 2025 Apple Inc."
|
||||
echo ""
|
||||
echo "// Auto generated source for $INPUT_FILE"
|
||||
echo ""
|
||||
|
||||
for header in "${HDRS_SORTED[@]}"
|
||||
do
|
||||
echo "///////////////////////////////////////////////////////////////////////////////"
|
||||
echo "// Contents from \"${header}\""
|
||||
echo "///////////////////////////////////////////////////////////////////////////////"
|
||||
echo ""
|
||||
|
||||
echo "#line 1 \"${header}\""
|
||||
|
||||
grep -h -v -G -e "#include \".*.h\"" -e "#pragma once" "${SRC_DIR}/${header}"
|
||||
|
||||
echo ""
|
||||
|
||||
done
|
||||
|
||||
echo "///////////////////////////////////////////////////////////////////////////////"
|
||||
)
|
||||
|
||||
cat << EOF > "$OUTPUT_FILE"
|
||||
namespace mlx::core::metal {
|
||||
|
||||
@@ -569,6 +569,10 @@ bool ScaledDotProductAttention::use_fallback(
|
||||
return !(supports_sdpa_full || supports_sdpa_vector);
|
||||
}
|
||||
|
||||
bool ScaledDotProductAttention::supports_bool_mask() {
|
||||
return true;
|
||||
}
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
|
||||
@@ -36,6 +36,10 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
||||
return true;
|
||||
}
|
||||
|
||||
bool fast::ScaledDotProductAttention::supports_bool_mask() {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fast::ScaledDotProductAttentionVJP::use_fallback(
|
||||
const array& q,
|
||||
Stream s) {
|
||||
|
||||
@@ -5,3 +5,48 @@
|
||||
ncclResult_t ncclGetUniqueId(ncclUniqueId*) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
const char* ncclGetErrorString(ncclResult_t result) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ncclResult_t
|
||||
ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t ncclCommDestroy(ncclComm_t comm) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t ncclAllGather(
|
||||
const void* sendbuff,
|
||||
void* recvbuff,
|
||||
size_t sendcount,
|
||||
ncclDataType_t datatype,
|
||||
ncclComm_t comm,
|
||||
cudaStream_t stream) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t ncclAllReduce(
|
||||
const void* sendbuff,
|
||||
void* recvbuff,
|
||||
size_t count,
|
||||
ncclDataType_t datatype,
|
||||
ncclRedOp_t op,
|
||||
ncclComm_t comm,
|
||||
cudaStream_t stream) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t ncclReduceScatter(
|
||||
const void* sendbuff,
|
||||
void* recvbuff,
|
||||
size_t recvcount,
|
||||
ncclDataType_t datatype,
|
||||
ncclRedOp_t op,
|
||||
ncclComm_t comm,
|
||||
cudaStream_t stream) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
11
mlx/fast.cpp
11
mlx/fast.cpp
@@ -800,6 +800,15 @@ array scaled_dot_product_attention(
|
||||
is_training,
|
||||
output_logsumexp,
|
||||
stream)) {
|
||||
if (has_bool_mask && !ScaledDotProductAttention::supports_bool_mask()) {
|
||||
// Convert bool mask to additive mask.
|
||||
float inf = std::numeric_limits<float>::infinity();
|
||||
array& mask = inputs[3];
|
||||
mask = where(
|
||||
mask,
|
||||
full_like(mask, 0, final_type, s),
|
||||
full_like(mask, -inf, final_type, s));
|
||||
}
|
||||
Shape out_shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
||||
auto primitive = std::make_shared<ScaledDotProductAttention>(
|
||||
stream, fallback, scale, do_causal, has_sinks, output_logsumexp);
|
||||
@@ -839,7 +848,7 @@ std::vector<array> ScaledDotProductAttention::vjp(
|
||||
|
||||
std::vector<Shape> shapes;
|
||||
std::vector<Dtype> dtypes;
|
||||
for (int i = 0; i < primals.size(); ++i) {
|
||||
for (int i = 0; i < /* outputs size */ 3; ++i) {
|
||||
shapes.push_back(primals[i].shape());
|
||||
dtypes.push_back(primals[i].dtype());
|
||||
}
|
||||
|
||||
@@ -228,6 +228,7 @@ class ScaledDotProductAttention : public Custom {
|
||||
bool is_training,
|
||||
bool output_logsumexp,
|
||||
Stream s);
|
||||
static bool supports_bool_mask();
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
|
||||
@@ -135,7 +135,11 @@ class Scheduler {
|
||||
|
||||
~Scheduler() {
|
||||
for (auto s : streams_) {
|
||||
synchronize(s);
|
||||
try {
|
||||
synchronize(s);
|
||||
} catch (const std::runtime_error&) {
|
||||
// ignore errors if synch fails
|
||||
}
|
||||
}
|
||||
for (auto t : threads_) {
|
||||
if (t != nullptr) {
|
||||
|
||||
@@ -407,7 +407,10 @@ class Module(dict):
|
||||
instance).
|
||||
|
||||
Args:
|
||||
apply_fn (Callable): The function to apply to the modules.
|
||||
apply_fn (Callable): The function to apply to the modules which
|
||||
takes two parameters. The first parameter is the string path of
|
||||
the module (e.g. ``"model.layers.0.linear"``). The second
|
||||
parameter is the module object.
|
||||
|
||||
Returns:
|
||||
The module instance after updating submodules.
|
||||
|
||||
@@ -1445,7 +1445,7 @@ void init_ops(nb::module_& m) {
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def linspace(start, stop, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def linspace(start: scalar, stop: scalar, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generate ``num`` evenly spaced numbers over interval ``[start, stop]``.
|
||||
|
||||
|
||||
@@ -1238,8 +1238,18 @@ void init_transforms(nb::module_& m) {
|
||||
same in number, shape, and type as the inputs of ``fun`` (i.e. the ``primals``).
|
||||
|
||||
Returns:
|
||||
list(array): A list of the Jacobian-vector products which
|
||||
is the same in number, shape, and type of the inputs to ``fun``.
|
||||
tuple(list(array), list(array)): A tuple with the outputs of
|
||||
``fun`` in the first position and the Jacobian-vector products
|
||||
in the second position.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
outs, jvps = mx.jvp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))
|
||||
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"vjp",
|
||||
@@ -1277,8 +1287,18 @@ void init_transforms(nb::module_& m) {
|
||||
same in number, shape, and type as the outputs of ``fun``.
|
||||
|
||||
Returns:
|
||||
list(array): A list of the vector-Jacobian products which
|
||||
is the same in number, shape, and type of the outputs of ``fun``.
|
||||
tuple(list(array), list(array)): A tuple with the outputs of
|
||||
``fun`` in the first position and the vector-Jacobian products
|
||||
in the second position.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
outs, vjps = mx.vjp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))
|
||||
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"value_and_grad",
|
||||
|
||||
@@ -739,37 +739,69 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(out, expected, atol=1e-5))
|
||||
|
||||
def test_sdpa_grad(self):
|
||||
B, N_kv, T, D = (2, 8, 128, 64)
|
||||
scale = D**-0.5
|
||||
|
||||
f1 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale)
|
||||
f2 = lambda q, k, v: mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
|
||||
|
||||
f3 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale).sum()
|
||||
f4 = lambda q, k, v: (
|
||||
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
|
||||
).sum()
|
||||
|
||||
# High tolerance due to cuDNN SDPA kernel requiring tf32.
|
||||
tolerance = {"rtol": 1e-2, "atol": 1e-2}
|
||||
|
||||
def test_vjp(slow, fast, primals):
|
||||
cotan = mx.ones_like(primals[0])
|
||||
o1, vjp1 = mx.vjp(slow, primals, [cotan])
|
||||
o2, vjp2 = mx.vjp(fast, primals, [cotan])
|
||||
|
||||
self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance))
|
||||
for i in range(3):
|
||||
self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance))
|
||||
|
||||
def test_grad(slow, fast, args):
|
||||
g1 = mx.grad(slow)(*args)
|
||||
g2 = mx.grad(fast)(*args)
|
||||
|
||||
self.assertTrue(mx.allclose(g1, g2, **tolerance))
|
||||
|
||||
sdpa_mask_slow = lambda q, k, v, mask: mlx_ref_attn(
|
||||
q, k, v, scale=scale, mask=mask
|
||||
)
|
||||
sdpa_mask_fast = lambda q, k, v, mask: mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=scale, mask=mask
|
||||
)
|
||||
|
||||
loss_mask_slow = lambda q, k, v, mask: mlx_ref_attn(
|
||||
q, k, v, scale=scale, mask=mask
|
||||
).sum()
|
||||
loss_mask_fast = lambda q, k, v, mask: (
|
||||
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||
).sum()
|
||||
|
||||
B, N_kv, T, D = (2, 8, 128, 64)
|
||||
scale = D**-0.5
|
||||
|
||||
for N_q in (8, 32):
|
||||
q = mx.random.normal(shape=(B, N_q, T, D), dtype=mx.float16)
|
||||
k = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
|
||||
v = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
|
||||
|
||||
cotan = mx.ones_like(q)
|
||||
o1, vjp1 = mx.vjp(f1, [q, k, v], [cotan])
|
||||
o2, vjp2 = mx.vjp(f2, [q, k, v], [cotan])
|
||||
mask_additive = mx.random.normal((B, N_q, T, T), dtype=mx.float16)
|
||||
mask_bool = mx.random.uniform(0, 1, (B, N_q, T, T), dtype=mx.float16) < 0.5
|
||||
|
||||
self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance))
|
||||
for i in range(3):
|
||||
self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance))
|
||||
for mask in (mask_additive, mask_bool):
|
||||
test_vjp(sdpa_mask_slow, sdpa_mask_fast, [q, k, v, mask])
|
||||
test_grad(loss_mask_slow, loss_mask_fast, [q, k, v, mask])
|
||||
|
||||
g1 = mx.grad(f3)(q, k, v)
|
||||
g2 = mx.grad(f4)(q, k, v)
|
||||
for mask in (None, "causal"):
|
||||
sdpa_slow = lambda q, k, v: mlx_ref_attn(
|
||||
q, k, v, scale=scale, mask=mask
|
||||
)
|
||||
sdpa_fast = lambda q, k, v: mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=scale, mask=mask
|
||||
)
|
||||
test_vjp(sdpa_slow, sdpa_fast, [q, k, v])
|
||||
|
||||
self.assertTrue(mx.allclose(g1, g2, **tolerance))
|
||||
loss_slow = lambda q, k, v: mlx_ref_attn(
|
||||
q, k, v, scale=scale, mask=mask
|
||||
).sum()
|
||||
loss_fast = lambda q, k, v: mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=scale, mask=mask
|
||||
).sum()
|
||||
test_grad(loss_slow, loss_fast, [q, k, v])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user