Compare commits

..

9 Commits

Author SHA1 Message Date
Angelos Katharopoulos
7a82455b35 Add a no_ibv 2025-11-20 12:52:35 -08:00
Angelos Katharopoulos
643a9a6ba6 Add empty sum_scatter 2025-11-20 12:36:19 -08:00
Angelos Katharopoulos
82097a8f85 Add send/recv 2025-11-20 12:36:19 -08:00
Angelos Katharopoulos
29d9cd836a Make sure that there is space for work completions 2025-11-20 12:36:19 -08:00
Angelos Katharopoulos
2d10020178 Add working reduce and semi-working all gather 2025-11-20 12:36:19 -08:00
Angelos Katharopoulos
031e62539a Fix ring 2025-11-20 12:36:19 -08:00
Angelos Katharopoulos
97f74543b1 Fix side channel initialization for more than 2 peers 2025-11-20 12:36:19 -08:00
Angelos Katharopoulos
0dbe63397d All gather 2025-11-20 12:36:19 -08:00
Angelos Katharopoulos
873df2e0f7 Initial working all reduce 2025-11-20 12:36:16 -08:00
52 changed files with 2176 additions and 1627 deletions

View File

@@ -1,13 +1,18 @@
name: 'Build CUDA wheel'
description: 'Build CUDA wheel'
inputs:
toolkit:
description: 'The CUDA toolkit'
required: true
runs:
using: "composite"
steps:
- name: Build package
shell: bash
env:
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc
run: |
pip install auditwheel build patchelf setuptools
python setup.py clean --all

26
.github/actions/build-cuda/action.yml vendored Normal file
View File

@@ -0,0 +1,26 @@
name: 'Build and Test with CUDA'
description: 'Build and test MLX with CUDA'
inputs:
toolkit:
description: 'The CUDA toolkit'
required: true
runs:
using: "composite"
steps:
- name: Install Python package
shell: bash
env:
DEBUG: 1
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc
run: pip install --no-build-isolation -e ".[dev]" -v
- name: Build CPP only
shell: bash
run: |
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j $(nproc)

View File

@@ -1,41 +1,25 @@
name: 'Build and Test on Linux'
inputs:
toolkit:
description: 'The toolkit to build with'
required: false
default: 'cpu'
description: 'Build and test MLX on Linux'
runs:
using: "composite"
steps:
- name: Install Python package
id: python_build
shell: sh
env:
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
DEBUG: 1
CMAKE_ARGS: >-
-DCMAKE_COMPILE_WARNING_AS_ERROR=ON
-DMLX_BUILD_CUDA=${{ startsWith(inputs.toolkit, 'cuda') && 'ON' || 'OFF' }}
run: |
if ${{ startsWith(inputs.toolkit, 'cuda') && runner.arch == 'arm64' }} ; then
# There is no GPU in arm64 runner, use a common arch.
CMAKE_ARGS="$CMAKE_ARGS -DMLX_CUDA_ARCHITECTURES=90a"
# Can not build tests when the built executables can not run.
CMAKE_ARGS="$CMAKE_ARGS -DMLX_BUILD_TESTS=OFF"
fi
pip install --no-build-isolation -e ".[dev]" -v
# Pass the CMAKE_ARGS to following steps.
echo CMAKE_ARGS="$CMAKE_ARGS" >> $GITHUB_OUTPUT
run: pip install --no-build-isolation -e ".[dev]" -v
- name: Generate package stubs
shell: sh
run: |
pip install typing_extensions
python setup.py generate_stubs
- name: Build CPP only
shell: bash
run: |
cmake . -B build -DCMAKE_BUILD_TYPE=Debug ${{ steps.python_build.outputs.CMAKE_ARGS }}
cmake --build build -j $(nproc)
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j $(nproc)

View File

@@ -51,6 +51,8 @@ runs:
# Note: the CI machine does not meet CUDA 13's driver requirement.
# Compatibility matrix:
# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
# The `nvcc` is installed into `/usr/local/cuda-VERSION/bin/nvcc` - but
# it's *not* on the default toolkit path.
PACKAGES: |
{
"cuda-12.6": "libcudnn9-dev-cuda-12 cuda-toolkit-12-6",
@@ -58,16 +60,13 @@ runs:
"cuda-13.0": "libcudnn9-dev-cuda-13 cuda-toolkit-13-0"
}
run: |
# The CUDA binaries are hosted in the "sbsa" repo, the "arm64" repo is
# Jetson specific. SBSA means Arm Server Base System Architecture.
ARCH=${{ runner.arch == 'arm64' && 'sbsa' || 'x86_64' }}
export ARCH=${{ runner.arch == 'arm64' && 'arm64' || 'x86_64' }}
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$ARCH/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install -y \
libnccl2 libnccl-dev \
${{ fromJson(env.PACKAGES)[inputs.toolkit] }}
echo "/usr/local/${{ inputs.toolkit }}/bin" >> $GITHUB_PATH
- name: CUDA packages and driver report
if: ${{ startsWith(inputs.toolkit, 'cuda') }}

View File

@@ -1,8 +1,8 @@
name: 'Run Linux tests'
inputs:
has-gpu:
description: 'Run GPU tests'
cpu-only:
description: 'Skip GPU tests'
required: false
default: false
@@ -17,7 +17,7 @@ runs:
echo "::endgroup::"
- name: Run distributed tests
if: ${{ inputs.has-gpu == 'false' }}
if: ${{ inputs.cpu-only == 'true' }}
shell: bash
run: |
echo "::group::Distributed tests"
@@ -30,7 +30,7 @@ runs:
echo "::endgroup::"
- name: Run Python tests - CPU
if: ${{ inputs.has-gpu == 'false' }}
if: ${{ inputs.cpu-only == 'true' }}
shell: bash
env:
DEVICE: cpu
@@ -40,7 +40,7 @@ runs:
echo "::endgroup::"
- name: Run Python tests - GPU
if: ${{ inputs.has-gpu == 'true' }}
if: ${{ inputs.cpu-only == 'false' }}
shell: bash
env:
DEVICE: gpu
@@ -59,7 +59,7 @@ runs:
echo "::endgroup::"
- name: Run CPP tests - GPU
if: ${{ inputs.has-gpu == 'true' }}
if: ${{ inputs.cpu-only == 'false' }}
shell: bash
env:
DEVICE: gpu

View File

@@ -10,7 +10,7 @@ jobs:
build:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
deploy:

View File

@@ -16,7 +16,7 @@ jobs:
python_version: ["3.10", "3.14"]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux-release
with:
@@ -46,12 +46,14 @@ jobs:
- ubuntu-22.04-arm
runs-on: ${{ matrix.runner }}
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux
- uses: ./.github/actions/test-linux
with:
cpu-only: true
build_mac_release:
if: github.repository == 'ml-explore/mlx'
@@ -60,7 +62,7 @@ jobs:
python-version: ["3.10", "3.13"]
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
@@ -80,7 +82,7 @@ jobs:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22-large
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
toolkit: 'cuda-12.9'

View File

@@ -13,55 +13,33 @@ permissions:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
cancel-in-progress: ${{ github.ref != 'refs/head/main' }}
jobs:
check_lint:
name: Check Lint
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v5
- uses: pre-commit/action@v3.0.1
linux_build_and_test:
name: Linux (cpu, ${{ matrix.arch }})
needs: check_lint
strategy:
fail-fast: false
matrix:
arch: ['x86_64', 'aarch64']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
runner:
- ubuntu-22.04
- ubuntu-22.04-arm
fail-fast: false
runs-on: ${{ matrix.runner }}
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
- uses: ./.github/actions/test-linux
cuda_build_and_test:
name: Linux (${{ matrix.toolkit }}, ${{ matrix.arch }})
if: github.repository == 'ml-explore/mlx'
needs: check_lint
strategy:
fail-fast: false
matrix:
arch: ['x86_64', 'aarch64']
toolkit: ['cuda-12.6', 'cuda-12.9']
runs-on: ${{ matrix.arch == 'x86_64' && 'gpu-t4-4-core' || 'ubuntu-22.04-arm' }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/build-linux
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/test-linux
if: matrix.arch == 'x86_64'
with:
has-gpu: true
cpu-only: true
mac_build_and_test:
name: macOS (${{ matrix.macos-target }})
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
@@ -71,21 +49,38 @@ jobs:
MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos-target }}
needs: check_lint
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
- uses: ./.github/actions/build-macos
cuda_build_and_test:
if: github.repository == 'ml-explore/mlx'
strategy:
fail-fast: false
matrix:
toolkit: ['cuda-12.6', 'cuda-12.9']
runs-on: gpu-t4-4-core
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/build-cuda
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/test-linux
build_documentation:
name: Build Documentation
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22.04
needs: check_lint
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
linux_fedora_build_cpp:
name: Linux Fedora (${{ matrix.arch }})
name: Linux Fedora CPP Build (${{ matrix.arch }})
needs: check_lint
strategy:
fail-fast: false
@@ -101,7 +96,7 @@ jobs:
image: fedora:42
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v5
- name: CPP Build Test - No Release
run: |

View File

@@ -25,7 +25,7 @@ jobs:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
deploy_documentation:
@@ -53,7 +53,7 @@ jobs:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
@@ -86,7 +86,7 @@ jobs:
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
@@ -133,12 +133,14 @@ jobs:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
toolkit: 'cuda-12.9'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
toolkit: 'cuda-12.9'
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:

View File

@@ -119,6 +119,10 @@ if(MLX_BUILD_METAL)
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-path"
OUTPUT_VARIABLE CMAKE_OSX_SYSROOT
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0)
message(

View File

@@ -12,167 +12,6 @@ namespace mlx::core {
namespace {
template <typename T>
complex64_t to_complex(T r, T i) {
return {static_cast<float>(r), static_cast<float>(i)};
}
template <typename T, class Enable = void>
struct EigWork {};
template <typename T>
struct EigWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
using O = complex64_t;
char jobl;
char jobr;
int N;
int lwork;
int info;
std::vector<array::Data> buffers;
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1) {
T work;
int n_vecs_l = compute_eigenvectors ? N_ : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&info);
lwork = static_cast<int>(work);
buffers.emplace_back(allocator::malloc(sizeof(T) * N * 2));
if (compute_eigenvectors) {
buffers.emplace_back(allocator::malloc(sizeof(T) * N * N * 2));
}
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
}
void run(T* a, O* values, O* vectors) {
auto eig_tmp = static_cast<T*>(buffers[0].buffer.raw_ptr());
T* vec_tmp = nullptr;
if (vectors) {
vec_tmp = static_cast<T*>(buffers[1].buffer.raw_ptr());
}
auto work = static_cast<T*>(buffers.back().buffer.raw_ptr());
int n_vecs_l = vectors ? N : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
a,
&N,
eig_tmp,
eig_tmp + N,
vectors ? vec_tmp : nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
work,
&lwork,
&info);
for (int i = 0; i < N; ++i) {
values[i] = to_complex(eig_tmp[i], eig_tmp[N + i]);
}
if (vectors) {
for (int i = 0; i < N; ++i) {
if (values[i].imag() != 0) {
for (int j = 0; j < N; ++j) {
vectors[i * N + j] =
to_complex(vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]);
vectors[(i + 1) * N + j] =
to_complex(vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]);
}
i += 1;
} else {
for (int j = 0; j < N; ++j) {
vectors[i * N + j] = to_complex(vec_tmp[i * N + j], T(0.0));
}
}
}
}
}
};
template <>
struct EigWork<std::complex<float>> {
using T = std::complex<float>;
using R = float;
using O = T;
char jobl;
char jobr;
int N;
int lwork;
int lrwork;
int info;
std::vector<array::Data> buffers;
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1), lrwork(2 * N_) {
T work;
R rwork;
int n_vecs_l = compute_eigenvectors ? N_ : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&rwork,
&info);
lwork = static_cast<int>(work.real());
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
}
void run(T* a, T* values, T* vectors) {
int n_vecs_l = vectors ? N : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
a,
&N,
values,
vectors,
&n_vecs_l,
nullptr,
&n_vecs_r,
static_cast<T*>(buffers[0].buffer.raw_ptr()),
&lwork,
static_cast<R*>(buffers[1].buffer.raw_ptr()),
&info);
}
};
template <typename T>
void eig_impl(
array& a,
@@ -180,39 +19,101 @@ void eig_impl(
array& values,
bool compute_eigenvectors,
Stream stream) {
using OT = std::complex<T>;
auto a_ptr = a.data<T>();
auto val_ptr = values.data<complex64_t>();
auto eig_ptr = values.data<OT>();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(values);
complex64_t* vec_ptr = nullptr;
OT* vec_ptr = nullptr;
if (compute_eigenvectors) {
encoder.set_output_array(vectors);
vec_ptr = vectors.data<complex64_t>();
vec_ptr = vectors.data<OT>();
}
encoder.dispatch([a_ptr,
val_ptr,
vec_ptr,
eig_ptr,
compute_eigenvectors,
N = vectors.shape(-1),
size = vectors.size()]() mutable {
// Work query
char jobr = 'N';
char jobl = compute_eigenvectors ? 'V' : 'N';
int n_vecs_r = 1;
int n_vecs_l = compute_eigenvectors ? N : 1;
int lwork = -1;
int info;
{
T work;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&info);
lwork = static_cast<int>(work);
}
EigWork<T> work(jobl, jobr, N, compute_eigenvectors);
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
auto vec_tmp_data =
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
for (size_t i = 0; i < size / (N * N); ++i) {
work.run(a_ptr, val_ptr, vec_ptr);
a_ptr += N * N;
val_ptr += N;
geev<T>(
&jobl,
&jobr,
&N,
a_ptr,
&N,
eig_tmp,
eig_tmp + N,
vec_tmp,
&n_vecs_l,
nullptr,
&n_vecs_r,
static_cast<T*>(work_buf.buffer.raw_ptr()),
&lwork,
&info);
for (int i = 0; i < N; ++i) {
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
}
if (vec_ptr) {
for (int i = 0; i < N; ++i) {
if (eig_ptr[i].imag() != 0) {
// This vector and the next are a pair
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
vec_ptr[(i + 1) * N + j] = {
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
}
i += 1;
} else {
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
}
}
}
vec_ptr += N * N;
}
if (work.info != 0) {
a_ptr += N * N;
eig_ptr += N;
if (info != 0) {
std::stringstream msg;
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
<< work.info;
<< info;
throw std::runtime_error(msg.str());
}
}
@@ -264,17 +165,8 @@ void Eig::eval_cpu(
case float32:
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
break;
case float64:
eig_impl<double>(
a_copy, vectors, values, compute_eigenvectors_, stream());
break;
case complex64:
eig_impl<std::complex<float>>(
a_copy, vectors, values, compute_eigenvectors_, stream());
break;
default:
throw std::runtime_error(
"[Eig::eval_cpu] only supports float32, float64, or complex64.");
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
}
}

View File

@@ -45,7 +45,9 @@
INSTANTIATE_LAPACK_REAL(geqrf)
INSTANTIATE_LAPACK_REAL(orgqr)
INSTANTIATE_LAPACK_REAL(syevd)
INSTANTIATE_LAPACK_REAL(geev)
INSTANTIATE_LAPACK_REAL(potrf)
INSTANTIATE_LAPACK_REAL(gesdd)
INSTANTIATE_LAPACK_REAL(getrf)
INSTANTIATE_LAPACK_REAL(getri)
INSTANTIATE_LAPACK_REAL(trtri)
@@ -61,20 +63,3 @@ INSTANTIATE_LAPACK_REAL(trtri)
}
INSTANTIATE_LAPACK_COMPLEX(heevd)
#define INSTANTIATE_LAPACK_ALL(FUNC) \
template <typename T, typename... Args> \
void FUNC(Args... args) { \
if constexpr (std::is_same_v<T, float>) { \
MLX_LAPACK_FUNC(s##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, double>) { \
MLX_LAPACK_FUNC(d##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, std::complex<float>>) { \
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
} \
}
INSTANTIATE_LAPACK_ALL(geev)
INSTANTIATE_LAPACK_ALL(gesdd)

View File

@@ -8,183 +8,6 @@
namespace mlx::core {
template <typename T, class Enable = void>
struct SVDWork {};
template <typename T>
struct SVDWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
using R = T;
int N;
int M;
int K;
int lda;
int ldu;
int ldvt;
char jobz;
std::vector<array::Data> buffers;
int lwork;
SVDWork(int N, int M, int K, char jobz)
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
int lwork_query = -1;
int info;
// Compute workspace size.
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
lwork = workspace_dimension;
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
}
void run(T* a, R* s, T* u, T* vt) {
int info;
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ a,
/* lda = */ &lda,
/* s = */ s,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ u,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ vt,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(buffers[1].buffer.raw_ptr()),
/* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
}
};
template <>
struct SVDWork<std::complex<float>> {
using T = std::complex<float>;
using R = float;
int N;
int M;
int K;
int lda;
int ldu;
int ldvt;
char jobz;
std::vector<array::Data> buffers;
int lwork;
SVDWork(int N, int M, int K, char jobz)
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
const int lrwork =
jobz == 'A' ? std::max(1, 5 * K * K + 5 * K) : std::max(1, 7 * K);
buffers.emplace_back(allocator::malloc(sizeof(float) * lrwork));
int lwork_query = -1;
int work_query = -1;
int info;
// Compute workspace size.
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
lwork = workspace_dimension.real();
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
}
void run(T* a, R* s, T* u, T* vt) {
int info;
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ a,
/* lda = */ &lda,
/* s = */ s,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ u,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ vt,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(buffers[2].buffer.raw_ptr()),
/* lwork = */ &lwork,
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
}
};
template <typename T>
void svd_impl(
const array& a,
@@ -204,8 +27,6 @@ void svd_impl(
const int N = a.shape(-1);
const int K = std::min(M, N);
using R = typename SVDWork<T>::R;
size_t num_matrices = a.size() / (M * N);
// lapack clobbers the input, so we have to make a copy.
@@ -221,7 +42,7 @@ void svd_impl(
encoder.set_input_array(a);
auto in_ptr = in.data<T>();
T* u_ptr;
R* s_ptr;
T* s_ptr;
T* vt_ptr;
if (compute_uv) {
@@ -237,7 +58,7 @@ void svd_impl(
encoder.set_output_array(s);
encoder.set_output_array(vt);
s_ptr = s.data<R>();
s_ptr = s.data<T>();
u_ptr = u.data<T>();
vt_ptr = vt.data<T>();
} else {
@@ -247,26 +68,96 @@ void svd_impl(
encoder.set_output_array(s);
s_ptr = s.data<R>();
s_ptr = s.data<T>();
u_ptr = nullptr;
vt_ptr = nullptr;
}
encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() {
auto jobz = (u_ptr) ? 'A' : 'N';
SVDWork<T> svd_work(N, M, K, jobz);
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
const int lda = N;
// U of shape M x M. (N x N in lapack).
const int ldu = N;
// Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M;
auto jobz = (u_ptr) ? "A" : "N";
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)};
static const int lwork_query = -1;
int info;
// Compute workspace size.
gesdd<T>(
/* jobz = */ jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_dimension;
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
// Loop over matrices.
for (int i = 0; i < num_matrices; i++) {
svd_work.run(
in_ptr + M * N * i,
s_ptr + K * i,
vt_ptr ? vt_ptr + N * N * i : nullptr,
u_ptr ? u_ptr + M * M * i : nullptr);
gesdd<T>(
/* jobz = */ jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ in_ptr + M * N * i,
/* lda = */ &lda,
/* s = */ s_ptr + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ u_ptr ? u_ptr + M * M * i : nullptr,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
}
});
encoder.add_temporary(in);
}
template <typename T>
void compute_svd(
const array& a,
bool compute_uv,
std::vector<array>& outputs,
Stream stream) {}
void SVD::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
@@ -277,12 +168,9 @@ void SVD::eval_cpu(
case float64:
svd_impl<double>(inputs[0], outputs, compute_uv_, stream());
break;
case complex64:
svd_impl<std::complex<float>>(inputs[0], outputs, compute_uv_, stream());
break;
default:
throw std::runtime_error(
"[SVD::eval_cpu] only supports float32, float64, or complex64.");
"[SVD::eval_cpu] only supports float32 or float64.");
}
}

View File

@@ -123,21 +123,14 @@ if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
endif()
# Use native CUDA arch by default.
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
# managed memory.
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
execute_process(
COMMAND __nvcc_device_query
COMMAND bash detect_cuda_arch.sh
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
OUTPUT_VARIABLE MLX_CUDA_ARCHITECTURES
OUTPUT_STRIP_TRAILING_WHITESPACE)
set(UPGRADABLE_ARCHITECTURES "90;100;121")
if(MLX_CUDA_ARCHITECTURES STREQUAL "")
message(
FATAL_ERROR
"Can not get native CUDA arch, must set MLX_CUDA_ARCHITECTURES")
elseif(MLX_CUDA_ARCHITECTURES IN_LIST UPGRADABLE_ARCHITECTURES)
# Use arch-specific compute capability whenever possible.
set(MLX_CUDA_ARCHITECTURES "${MLX_CUDA_ARCHITECTURES}a")
endif()
endif()
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES

View File

@@ -154,21 +154,17 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
}
lock.unlock();
if (!buf) {
buf = new CudaBuffer{nullptr, size, device};
cudaError_t err;
void* data = nullptr;
if (device == -1) {
err = cudaMallocManaged(&data, size);
err = cudaMallocManaged(&buf->data, size);
} else {
err = cudaMallocAsync(&data, size, stream);
err = cudaMallocAsync(&buf->data, size, stream);
}
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
if (!data) {
return Buffer{nullptr};
}
buf = new CudaBuffer{data, size, device};
}
lock.lock();
}

View File

@@ -29,10 +29,6 @@ class CudaHandle {
}
~CudaHandle() {
// Skip if there was an error to avoid throwing in the destructors
if (cudaPeekAtLastError() != cudaSuccess) {
return;
}
reset();
}

View File

@@ -0,0 +1,13 @@
#!/bin/bash
arch=`__nvcc_device_query`
case "$arch" in
"90")
echo "90a" ;;
"100")
echo "100a" ;;
"121")
echo "121a" ;;
*)
echo "native" ;;
esac

View File

@@ -24,19 +24,10 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
}
bool use_cuda_graphs() {
static bool use_graphs = env::get_var("MLX_USE_CUDA_GRAPHS", true);
return use_graphs;
}
const char* save_cuda_graphs_dot_file() {
static const char* filename = []() -> const char* {
const char* env = std::getenv("MLX_SAVE_CUDA_GRAPHS_DOT_FILE");
if (env && std::strlen(env) == 0) {
return nullptr;
}
return env;
static bool use_graphs = []() {
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
}();
return filename;
return use_graphs;
}
} // namespace
@@ -124,17 +115,18 @@ CommandEncoder::ConcurrentContext::~ConcurrentContext() {
}
// Use an empty graph node for synchronization
CommandEncoder::GraphNode empty{NULL, "E", std::to_string(enc.node_count_++)};
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
enc.empty_node_count_++;
CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0));
// Insert the concurrent -> empty node dependencies
for (auto& from : enc.concurrent_nodes_) {
enc.from_nodes_.push_back(from.node);
enc.to_nodes_.push_back(empty.node);
enc.graph_deps_key_ += from.id;
enc.graph_deps_key_ += "-";
enc.graph_deps_key_ += empty.id;
enc.graph_deps_key_ += "-";
enc.graph_key_ += from.id;
enc.graph_key_ += from.node_type;
enc.graph_key_ += empty.id;
enc.graph_key_ += empty.node_type;
}
// Insert the input -> concurrent node dependencies without updating output
@@ -149,6 +141,9 @@ CommandEncoder::ConcurrentContext::~ConcurrentContext() {
}
void CommandEncoder::insert_graph_dependencies(GraphNode node) {
if (node.node_type == 'G') {
graph_node_count_++;
}
node.id = std::to_string(node_count_++);
if (in_concurrent_) {
concurrent_nodes_.push_back(std::move(node));
@@ -160,10 +155,6 @@ void CommandEncoder::insert_graph_dependencies(GraphNode node) {
}
void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
for (auto& node : nodes) {
graph_nodes_key_ += node.node_type;
graph_nodes_key_ += "-";
}
std::vector<GraphNode> deps;
{
// Dependencies must be added in the same order to produce a consistent
@@ -191,10 +182,10 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
for (auto& to : nodes) {
from_nodes_.push_back(from.node);
to_nodes_.push_back(to.node);
graph_deps_key_ += from.id;
graph_deps_key_ += "-";
graph_deps_key_ += to.id;
graph_deps_key_ += "-";
graph_key_ += from.id;
graph_key_ += from.node_type;
graph_key_ += to.id;
graph_key_ += to.node_type;
}
}
}
@@ -318,46 +309,13 @@ void CommandEncoder::add_kernel_node(
void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
cudaGraphNode_t node;
CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, &params));
insert_graph_dependencies(GraphNode{node, "K"});
insert_graph_dependencies(GraphNode{node, 'K'});
}
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
CUgraphNode node;
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, &params));
insert_graph_dependencies(GraphNode{node, "K"});
}
bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
// CUDA graphs do not get updated correctly if a kernel node getting updated
// has a different cluster shape than the node it's being updated with.
size_t num_nodes = 0;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));
if (num_nodes == 0) {
return true;
}
std::vector<cudaGraphNode_t> nodes(num_nodes);
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
for (const auto& node : nodes) {
cudaGraphNodeType type;
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
if (type != cudaGraphNodeTypeKernel) {
return false;
}
cudaLaunchAttributeValue cluster_dim;
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
// Only dim.x can be greater than 1
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
return false;
}
// Only one child node allowed when subgraph uses clusters
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
return false;
}
cluster_dim_x = cluster_dim.clusterDim.x;
}
return true;
insert_graph_dependencies(GraphNode{node, 'K'});
}
void CommandEncoder::add_graph_node(cudaGraph_t child) {
@@ -370,11 +328,8 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
return;
}
cudaGraphNode_t node;
int cluster_dim_x = 0;
is_graph_updatable_ = is_graph_updatable(child, cluster_dim_x);
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies(
GraphNode{node, "G" + std::to_string(cluster_dim_x)});
insert_graph_dependencies(GraphNode{node, 'G'});
}
bool CommandEncoder::needs_commit() {
@@ -399,53 +354,44 @@ void CommandEncoder::commit() {
from_nodes_.size()));
}
device_.make_current();
graph_key_ += ".";
graph_key_ += std::to_string(node_count_);
graph_key_ += ".";
graph_key_ += std::to_string(graph_node_count_);
graph_key_ += ".";
graph_key_ += std::to_string(empty_node_count_);
if (!is_graph_updatable_) {
CudaGraphExec graph_exec;
graph_exec.instantiate(graph_);
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
} else {
auto graph_key = graph_nodes_key_ + ":" + graph_deps_key_;
auto& graph_exec = graph_cache_[graph_key];
CudaGraphExec& graph_exec = graph_cache_[graph_key_];
if (graph_exec != nullptr) {
cudaGraphExecUpdateResult update_result;
if (graph_exec != nullptr) {
cudaGraphExecUpdateResult update_result;
#if CUDART_VERSION >= 12000
cudaGraphExecUpdateResultInfo info;
cudaGraphExecUpdate(graph_exec, graph_, &info);
update_result = info.result;
cudaGraphExecUpdateResultInfo info;
cudaGraphExecUpdate(graph_exec, graph_, &info);
update_result = info.result;
#else
cudaGraphNode_t error_node;
cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result);
cudaGraphNode_t error_node;
cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result);
#endif // CUDART_VERSION >= 12000
if (update_result != cudaGraphExecUpdateSuccess) {
cudaGetLastError(); // reset error
graph_exec.reset();
}
if (update_result != cudaGraphExecUpdateSuccess) {
cudaGetLastError(); // reset error
graph_exec.reset();
}
if (graph_exec == nullptr) {
graph_exec.instantiate(graph_);
}
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
}
// Save cuda graph to dot file
if (const char* filename = save_cuda_graphs_dot_file(); filename) {
static int count = 0;
auto path = fmt::format("{}_{}.dot", filename, ++count);
CHECK_CUDA_ERROR(cudaGraphDebugDotPrint(graph_, path.c_str(), 0));
if (graph_exec == nullptr) {
graph_exec.instantiate(graph_);
}
device_.make_current();
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
// Reset state
graph_node_count_ = 0;
empty_node_count_ = 0;
from_nodes_.clear();
to_nodes_.clear();
graph_deps_key_.clear();
graph_nodes_key_.clear();
graph_key_.clear();
node_map_.clear();
graph_ = CudaGraph(device_);
is_graph_updatable_ = true;
}
// Put completion handlers in a batch.

View File

@@ -106,9 +106,8 @@ class CommandEncoder {
cudaGraphNode_t node;
// K = kernel
// E = empty
// G* = subgraph (with metadata)
// Symbols ':', '-' are reserved as separators
std::string node_type;
// G = subgraph
char node_type;
std::string id;
};
@@ -120,11 +119,12 @@ class CommandEncoder {
CudaGraph graph_;
Worker worker_;
char node_count_{0};
char graph_node_count_{0};
char empty_node_count_{0};
bool in_concurrent_{false};
std::vector<cudaGraphNode_t> from_nodes_;
std::vector<cudaGraphNode_t> to_nodes_;
std::string graph_nodes_key_;
std::string graph_deps_key_;
std::string graph_key_;
std::vector<GraphNode> concurrent_nodes_;
std::vector<std::shared_ptr<array::Data>> temporaries_;
LRUCache<std::string, CudaGraphExec> graph_cache_;
@@ -132,7 +132,6 @@ class CommandEncoder {
std::vector<std::uintptr_t> active_outputs_;
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
size_t bytes_in_graph_{0};
bool is_graph_updatable_{true};
int max_ops_per_graph_;
int max_mb_per_graph_;
};

View File

@@ -305,7 +305,6 @@ void Event::wait() {
} else {
event->atomic->wait(value());
}
CHECK_CUDA_ERROR(cudaPeekAtLastError());
}
void Event::wait(Stream s) {

View File

@@ -22,28 +22,26 @@ inline __device__ float2 plus_f2(const float2& a, const float2& b) {
}
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
template <typename T, int BLOCK_DIM, int GROUP_DIM = WARP_SIZE>
template <typename T, int BLOCK_DIM>
struct BlockBroadcastReduce {
using TempStorage = T[std::max(BLOCK_DIM / WARP_SIZE, 1)];
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
static_assert(BLOCK_DIM % WARP_SIZE == 0);
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
cg::thread_block& block;
TempStorage& temp;
template <typename Op>
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
auto warp = cg::tiled_partition<GROUP_DIM>(block);
auto warp = cg::tiled_partition<WARP_SIZE>(block);
T x = cg::reduce(warp, input, op);
if constexpr (BLOCK_DIM > GROUP_DIM) {
if (warp.thread_rank() == 0) {
temp[warp.meta_group_rank()] = x;
}
block.sync();
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
: init_value;
return cg::reduce(warp, x, op);
} else {
return x;
if (warp.thread_rank() == 0) {
temp[warp.meta_group_rank()] = x;
}
block.sync();
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
: init_value;
return cg::reduce(warp, x, op);
}
__device__ T Sum(const T& input) {
@@ -51,52 +49,6 @@ struct BlockBroadcastReduce {
}
};
template <typename T, int BLOCK_DIM, int REDUCE_DIM, int N_READS = 4>
__global__ void rms_norm_small(
const T* x,
const T* w,
T* out,
float eps,
uint32_t axis_size,
uint32_t n_rows,
int64_t w_stride) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM, REDUCE_DIM>;
__shared__ typename BlockReduceT::TempStorage temp;
auto row =
(grid.block_rank() * block.dim_threads().y) + block.thread_index().y;
if (row >= n_rows) {
return;
}
x += row * axis_size;
out += row * axis_size;
// Normalizer.
float normalizer = 0;
auto index = block.thread_index().x;
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
float t = static_cast<float>(xn[i]);
normalizer += t * t;
}
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
normalizer = rsqrt(normalizer / axis_size + eps);
// Outputs.
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
float y = static_cast<float>(xn[i]) * normalizer;
xn[i] = wn[i] * static_cast<T>(y);
}
store_vector<N_READS>(out, index, xn, axis_size);
}
template <typename T, int BLOCK_DIM, int N_READS = 4>
__global__ void rms_norm(
const T* x,
@@ -142,74 +94,6 @@ __global__ void rms_norm(
}
}
template <
typename T,
bool HAS_W,
int BLOCK_DIM,
int REDUCE_DIM,
int N_READS = 4>
__global__ void rms_norm_vjp_small(
const T* x,
const T* w,
const T* g,
T* gx,
T* gw,
float eps,
int32_t axis_size,
int32_t n_rows,
int64_t w_stride) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM, REDUCE_DIM>;
__shared__ typename BlockReduceF2::TempStorage temp;
auto row =
(grid.block_rank() * block.dim_threads().y) + block.thread_index().y;
if (row >= n_rows) {
return;
}
x += row * axis_size;
g += row * axis_size;
gx += row * axis_size;
gw += row * axis_size;
// Normalizer.
float2 factors = {};
auto index = block.thread_index().x;
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
for (int i = 0; i < N_READS; i++) {
float t = static_cast<float>(xn[i]);
float wi = wn[i];
float gi = gn[i];
float wg = wi * gi;
factors = plus_f2(factors, {wg * t, t * t});
}
factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});
float meangwx = factors.x / axis_size;
float normalizer = rsqrt(factors.y / axis_size + eps);
float normalizer3 = normalizer * normalizer * normalizer;
// Outputs.
for (int i = 0; i < N_READS; i++) {
float xi = xn[i];
float wi = wn[i];
float gi = gn[i];
xn[i] = static_cast<T>(normalizer * wi * gi - xi * meangwx * normalizer3);
if constexpr (HAS_W) {
wn[i] = static_cast<T>(gi * xi * normalizer);
}
}
store_vector<N_READS>(gx, index, xn, axis_size);
if constexpr (HAS_W) {
store_vector<N_READS>(gw, index, wn, axis_size);
}
}
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
__global__ void rms_norm_vjp(
const T* x,
@@ -223,8 +107,12 @@ __global__ void rms_norm_vjp(
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>;
__shared__ typename BlockReduceF2::TempStorage temp;
__shared__ union {
typename BlockReduceF::TempStorage f;
typename BlockReduceF2::TempStorage f2;
} temp;
x += grid.block_rank() * axis_size;
g += grid.block_rank() * axis_size;
@@ -246,7 +134,7 @@ __global__ void rms_norm_vjp(
factors = plus_f2(factors, {wg * t, t * t});
}
}
factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});
factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {});
float meangwx = factors.x / axis_size;
float normalizer = rsqrt(factors.y / axis_size + eps);
float normalizer3 = normalizer * normalizer * normalizer;
@@ -281,43 +169,6 @@ bool RMSNorm::use_fallback(Stream s) {
return s.device == Device::cpu;
}
template <int n_per_thread, typename F>
void dispatch_group_dim(int axis_size, F&& f) {
if (axis_size <= n_per_thread * 8) {
f(std::integral_constant<int, 8>{},
std::integral_constant<int, 1>(),
std::integral_constant<int, 16>());
} else if (axis_size <= n_per_thread * 16) {
f(std::integral_constant<int, 16>{},
std::integral_constant<int, 1>(),
std::integral_constant<int, 8>());
} else if (axis_size <= n_per_thread * 32) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 1>(),
std::integral_constant<int, 4>());
} else if (axis_size <= n_per_thread * 32 * 2) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 2>(),
std::integral_constant<int, 2>());
} else if (axis_size <= n_per_thread * 32 * 4) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 4>(),
std::integral_constant<int, 1>());
} else if (axis_size <= n_per_thread * 32 * 8) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 8>(),
std::integral_constant<int, 1>());
} else if (axis_size <= n_per_thread * 32 * 16) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 16>(),
std::integral_constant<int, 1>());
} else {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 32>(),
std::integral_constant<int, 1>());
}
}
// TODO: There are duplicate code with backend/metal/normalization.cpp
void RMSNorm::eval_gpu(
const std::vector<array>& inputs,
@@ -365,33 +216,12 @@ void RMSNorm::eval_gpu(
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
if (axis_size <= N_READS * 1024) {
dispatch_group_dim<N_READS>(
axis_size, [&](auto group_dim, auto n_groups, auto groups_per_block) {
constexpr int block_dim = n_groups() * group_dim();
auto kernel =
cu::rms_norm_small<DataType, block_dim, group_dim(), N_READS>;
auto n_blocks =
(n_rows + groups_per_block() - 1) / groups_per_block();
encoder.add_kernel_node(
kernel,
n_blocks,
{block_dim, groups_per_block()},
0,
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(out),
eps_,
axis_size,
n_rows,
w_stride);
});
} else {
auto kernel = cu::rms_norm<DataType, 1024, N_READS>;
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>;
encoder.add_kernel_node(
kernel,
n_rows,
1024,
block_dim(),
0,
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
@@ -399,7 +229,7 @@ void RMSNorm::eval_gpu(
eps_,
axis_size,
w_stride);
}
});
});
}
@@ -476,51 +306,27 @@ void RMSNormVJP::eval_gpu(
dispatch_bool(has_w, [&](auto has_w_constant) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
if (axis_size <= N_READS * 1024) {
dispatch_group_dim<N_READS>(
axis_size,
[&](auto group_dim, auto n_groups, auto groups_per_block) {
constexpr int block_dim = group_dim() * n_groups();
auto kernel = cu::rms_norm_vjp_small<
DataType,
has_w_constant.value,
block_dim,
group_dim(),
N_READS>;
auto n_blocks =
(n_rows + groups_per_block() - 1) / groups_per_block();
encoder.add_kernel_node(
kernel,
n_blocks,
{block_dim, groups_per_block()},
0,
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(g),
gpu_ptr<DataType>(gx),
gpu_ptr<DataType>(gw_temp),
eps_,
axis_size,
n_rows,
w_stride);
});
} else {
auto kernel =
cu::rms_norm_vjp<DataType, has_w_constant.value, 1024, N_READS>;
encoder.add_kernel_node(
kernel,
n_rows,
1024,
0,
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(g),
gpu_ptr<DataType>(gx),
gpu_ptr<DataType>(gw_temp),
eps_,
axis_size,
w_stride);
}
dispatch_block_dim(
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
auto kernel = cu::rms_norm_vjp<
DataType,
has_w_constant.value,
block_dim(),
N_READS>;
encoder.add_kernel_node(
kernel,
n_rows,
block_dim(),
0,
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(g),
gpu_ptr<DataType>(gx),
gpu_ptr<DataType>(gw_temp),
eps_,
axis_size,
w_stride);
});
});
});

View File

@@ -63,38 +63,6 @@ array prepare_sdpa_input(const array& x, Stream s) {
return x;
}
void malloc_with_same_layout(
cu::CommandEncoder& encoder,
array& o,
const array& q) {
if (q.flags().row_contiguous) {
o.set_data(cu::malloc_async(o.nbytes(), encoder));
return;
}
// fill_order = argsort(q.strides())
Shape fill_order(q.ndim());
std::iota(fill_order.begin(), fill_order.end(), 0);
std::stable_sort(
fill_order.begin(), fill_order.end(), [&q](int idx1, int idx2) {
auto s1 = q.strides(idx1) > 0 ? q.strides(idx1) : 1;
auto s2 = q.strides(idx2) > 0 ? q.strides(idx2) : 1;
return s1 < s2;
});
// Generate o_strides with fill_order
Strides o_strides(q.ndim());
int64_t stride = 1;
for (int i : fill_order) {
o_strides[i] = stride;
stride *= o.shape(i);
}
// o is a transposed contiguous array
o.set_data(
cu::malloc_async(o.nbytes(), encoder),
o.size(),
o_strides,
{true, false, false});
}
constexpr int QKV_NDIM = 4;
struct SDPACacheKey {
@@ -107,8 +75,6 @@ struct SDPACacheKey {
std::array<int64_t, QKV_NDIM> k_strides;
std::array<int64_t, QKV_NDIM> v_strides;
bool do_causal;
std::array<int, QKV_NDIM> mask_shape;
std::array<int64_t, QKV_NDIM> mask_strides;
bool output_logsumexp;
};
@@ -118,7 +84,6 @@ inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
const array& k,
const array& v,
bool do_causal,
const std::optional<array>& mask_arr,
bool output_logsumexp = true) {
BytesKey<SDPACacheKey> cache_key;
cache_key.pod = {
@@ -131,26 +96,20 @@ inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
vector_key<QKV_NDIM>(k.strides()),
vector_key<QKV_NDIM>(v.strides()),
do_causal,
{},
{},
output_logsumexp,
};
if (mask_arr) {
cache_key.pod.mask_shape = vector_key<QKV_NDIM>(mask_arr->shape());
cache_key.pod.mask_strides = vector_key<QKV_NDIM>(mask_arr->strides());
}
return cache_key;
}
auto& sdpa_cache() {
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 64);
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 16);
return cache;
}
auto& sdpa_backward_cache() {
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 64);
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 16);
return cache;
}
@@ -159,7 +118,6 @@ enum UIDS {
K,
V,
SCALE,
BIAS,
O,
STATS,
// Backward graph:
@@ -175,7 +133,6 @@ fe::graph::Graph build_sdpa_graph(
const array& k,
const array& v,
bool do_causal,
const std::optional<array>& mask_arr,
bool output_logsumexp,
const array& o,
const array& stats) {
@@ -207,19 +164,8 @@ fe::graph::Graph build_sdpa_graph(
auto options = fe::graph::SDPA_attributes()
.set_name("sdpa_cudnn")
.set_attn_scale(scale)
.set_causal_mask(do_causal)
.set_generate_stats(output_logsumexp);
if (do_causal) {
if (q.shape(2) > k.shape(2)) {
options.set_causal_mask(do_causal);
} else {
options.set_causal_mask_bottom_right(do_causal);
}
}
if (mask_arr) {
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
set_tensor_attrs(bias_, BIAS, *mask_arr);
options.set_bias(bias_);
}
auto [o_, stats_] = graph.sdpa(q_, k_, v_, options);
o_->set_output(true);
@@ -246,7 +192,6 @@ fe::graph::Graph build_sdpa_backward_graph(
const array& k,
const array& v,
bool do_causal,
const std::optional<array>& mask_arr,
const array& o,
const array& d_o,
const array& stats,
@@ -288,19 +233,7 @@ fe::graph::Graph build_sdpa_backward_graph(
auto options = fe::graph::SDPA_backward_attributes()
.set_name("sdpa_backward_cudnn")
.set_attn_scale(scale)
.set_attn_scale(scale);
if (do_causal) {
if (q.shape(2) > k.shape(2)) {
options.set_causal_mask(do_causal);
} else {
options.set_causal_mask_bottom_right(do_causal);
}
}
if (mask_arr) {
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
set_tensor_attrs(bias_, BIAS, *mask_arr);
options.set_bias(bias_);
}
.set_causal_mask(do_causal);
auto [d_q_, d_k_, d_v_] =
graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options);
@@ -353,6 +286,7 @@ bool supports_sdpa_cudnn(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool do_causal,
Stream s) {
static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1);
@@ -365,8 +299,19 @@ bool supports_sdpa_cudnn(
return false;
}
// Only use cuDNN for prefilling (T_q > 1) and training (T_q == T_kv).
if ((q.shape(2) == 1) && (q.shape(2) != k.shape(2))) {
if (has_mask) {
// TODO: Support array masks.
if (!do_causal) {
return false;
}
// FIXME: Causal mask generates wrong results when L_Q != L_K.
if (q.shape(2) != k.shape(2)) {
return false;
}
}
// Only use cuDNN for prefilling and training.
if (q.shape(2) != k.shape(2)) {
return false;
}
@@ -388,33 +333,32 @@ void sdpa_cudnn(
array& o,
array& stats,
bool do_causal,
const std::optional<array>& mask_arr,
bool output_logsumexp,
Stream s) {
auto& encoder = cu::get_command_encoder(s);
auto handle = encoder.device().cudnn_handle();
malloc_with_same_layout(encoder, o, q);
// TODO: Handle donation.
// TODO: Make O use same memory layout with Q.
o.set_data(cu::malloc_async(o.nbytes(), encoder));
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
encoder.set_output_array(o);
if (mask_arr) {
encoder.set_input_array(*mask_arr);
}
if (output_logsumexp) {
stats.set_data(cu::malloc_async(stats.nbytes(), encoder));
encoder.set_output_array(stats);
}
// Search cache.
auto cache_key = build_sdpa_cache_key(
encoder, q, k, v, do_causal, mask_arr, output_logsumexp);
auto cache_key =
build_sdpa_cache_key(encoder, q, k, v, do_causal, output_logsumexp);
auto it = sdpa_cache().find(cache_key);
if (it == sdpa_cache().end()) {
auto graph = build_sdpa_graph(
handle, q, k, v, do_causal, mask_arr, output_logsumexp, o, stats);
handle, q, k, v, do_causal, output_logsumexp, o, stats);
it = sdpa_cache().emplace(cache_key, std::move(graph)).first;
}
auto& graph = it->second;
@@ -425,9 +369,6 @@ void sdpa_cudnn(
{V, const_cast<void*>(gpu_ptr<void>(v))},
{SCALE, &scale},
{O, gpu_ptr<void>(o)}};
if (mask_arr) {
variant_pack[BIAS] = const_cast<void*>(gpu_ptr<void>(*mask_arr));
}
if (output_logsumexp) {
variant_pack[STATS] = gpu_ptr<void>(stats);
}
@@ -443,7 +384,6 @@ void sdpa_backward_cudnn(
const array& o,
const array& stats,
bool do_causal,
const std::optional<array>& mask_arr,
const array& d_o,
array& d_q,
array& d_k,
@@ -452,9 +392,10 @@ void sdpa_backward_cudnn(
auto& encoder = cu::get_command_encoder(s);
auto handle = encoder.device().cudnn_handle();
malloc_with_same_layout(encoder, d_q, q);
malloc_with_same_layout(encoder, d_k, k);
malloc_with_same_layout(encoder, d_v, v);
// TODO: Handle donation.
d_q.set_data(cu::malloc_async(d_q.nbytes(), encoder));
d_k.set_data(cu::malloc_async(d_k.nbytes(), encoder));
d_v.set_data(cu::malloc_async(d_v.nbytes(), encoder));
encoder.set_input_array(q);
encoder.set_input_array(k);
@@ -465,16 +406,13 @@ void sdpa_backward_cudnn(
encoder.set_output_array(d_q);
encoder.set_output_array(d_k);
encoder.set_output_array(d_v);
if (mask_arr) {
encoder.set_input_array(*mask_arr);
}
// Search cache.
auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal, mask_arr);
auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal);
auto it = sdpa_backward_cache().find(cache_key);
if (it == sdpa_backward_cache().end()) {
auto graph = build_sdpa_backward_graph(
handle, q, k, v, do_causal, mask_arr, o, d_o, stats, d_q, d_k, d_v);
handle, q, k, v, do_causal, o, d_o, stats, d_q, d_k, d_v);
it = sdpa_backward_cache().emplace(cache_key, std::move(graph)).first;
}
auto& graph = it->second;
@@ -490,9 +428,6 @@ void sdpa_backward_cudnn(
{D_Q, gpu_ptr<void>(d_q)},
{D_K, gpu_ptr<void>(d_k)},
{D_V, gpu_ptr<void>(d_v)}};
if (mask_arr) {
variant_pack[BIAS] = const_cast<void*>(gpu_ptr<void>(*mask_arr));
}
execute_graph(encoder, handle, graph, variant_pack);
}
@@ -534,11 +469,7 @@ bool ScaledDotProductAttention::use_fallback(
return !supports_sdpa_vector(
q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) &&
!supports_sdpa_cudnn(q, k, v, do_causal, s);
}
bool ScaledDotProductAttention::supports_bool_mask() {
return false;
!supports_sdpa_cudnn(q, k, v, has_mask, do_causal, s);
}
void ScaledDotProductAttention::eval_gpu(
@@ -556,11 +487,6 @@ void ScaledDotProductAttention::eval_gpu(
bool has_mask = inputs.size() - has_sinks_ > 3;
bool has_arr_mask = has_mask && !do_causal_;
std::optional<array> mask_arr;
if (has_arr_mask) {
mask_arr = prepare_sdpa_input(inputs[3], s);
}
if (supports_sdpa_vector(
q, k, v, has_mask, has_arr_mask, do_causal_, output_logsumexp_)) {
if (has_sinks_) {
@@ -569,17 +495,7 @@ void ScaledDotProductAttention::eval_gpu(
sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s);
}
} else {
sdpa_cudnn(
q,
k,
v,
scale_,
out,
stats,
do_causal_,
mask_arr,
output_logsumexp_,
s);
sdpa_cudnn(q, k, v, scale_, out, stats, do_causal_, output_logsumexp_, s);
}
}
@@ -599,21 +515,13 @@ void ScaledDotProductAttentionVJP::eval_gpu(
auto& s = stream();
assert(inputs.size() >= 6);
int primals_size = inputs.size() - 3;
bool has_arr_mask = primals_size > 3 + has_sinks_;
assert(inputs.size() == 6);
array q = prepare_sdpa_input(inputs[0], s);
array k = prepare_sdpa_input(inputs[1], s);
array v = prepare_sdpa_input(inputs[2], s);
array o = prepare_sdpa_input(inputs[primals_size], s);
array stats = prepare_sdpa_input(inputs[primals_size + 1], s);
array d_o = prepare_sdpa_input(inputs[primals_size + 2], s);
std::optional<array> mask_arr;
if (has_arr_mask) {
mask_arr = prepare_sdpa_input(inputs[3], s);
}
array o = prepare_sdpa_input(inputs[3], s);
array stats = prepare_sdpa_input(inputs[4], s);
array d_o = prepare_sdpa_input(inputs[5], s);
assert(outputs.size() == 3);
auto& d_q = outputs[0];
@@ -621,7 +529,7 @@ void ScaledDotProductAttentionVJP::eval_gpu(
auto& d_v = outputs[2];
sdpa_backward_cudnn(
q, k, v, scale_, o, stats, do_causal_, mask_arr, d_o, d_q, d_k, d_v, s);
q, k, v, scale_, o, stats, do_causal_, d_o, d_q, d_k, d_v, s);
}
} // namespace fast

View File

@@ -5,7 +5,6 @@
#include "mlx/dtype_utils.h"
#include <fmt/format.h>
#include <vector>
namespace mlx::core {

View File

@@ -121,6 +121,14 @@ if(NOT MLX_METAL_PATH)
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
endif()
if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL
26.2))
set(MLX_ENABLE_NAX TRUE)
target_compile_definitions(mlx PRIVATE MLX_ENABLE_NAX)
else()
set(MLX_ENABLE_NAX FALSE)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
target_compile_definitions(mlx

View File

@@ -265,19 +265,14 @@ Device& device(mlx::core::Device);
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
#ifdef MLX_ENABLE_NAX
inline bool is_nax_available() {
auto _check_nax = []() {
bool can_use_nax = false;
if (__builtin_available(
macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
can_use_nax = true;
}
can_use_nax &=
metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17;
return can_use_nax;
};
static bool is_nax_available_ = _check_nax();
static bool is_nax_available_ =
metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17;
return is_nax_available_;
}
#endif // MLX_ENABLE_NAX
} // namespace mlx::core::metal

View File

@@ -9,17 +9,13 @@ set(BASE_HEADERS
utils.h)
function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS
-x
metal
-Wall
-Wextra
-fno-fast-math
-Wno-c++17-extensions
-Wno-c++20-extensions)
set(METAL_FLAGS -x metal -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
endif()
if(MLX_ENABLE_NAX)
set(METAL_FLAGS ${METAL_FLAGS} -Wno-c++20-extensions -std=metal4.0)
endif()
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
set(METAL_FLAGS ${METAL_FLAGS}
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
@@ -127,8 +123,8 @@ if(NOT MLX_METAL_JIT)
build_kernel(gemv_masked steel/utils.h)
endif()
if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL
26.2))
if(MLX_ENABLE_NAX)
set(STEEL_NAX_HEADERS
steel/defines.h
steel/utils.h

View File

@@ -16,77 +16,7 @@ INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
mkdir -p "$OUTPUT_DIR"
# CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
CCC="xcrun -sdk macosx metal -x metal"
HDRS=$( $CCC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P -CC -C -H "$INPUT_FILE" $CFLAGS -w 2>&1 1>/dev/null )
declare -a HDRS_LIST=($HDRS)
declare -a HDRS_STACK=()
declare -a HDRS_SORTED=()
length=${#HDRS_LIST[@]}
HDRS_LIST+=(".")
for ((i=0; i<${length}; i+=2));
do
header="${HDRS_LIST[$i+1]#$SRC_DIR/}"
str_this="${HDRS_LIST[$i]}"
str_next="${HDRS_LIST[$i + 2]}"
depth_this=${#str_this}
depth_next=${#str_next}
# If we have a dependency then we stack it
if [ $depth_next -gt $depth_this ]; then
HDRS_STACK=($header ${HDRS_STACK[@]})
# If we are done with this level
else
# We add the header to out list
HDRS_SORTED+=($header)
# Pop the stacked up dependencies
pop_len=$((depth_this - depth_next))
for popped_header in "${HDRS_STACK[@]:0:$pop_len}"
do
HDRS_SORTED+=($popped_header)
done
HDRS_STACK=(${HDRS_STACK[@]:$pop_len})
fi
done
HDRS_SORTED+=("${INPUT_FILE#$SRC_DIR/}")
CONTENT=$(
echo "// Copyright © 2025 Apple Inc."
echo ""
echo "// Auto generated source for $INPUT_FILE"
echo ""
for header in "${HDRS_SORTED[@]}"
do
echo "///////////////////////////////////////////////////////////////////////////////"
echo "// Contents from \"${header}\""
echo "///////////////////////////////////////////////////////////////////////////////"
echo ""
echo "#line 1 \"${header}\""
grep -h -v -G -e "#include \".*.h\"" -e "#pragma once" "${SRC_DIR}/${header}"
echo ""
done
echo "///////////////////////////////////////////////////////////////////////////////"
)
CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
cat << EOF > "$OUTPUT_FILE"
namespace mlx::core::metal {

View File

@@ -172,6 +172,8 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
// Regular steel matmul dispatch
///////////////////////////////////////////////////////////////////////////////
#ifdef MLX_ENABLE_NAX
template <bool CHECK_AB>
void steel_matmul_regular_axpby_nax(
const Stream& s,
@@ -208,11 +210,11 @@ void steel_matmul_regular_axpby_nax(
std::ostringstream kname;
// clang-format off
kname << "steel_gemm_fused_nax_"
kname << "steel_gemm_fused_nax_"
<< (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n')
<< "_" << type_to_name(a)
<< "_" << type_to_name(out)
<< (transpose_b ? 't' : 'n')
<< "_" << type_to_name(a)
<< "_" << type_to_name(out)
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn; // clang-format on
@@ -327,6 +329,8 @@ void steel_matmul_regular_axpby_nax(
d.add_temporaries(std::move(copies), s.index);
}
#endif // MLX_ENABLE_NAX
template <bool CHECK_AB>
void steel_matmul_regular_axpby(
const Stream& s,
@@ -353,35 +357,41 @@ void steel_matmul_regular_axpby(
int64_t C_batch_stride /* = 0*/,
float alpha /* = 1.0f */,
float beta /* = 0.0f */) {
if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) &&
(env::enable_tf32() || a.dtype() != float32)) {
return steel_matmul_regular_axpby_nax<CHECK_AB>(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* const array& c = */ c,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int batch_size_out = */ batch_size_out,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* int ldd = */ ldd,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ batch_shape,
/* Strides batch_strides = */ batch_strides,
/* int64_t A_batch_stride = */ A_batch_stride,
/* int64_t B_batch_stride = */ B_batch_stride,
/* int64_t matrix_stride_out = */ matrix_stride_out,
/* int64_t C_batch_stride = */ C_batch_stride,
/* float alpha = */ alpha,
/* float beta = */ beta);
#ifdef MLX_ENABLE_NAX
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) &&
(env::enable_tf32() || a.dtype() != float32)) {
return steel_matmul_regular_axpby_nax<CHECK_AB>(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* const array& c = */ c,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int batch_size_out = */ batch_size_out,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* int ldd = */ ldd,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ batch_shape,
/* Strides batch_strides = */ batch_strides,
/* int64_t A_batch_stride = */ A_batch_stride,
/* int64_t B_batch_stride = */ B_batch_stride,
/* int64_t matrix_stride_out = */ matrix_stride_out,
/* int64_t C_batch_stride = */ C_batch_stride,
/* float alpha = */ alpha,
/* float beta = */ beta);
}
}
#endif // MLX_ENABLE_NAX
using namespace mlx::steel;
// Determine dispatch kernel
@@ -395,11 +405,11 @@ void steel_matmul_regular_axpby(
std::ostringstream kname;
// clang-format off
kname << "steel_gemm_fused_"
kname << "steel_gemm_fused_"
<< (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n')
<< "_" << type_to_name(a)
<< "_" << type_to_name(out)
<< (transpose_b ? 't' : 'n')
<< "_" << type_to_name(a)
<< "_" << type_to_name(out)
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn; // clang-format on
@@ -564,14 +574,14 @@ void steel_gemm_splitk_axpby(
std::ostringstream kname;
// clang-format off
kname << "steel_gemm_splitk_"
kname << "steel_gemm_splitk_"
<< (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n')
<< "_" << type_to_name(a)
<< "_" << type_to_name(C_split)
<< (transpose_b ? 't' : 'n')
<< "_" << type_to_name(a)
<< "_" << type_to_name(C_split)
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn
<< "_MN_" << (mn_aligned ? "t" : "n") << "aligned"
<< "_wm" << wm << "_wn" << wn
<< "_MN_" << (mn_aligned ? "t" : "n") << "aligned"
<< "_K_" << (k_aligned ? "t" : "n") << "aligned"; // clang-format on
// Encode and dispatch gemm kernel
@@ -905,10 +915,10 @@ void gemv_axbpy(
const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);
// clang-format off
kname << "_bm" << bm << "_bn" << bn
<< "_sm" << sm << "_sn" << sn
kname << "_bm" << bm << "_bn" << bn
<< "_sm" << sm << "_sn" << sn
<< "_tm" << tm << "_tn" << tn
<< "_nc" << !contiguous_kernel
<< "_nc" << !contiguous_kernel
<< "_axpby" << do_axpby; // clang-format on
// Encode and dispatch kernel
@@ -1756,6 +1766,8 @@ void gather_mm_rhs(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
#ifdef MLX_ENABLE_NAX
void gather_mm_rhs_nax(
const array& a_,
const array& b_,
@@ -1899,6 +1911,8 @@ void gather_mm_rhs_nax(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
#endif // MLX_ENABLE_NAX
void gather_mv(
const array& mat_,
const array& vec_,
@@ -2182,10 +2196,19 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// We are walking a in order and b is also in order so we can batch up the
// matmuls and reuse reading a and b.
if (M == 1 && right_sorted_ == true) {
if (metal::is_nax_available() &&
(env::enable_tf32() || a.dtype() != float32)) {
return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s);
#ifdef MLX_ENABLE_NAX
if (__builtin_available(
macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() &&
!issubdtype(a.dtype(), complexfloating) &&
(env::enable_tf32() || a.dtype() != float32)) {
return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s);
}
}
#endif // MLX_ENABLE_NAX
gather_mm_rhs(a, b, rhs_indices, out, d, s);
return;
}

View File

@@ -451,6 +451,8 @@ void qvm(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
#ifdef MLX_ENABLE_NAX
void qmm_nax(
const array& x,
const array& w,
@@ -651,6 +653,8 @@ void gather_qmm_nax(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
#endif // MLX_ENABLE_NAX
void qmm(
const array& x,
const array& w,
@@ -666,25 +670,31 @@ void qmm(
metal::Device& d,
const Stream& s,
const std::string& mode) {
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
(env::enable_tf32() || x.dtype() != float32)) {
return qmm_nax(
/* const array& x = */ x,
/* const array& w = */ w,
/* const array& scales = */ scales,
/* const std::optional<array>& biases = */ biases,
/* array& out = */ out,
/* bool transpose = */ transpose,
/* int group_size = */ group_size,
/* int bits = */ bits,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* metal::Device& d = */ d,
/* const Stream& s = */ s,
/* const std::string& mode = */ mode);
#ifdef MLX_ENABLE_NAX
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
(env::enable_tf32() || x.dtype() != float32)) {
return qmm_nax(
/* const array& x = */ x,
/* const array& w = */ w,
/* const array& scales = */ scales,
/* const std::optional<array>& biases = */ biases,
/* array& out = */ out,
/* bool transpose = */ transpose,
/* int group_size = */ group_size,
/* int bits = */ bits,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* metal::Device& d = */ d,
/* const Stream& s = */ s,
/* const std::string& mode = */ mode);
}
}
#endif // MLX_ENABLE_NAX
int B = out.size() / M / N;
int wm = 2;
@@ -762,27 +772,33 @@ void gather_qmm(
metal::Device& d,
const Stream& s,
const std::string& mode) {
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
(env::enable_tf32() || x.dtype() != float32)) {
return gather_qmm_nax(
/* const array& x = */ x,
/* const array& w = */ w,
/* const array& scales = */ scales,
/* const std::optional<array>& biases = */ biases,
/* const array& lhs_indices = */ lhs_indices,
/* const array& rhs_indices = */ rhs_indices,
/* array& out = */ out,
/* bool transpose = */ transpose,
/* int group_size = */ group_size,
/* int bits = */ bits,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* metal::Device& d = */ d,
/* const Stream& s = */ s,
/* const std::string& mode = */ mode);
#ifdef MLX_ENABLE_NAX
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
(env::enable_tf32() || x.dtype() != float32)) {
return gather_qmm_nax(
/* const array& x = */ x,
/* const array& w = */ w,
/* const array& scales = */ scales,
/* const std::optional<array>& biases = */ biases,
/* const array& lhs_indices = */ lhs_indices,
/* const array& rhs_indices = */ rhs_indices,
/* array& out = */ out,
/* bool transpose = */ transpose,
/* int group_size = */ group_size,
/* int bits = */ bits,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* metal::Device& d = */ d,
/* const Stream& s = */ s,
/* const std::string& mode = */ mode);
}
}
#endif // MLX_ENABLE_NAX
int B = out.size() / M / N;
int wm = 2;
@@ -959,6 +975,8 @@ void gather_qvm(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
#ifdef MLX_ENABLE_NAX
void gather_qmm_rhs_nax(
const array& x_,
const array& w_,
@@ -1090,6 +1108,8 @@ void gather_qmm_rhs_nax(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
#endif // MLX_ENABLE_NAX
void gather_qmm_rhs(
const array& x_,
const array& w_,
@@ -1106,26 +1126,32 @@ void gather_qmm_rhs(
metal::Device& d,
const Stream& s,
const std::string mode) {
if (metal::is_nax_available() && transpose &&
(env::enable_tf32() || x_.dtype() != float32)) {
return gather_qmm_rhs_nax(
/* const array& x_ = */ x_,
/* const array& w_ = */ w_,
/* const array& scales_ = */ scales_,
/* const std::optional<array>& biases_ = */ biases_,
/* const array& indices_ = */ indices_,
/* array& out = */ out,
/* bool transpose = */ transpose,
/* int group_size = */ group_size,
/* int bits = */ bits,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* metal::Device& d = */ d,
/* const Stream& s = */ s,
/* const std::string mode = */ mode);
#ifdef MLX_ENABLE_NAX
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && transpose &&
(env::enable_tf32() || x_.dtype() != float32)) {
return gather_qmm_rhs_nax(
/* const array& x_ = */ x_,
/* const array& w_ = */ w_,
/* const array& scales_ = */ scales_,
/* const std::optional<array>& biases_ = */ biases_,
/* const array& indices_ = */ indices_,
/* array& out = */ out,
/* bool transpose = */ transpose,
/* int group_size = */ group_size,
/* int bits = */ bits,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* metal::Device& d = */ d,
/* const Stream& s = */ s,
/* const std::string mode = */ mode);
}
}
#endif // MLX_ENABLE_NAX
// Start by normalizing the indices
array indices = ensure_row_contiguous(indices_, d, s);

View File

@@ -13,6 +13,8 @@ namespace mlx::core::fast {
namespace {
#ifdef MLX_ENABLE_NAX
void sdpa_full_self_attention_nax(
const Stream& s,
metal::Device& d,
@@ -148,6 +150,8 @@ void sdpa_full_self_attention_nax(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
#endif // MLX_ENABLE_NAX
void sdpa_full_self_attention_metal(
const Stream& s,
metal::Device& d,
@@ -159,20 +163,24 @@ void sdpa_full_self_attention_metal(
bool do_causal_,
const std::optional<array>& mask,
const std::optional<array>& sinks) {
if (metal::is_nax_available() && q.shape(3) != 80 &&
(env::enable_tf32() || q.dtype() != float32)) {
return sdpa_full_self_attention_nax(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& q = */ q,
/* const array& k = */ k,
/* const array& v = */ v,
/* const float scale = */ scale,
/* array& o = */ o,
/* bool do_causal_ = */ do_causal_,
/* const std::optional<array>& mask = */ mask,
/* const std::optional<array>& sinks = */ sinks);
#ifdef MLX_ENABLE_NAX
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && q.shape(3) != 80 &&
(env::enable_tf32() || q.dtype() != float32)) {
return sdpa_full_self_attention_nax(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& q = */ q,
/* const array& k = */ k,
/* const array& v = */ v,
/* const float scale = */ scale,
/* array& o = */ o,
/* bool do_causal_ = */ do_causal_,
/* const std::optional<array>& mask = */ mask,
/* const std::optional<array>& sinks = */ sinks);
}
}
#endif // MLX_ENABLE_NAX
using namespace mlx::steel;
@@ -569,10 +577,6 @@ bool ScaledDotProductAttention::use_fallback(
return !(supports_sdpa_full || supports_sdpa_vector);
}
bool ScaledDotProductAttention::supports_bool_mask() {
return true;
}
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {

View File

@@ -36,10 +36,6 @@ bool fast::ScaledDotProductAttention::use_fallback(
return true;
}
bool fast::ScaledDotProductAttention::supports_bool_mask() {
return false;
}
bool fast::ScaledDotProductAttentionVJP::use_fallback(
const array& q,
Stream s) {

View File

@@ -4,6 +4,11 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
if(MLX_BUILD_CPU AND NOT WIN32)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ibv)

View File

@@ -5,6 +5,7 @@
#include "mlx/backend/cuda/cuda.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/ibv/ibv.h"
#include "mlx/distributed/mpi/mpi.h"
#include "mlx/distributed/nccl/nccl.h"
#include "mlx/distributed/ring/ring.h"
@@ -102,7 +103,8 @@ class EmptyGroup : public GroupImpl {
} // namespace detail
bool is_available() {
return mpi::is_available() || ring::is_available() || nccl::is_available();
return mpi::is_available() || ring::is_available() || nccl::is_available() ||
ibv::is_available();
}
int Group::rank() const {
@@ -135,6 +137,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
group = ring::init(strict);
} else if (bk == "nccl") {
group = nccl::init(strict);
} else if (bk == "ibv") {
group = ibv::init(strict);
} else if (bk == "any") {
if (mlx::core::cu::is_available()) {
group = nccl::init(false);
@@ -148,13 +152,17 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
group = mpi::init(false);
bk_ = "mpi";
}
if (group == nullptr) {
group = ibv::init(false);
bk_ = "ibv";
}
if (group == nullptr && strict) {
throw std::runtime_error("[distributed] Couldn't initialize any backend");
}
} else {
std::ostringstream msg;
msg << "[distributed] The only valid values for backend are 'any', 'mpi' "
<< "and 'ring' but '" << bk << "' was provided.";
msg << "[distributed] The only valid values for backend are 'any', 'mpi', 'nccl', "
<< "'ibv' and 'ring' but '" << bk << "' was provided.";
throw std::invalid_argument(msg.str());
}

View File

@@ -0,0 +1,8 @@
if(MLX_BUILD_CPU
AND ${CMAKE_SYSTEM_NAME} MATCHES "Darwin"
AND MACOS_SDK_VERSION GREATER_EQUAL 26.2)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ibv.cpp)
target_link_libraries(mlx PRIVATE rdma)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_ibv.cpp)
endif()

1122
mlx/distributed/ibv/ibv.cpp Normal file

File diff suppressed because it is too large Load Diff

12
mlx/distributed/ibv/ibv.h Normal file
View File

@@ -0,0 +1,12 @@
// Copyright © 2025 Apple Inc.
#include "mlx/distributed/distributed.h"
namespace mlx::core::distributed::ibv {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available();
std::shared_ptr<GroupImpl> init(bool strict = false);
} // namespace mlx::core::distributed::ibv

View File

@@ -0,0 +1,20 @@
// Copyright © 2025 Apple Inc.
#include "mlx/distributed/ibv/ibv.h"
namespace mlx::core::distributed::ibv {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available() {
return false;
}
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
if (strict) {
throw std::runtime_error("Cannot initialize ibv distributed backend.");
}
return nullptr;
}
} // namespace mlx::core::distributed::ibv

View File

@@ -5,48 +5,3 @@
ncclResult_t ncclGetUniqueId(ncclUniqueId*) {
return ncclSuccess;
}
const char* ncclGetErrorString(ncclResult_t result) {
return nullptr;
}
ncclResult_t
ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank) {
return ncclSuccess;
}
ncclResult_t ncclCommDestroy(ncclComm_t comm) {
return ncclSuccess;
}
ncclResult_t ncclAllGather(
const void* sendbuff,
void* recvbuff,
size_t sendcount,
ncclDataType_t datatype,
ncclComm_t comm,
cudaStream_t stream) {
return ncclSuccess;
}
ncclResult_t ncclAllReduce(
const void* sendbuff,
void* recvbuff,
size_t count,
ncclDataType_t datatype,
ncclRedOp_t op,
ncclComm_t comm,
cudaStream_t stream) {
return ncclSuccess;
}
ncclResult_t ncclReduceScatter(
const void* sendbuff,
void* recvbuff,
size_t recvcount,
ncclDataType_t datatype,
ncclRedOp_t op,
ncclComm_t comm,
cudaStream_t stream) {
return ncclSuccess;
}

View File

@@ -0,0 +1,38 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core::distributed::detail {
template <typename T>
struct SumOp {
void operator()(const T* input, T* output, size_t N) const {
while (N-- > 0) {
*output += *input;
input++;
output++;
}
}
};
template <typename T>
struct MaxOp {
void operator()(const T* input, T* output, size_t N) const {
while (N-- > 0) {
*output = std::max(*output, *input);
input++;
output++;
}
}
};
template <typename T>
struct MinOp {
void operator()(const T* input, T* output, size_t N) const {
while (N-- > 0) {
*output = std::min(*output, *input);
input++;
output++;
}
}
};
} // namespace mlx::core::distributed::detail

View File

@@ -1,9 +1,6 @@
// Copyright © 2024 Apple Inc.
#include <arpa/inet.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
#include <unistd.h>
@@ -22,6 +19,8 @@
#include "mlx/backend/cpu/encoder.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/reduction_ops.h"
#include "mlx/distributed/utils.h"
#include "mlx/threadpool.h"
#ifndef SOL_TCP
@@ -94,6 +93,7 @@ constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024;
constexpr const size_t ALL_SUM_BUFFERS = 2;
constexpr const int CONN_ATTEMPTS = 5;
constexpr const int CONN_WAIT = 1000;
constexpr const char* RING_TAG = "[ring]";
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
using json = nlohmann::json;
@@ -296,55 +296,6 @@ class CommunicationThreads {
std::unordered_map<int, SocketThread> threads_;
};
struct address_t {
sockaddr_storage addr;
socklen_t len;
const sockaddr* get() const {
return (struct sockaddr*)&addr;
}
};
/**
* Parse a sockaddr from an ip and port provided as strings.
*/
address_t parse_address(const std::string& ip, const std::string& port) {
struct addrinfo hints, *res;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
if (status != 0) {
std::ostringstream msg;
msg << "Can't parse address " << ip << ":" << port;
throw std::runtime_error(msg.str());
}
address_t result;
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
result.len = res->ai_addrlen;
freeaddrinfo(res);
return result;
}
/**
* Parse a sockaddr provided as an <ip>:<port> string.
*/
address_t parse_address(const std::string& ip_port) {
auto colon = ip_port.find(":");
if (colon == std::string::npos) {
std::ostringstream msg;
msg << "Can't parse address " << ip_port;
throw std::runtime_error(msg.str());
}
std::string ip(ip_port.begin(), ip_port.begin() + colon);
std::string port(ip_port.begin() + colon + 1, ip_port.end());
return parse_address(ip, port);
}
/**
* Load all addresses from the json hostfile. The hostfile is a list of
* addresses in order of rank. For each rank there can be many addresses so
@@ -357,15 +308,15 @@ address_t parse_address(const std::string& ip_port) {
* ["ip3:5000", "ip3:5001"],
* ]
*/
std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
std::vector<std::vector<address_t>> nodes;
std::vector<std::vector<detail::address_t>> load_nodes(const char* hostfile) {
std::vector<std::vector<detail::address_t>> nodes;
std::ifstream f(hostfile);
json hosts = json::parse(f);
for (auto& h : hosts) {
std::vector<address_t> host;
std::vector<detail::address_t> host;
for (auto& ips : h) {
host.push_back(parse_address(ips.get<std::string>()));
host.push_back(std::move(detail::parse_address(ips.get<std::string>())));
}
nodes.push_back(std::move(host));
}
@@ -377,73 +328,15 @@ std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
* Create a socket and accept one connection for each of the provided
* addresses.
*/
std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
std::vector<int> accept_connections(
const std::vector<detail::address_t>& addresses) {
std::vector<int> sockets;
int success;
for (auto& address : addresses) {
// Create the socket to wait for connections from the peers
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "[ring] Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Make sure we can launch immediately after shutdown by setting the
// reuseaddr option so that we don't get address already in use errors
int enable = 1;
success = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't enable reuseaddr (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
success = setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't enable reuseport (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Bind the socket to the address and port
success = bind(sock, address.get(), address.len);
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't bind socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Wait for connections
success = listen(sock, 0);
if (success < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Couldn't listen (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
int peer_socket = accept(sock, nullptr, nullptr);
if (peer_socket < 0) {
shutdown(sock, 2);
close(sock);
std::ostringstream msg;
msg << "[ring] Accept failed (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Close the listening socket
shutdown(sock, 2);
close(sock);
sockets.push_back(peer_socket);
detail::TCPSocket socket(RING_TAG);
socket.listen(RING_TAG, address);
sockets.push_back(socket.accept(RING_TAG).detach());
}
return sockets;
@@ -454,93 +347,42 @@ std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
* provided addresses.
*/
std::vector<int> make_connections(
const std::vector<address_t>& addresses,
const std::vector<detail::address_t>& addresses,
bool verbose) {
std::vector<int> sockets;
int success;
for (auto& address : addresses) {
int sock;
// Attempt to connect to the peer CONN_ATTEMPTS times with exponential
// backoff. TODO: Do we need that?
for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
// Create the socket
sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "[ring] Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
if (attempt > 0) {
int wait = (1 << (attempt - 1)) * CONN_WAIT;
log_info(
verbose,
"Attempt",
attempt,
"wait",
wait,
"ms (error:",
errno,
")");
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
}
success = connect(sock, address.get(), address.len);
if (success == 0) {
break;
}
}
if (success < 0) {
std::ostringstream msg;
msg << "[ring] Couldn't connect (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
sockets.push_back(sock);
sockets.push_back(detail::TCPSocket::connect(
RING_TAG,
address,
CONN_ATTEMPTS,
CONN_WAIT,
[verbose](int attempt, int wait) {
log_info(
verbose,
"Attempt",
attempt,
"waiting",
wait,
"ms (error:",
errno,
")");
})
.detach());
}
return sockets;
}
template <typename T>
struct SumOp {
void operator()(const T* input, T* output, size_t N) {
while (N-- > 0) {
*output += *input;
input++;
output++;
}
}
};
template <typename T>
struct MaxOp {
void operator()(const T* input, T* output, size_t N) {
while (N-- > 0) {
*output = std::max(*output, *input);
input++;
output++;
}
}
};
template <typename T>
struct MinOp {
void operator()(const T* input, T* output, size_t N) {
while (N-- > 0) {
*output = std::min(*output, *input);
input++;
output++;
}
}
};
} // namespace
class RingGroup : public GroupImpl {
public:
RingGroup(int rank, std::vector<std::vector<address_t>> nodes, bool verbose)
RingGroup(
int rank,
std::vector<std::vector<detail::address_t>> nodes,
bool verbose)
: rank_(rank), verbose_(verbose), pool_(0) {
if (rank_ > 0 && rank_ >= nodes.size()) {
throw std::runtime_error(
@@ -633,17 +475,17 @@ class RingGroup : public GroupImpl {
void all_sum(const array& input, array& output, Stream stream) override {
SWITCH_TYPE(
output, all_reduce<T, SumOp<T>>(input, output, stream, SumOp<T>()));
output, all_reduce<T>(input, output, stream, detail::SumOp<T>()));
}
void all_max(const array& input, array& output, Stream stream) override {
SWITCH_TYPE(
output, all_reduce<T, MaxOp<T>>(input, output, stream, MaxOp<T>()));
output, all_reduce<T>(input, output, stream, detail::MaxOp<T>()));
}
void all_min(const array& input, array& output, Stream stream) override {
SWITCH_TYPE(
output, all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>()));
output, all_reduce<T>(input, output, stream, detail::MinOp<T>()));
}
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {

203
mlx/distributed/utils.cpp Normal file
View File

@@ -0,0 +1,203 @@
// Copyright © 2025 Apple Inc.
#include <netdb.h>
#include <unistd.h>
#include <sstream>
#include <thread>
#include "mlx/distributed/utils.h"
namespace mlx::core::distributed::detail {
/**
* Parse a sockaddr from an ip and port provided as strings.
*/
address_t parse_address(const std::string& ip, const std::string& port) {
struct addrinfo hints, *res;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
if (status != 0) {
std::ostringstream msg;
msg << "Can't parse address " << ip << ":" << port;
throw std::runtime_error(msg.str());
}
address_t result;
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
result.len = res->ai_addrlen;
freeaddrinfo(res);
return result;
}
/**
* Parse a sockaddr provided as an <ip>:<port> string.
*/
address_t parse_address(const std::string& ip_port) {
auto colon = ip_port.find(":");
if (colon == std::string::npos) {
std::ostringstream msg;
msg << "Can't parse address " << ip_port;
throw std::runtime_error(msg.str());
}
std::string ip(ip_port.begin(), ip_port.begin() + colon);
std::string port(ip_port.begin() + colon + 1, ip_port.end());
return parse_address(ip, port);
}
TCPSocket::TCPSocket(const char* tag) {
sock_ = socket(AF_INET, SOCK_STREAM, 0);
if (sock_ < 0) {
std::ostringstream msg;
msg << tag << " Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
}
TCPSocket::TCPSocket(TCPSocket&& s) {
sock_ = s.sock_;
s.sock_ = -1;
}
TCPSocket& TCPSocket::operator=(TCPSocket&& s) {
if (this != &s) {
sock_ = s.sock_;
s.sock_ = -1;
}
return *this;
}
TCPSocket::TCPSocket(int s) : sock_(s) {}
TCPSocket::~TCPSocket() {
if (sock_ > 0) {
shutdown(sock_, 2);
close(sock_);
}
}
int TCPSocket::detach() {
int s = sock_;
sock_ = -1;
return s;
}
void TCPSocket::listen(const char* tag, const address_t& addr) {
int success;
// Make sure we can launch immediately after shutdown by setting the
// reuseaddr option so that we don't get address already in use errors
int enable = 1;
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't enable reuseaddr (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't enable reuseport (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Bind the socket to the address and port
success = bind(sock_, addr.get(), addr.len);
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't bind socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
// Prepare waiting for connections
success = ::listen(sock_, 0);
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't listen (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
}
TCPSocket TCPSocket::accept(const char* tag) {
int peer = ::accept(sock_, nullptr, nullptr);
if (peer < 0) {
std::ostringstream msg;
msg << tag << " Accept failed (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
return TCPSocket(peer);
}
void TCPSocket::send(const char* tag, const void* data, size_t len) {
while (len > 0) {
auto n = ::send(sock_, data, len, 0);
if (n <= 0) {
std::ostringstream msg;
msg << tag << " Send failed with errno=" << errno;
throw std::runtime_error(msg.str());
}
len -= n;
data = static_cast<const char*>(data) + n;
}
}
void TCPSocket::recv(const char* tag, void* data, size_t len) {
while (len > 0) {
auto n = ::recv(sock_, data, len, 0);
if (n <= 0) {
std::ostringstream msg;
msg << tag << " Recv failed with errno=" << errno;
throw std::runtime_error(msg.str());
}
len -= n;
data = static_cast<char*>(data) + n;
}
}
TCPSocket TCPSocket::connect(
const char* tag,
const address_t& addr,
int num_retries,
int wait,
std::function<void(int, int)> cb) {
int sock, success;
// Attempt to connect `num_retries` times with exponential backoff.
for (int attempt = 0; attempt < num_retries; attempt++) {
// Create the socket
sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << tag << " Couldn't create socket to connect (error: " << errno
<< ")";
throw std::runtime_error(msg.str());
}
success = ::connect(sock, addr.get(), addr.len);
if (success == 0) {
break;
}
cb(attempt, wait);
if (wait > 0) {
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
}
wait <<= 1;
}
if (success < 0) {
std::ostringstream msg;
msg << tag << " Couldn't connect (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
return TCPSocket(sock);
}
} // namespace mlx::core::distributed::detail

65
mlx/distributed/utils.h Normal file
View File

@@ -0,0 +1,65 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <sys/socket.h>
namespace mlx::core::distributed::detail {
struct address_t {
sockaddr_storage addr;
socklen_t len;
const sockaddr* get() const {
return (struct sockaddr*)&addr;
}
};
/**
* Parse a sockaddr from an ip and port provided as strings.
*/
address_t parse_address(const std::string& ip, const std::string& port);
/**
* Parse a sockaddr provided as an <ip>:<port> string.
*/
address_t parse_address(const std::string& ip_port);
/**
* Small wrapper over a TCP socket to simplify initiating connections.
*/
class TCPSocket {
public:
TCPSocket(const char* tag);
TCPSocket(const TCPSocket&) = delete;
TCPSocket& operator=(const TCPSocket&) = delete;
TCPSocket(TCPSocket&& s);
TCPSocket& operator=(TCPSocket&&);
~TCPSocket();
void listen(const char* tag, const address_t& addr);
TCPSocket accept(const char* tag);
void send(const char* tag, const void* data, size_t len);
void recv(const char* tag, void* data, size_t len);
int detach();
operator int() const {
return sock_;
}
static TCPSocket connect(
const char* tag,
const address_t& addr,
int num_retries = 1,
int wait = 0,
std::function<void(int, int)> cb = nullptr);
private:
TCPSocket(int sock);
int sock_;
};
} // namespace mlx::core::distributed::detail

View File

@@ -800,15 +800,6 @@ array scaled_dot_product_attention(
is_training,
output_logsumexp,
stream)) {
if (has_bool_mask && !ScaledDotProductAttention::supports_bool_mask()) {
// Convert bool mask to additive mask.
float inf = std::numeric_limits<float>::infinity();
array& mask = inputs[3];
mask = where(
mask,
full_like(mask, 0, final_type, s),
full_like(mask, -inf, final_type, s));
}
Shape out_shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
auto primitive = std::make_shared<ScaledDotProductAttention>(
stream, fallback, scale, do_causal, has_sinks, output_logsumexp);
@@ -848,7 +839,7 @@ std::vector<array> ScaledDotProductAttention::vjp(
std::vector<Shape> shapes;
std::vector<Dtype> dtypes;
for (int i = 0; i < /* outputs size */ 3; ++i) {
for (int i = 0; i < primals.size(); ++i) {
shapes.push_back(primals[i].shape());
dtypes.push_back(primals[i].dtype());
}

View File

@@ -228,7 +228,6 @@ class ScaledDotProductAttention : public Custom {
bool is_training,
bool output_logsumexp,
Stream s);
static bool supports_bool_mask();
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {

View File

@@ -250,7 +250,7 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
std::vector<array>
svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
check_cpu_stream(s, "[linalg::svd]");
check_float_or_complex(a.dtype(), "[linalg::svd]");
check_float(a.dtype(), "[linalg::svd]");
if (a.ndim() < 2) {
std::ostringstream msg;
@@ -268,12 +268,10 @@ svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
s_shape.pop_back();
s_shape[rank - 2] = std::min(m, n);
auto s_dtype = a.dtype() == complex64 ? float32 : a.dtype();
if (!compute_uv) {
return {array(
std::move(s_shape),
s_dtype,
a.dtype(),
std::make_shared<SVD>(to_stream(s), compute_uv),
{a})};
}
@@ -288,7 +286,7 @@ svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
return array::make_arrays(
{u_shape, s_shape, vt_shape},
{a.dtype(), s_dtype, a.dtype()},
{a.dtype(), a.dtype(), a.dtype()},
std::make_shared<SVD>(to_stream(s), compute_uv),
{a});
}
@@ -705,4 +703,4 @@ array solve_triangular(
return matmul(a_inv, b, s);
}
} // namespace mlx::core::linalg
} // namespace mlx::core::linalg

View File

@@ -135,11 +135,7 @@ class Scheduler {
~Scheduler() {
for (auto s : streams_) {
try {
synchronize(s);
} catch (const std::runtime_error&) {
// ignore errors if synch fails
}
synchronize(s);
}
for (auto t : threads_) {
if (t != nullptr) {

View File

@@ -407,10 +407,7 @@ class Module(dict):
instance).
Args:
apply_fn (Callable): The function to apply to the modules which
takes two parameters. The first parameter is the string path of
the module (e.g. ``"model.layers.0.linear"``). The second
parameter is the module object.
apply_fn (Callable): The function to apply to the modules.
Returns:
The module instance after updating submodules.

View File

@@ -1445,7 +1445,7 @@ void init_ops(nb::module_& m) {
"dtype"_a.none() = mx::float32,
"stream"_a = nb::none(),
nb::sig(
"def linspace(start: scalar, stop: scalar, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array"),
"def linspace(start, stop, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Generate ``num`` evenly spaced numbers over interval ``[start, stop]``.
@@ -4021,7 +4021,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def load(file: Union[file, str, pathlib.Path], /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array], Tuple[dict[str, array], dict[str, Any]]]"),
"def load(file: Union[file, str, pathlib.Path], /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"),
R"pbdoc(
Load array(s) from a binary file.
@@ -4037,12 +4037,11 @@ void init_ops(nb::module_& m) {
which support matadata. The metadata will be returned as an
additional dictionary. Default: ``False``.
Returns:
array, dict, or tuple:
array or dict:
A single array if loading from a ``.npy`` file or a dict
mapping names to arrays if loading from a ``.npz`` or
``.safetensors`` file. If ``return_metadata`` is ``True`` a
tuple ``(arrays, metadata)`` will be returned where the second
element is a dictionary containing the metadata.
``.safetensors`` file. If ``return_metadata`` is ``True`` an
additional dictionary of metadata will be returned.
Warning:

View File

@@ -1238,18 +1238,8 @@ void init_transforms(nb::module_& m) {
same in number, shape, and type as the inputs of ``fun`` (i.e. the ``primals``).
Returns:
tuple(list(array), list(array)): A tuple with the outputs of
``fun`` in the first position and the Jacobian-vector products
in the second position.
Example:
.. code-block:: python
import mlx.core as mx
outs, jvps = mx.jvp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))
list(array): A list of the Jacobian-vector products which
is the same in number, shape, and type of the inputs to ``fun``.
)pbdoc");
m.def(
"vjp",
@@ -1287,18 +1277,8 @@ void init_transforms(nb::module_& m) {
same in number, shape, and type as the outputs of ``fun``.
Returns:
tuple(list(array), list(array)): A tuple with the outputs of
``fun`` in the first position and the vector-Jacobian products
in the second position.
Example:
.. code-block:: python
import mlx.core as mx
outs, vjps = mx.vjp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))
list(array): A list of the vector-Jacobian products which
is the same in number, shape, and type of the outputs of ``fun``.
)pbdoc");
m.def(
"value_and_grad",

View File

@@ -739,69 +739,37 @@ class TestSDPA(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(out, expected, atol=1e-5))
def test_sdpa_grad(self):
# High tolerance due to cuDNN SDPA kernel requiring tf32.
tolerance = {"rtol": 1e-2, "atol": 1e-2}
def test_vjp(slow, fast, primals):
cotan = mx.ones_like(primals[0])
o1, vjp1 = mx.vjp(slow, primals, [cotan])
o2, vjp2 = mx.vjp(fast, primals, [cotan])
self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance))
for i in range(3):
self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance))
def test_grad(slow, fast, args):
g1 = mx.grad(slow)(*args)
g2 = mx.grad(fast)(*args)
self.assertTrue(mx.allclose(g1, g2, **tolerance))
sdpa_mask_slow = lambda q, k, v, mask: mlx_ref_attn(
q, k, v, scale=scale, mask=mask
)
sdpa_mask_fast = lambda q, k, v, mask: mx.fast.scaled_dot_product_attention(
q, k, v, scale=scale, mask=mask
)
loss_mask_slow = lambda q, k, v, mask: mlx_ref_attn(
q, k, v, scale=scale, mask=mask
).sum()
loss_mask_fast = lambda q, k, v, mask: (
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
).sum()
B, N_kv, T, D = (2, 8, 128, 64)
scale = D**-0.5
f1 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale)
f2 = lambda q, k, v: mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
f3 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale).sum()
f4 = lambda q, k, v: (
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
).sum()
# High tolerance due to cuDNN SDPA kernel requiring tf32.
tolerance = {"rtol": 1e-2, "atol": 1e-2}
for N_q in (8, 32):
q = mx.random.normal(shape=(B, N_q, T, D), dtype=mx.float16)
k = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
v = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
mask_additive = mx.random.normal((B, N_q, T, T), dtype=mx.float16)
mask_bool = mx.random.uniform(0, 1, (B, N_q, T, T), dtype=mx.float16) < 0.5
cotan = mx.ones_like(q)
o1, vjp1 = mx.vjp(f1, [q, k, v], [cotan])
o2, vjp2 = mx.vjp(f2, [q, k, v], [cotan])
for mask in (mask_additive, mask_bool):
test_vjp(sdpa_mask_slow, sdpa_mask_fast, [q, k, v, mask])
test_grad(loss_mask_slow, loss_mask_fast, [q, k, v, mask])
self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance))
for i in range(3):
self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance))
for mask in (None, "causal"):
sdpa_slow = lambda q, k, v: mlx_ref_attn(
q, k, v, scale=scale, mask=mask
)
sdpa_fast = lambda q, k, v: mx.fast.scaled_dot_product_attention(
q, k, v, scale=scale, mask=mask
)
test_vjp(sdpa_slow, sdpa_fast, [q, k, v])
g1 = mx.grad(f3)(q, k, v)
g2 = mx.grad(f4)(q, k, v)
loss_slow = lambda q, k, v: mlx_ref_attn(
q, k, v, scale=scale, mask=mask
).sum()
loss_fast = lambda q, k, v: mx.fast.scaled_dot_product_attention(
q, k, v, scale=scale, mask=mask
).sum()
test_grad(loss_slow, loss_fast, [q, k, v])
self.assertTrue(mx.allclose(g1, g2, **tolerance))
if __name__ == "__main__":

View File

@@ -168,42 +168,6 @@ class TestLinalg(mlx_tests.MLXTestCase):
)
)
# Test float64 - use CPU stream since float64 is not supported on GPU
with mx.stream(mx.cpu):
A_f64 = mx.array(
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float64
)
U_f64, S_f64, Vt_f64 = mx.linalg.svd(A_f64, compute_uv=True)
mx.eval(U_f64, S_f64, Vt_f64)
self.assertTrue(
mx.allclose(
U_f64[:, : len(S_f64)] @ mx.diag(S_f64) @ Vt_f64,
A_f64,
rtol=1e-5,
atol=1e-7,
)
)
self.assertEqual(S_f64.dtype, mx.float64)
# Test complex64 - use CPU stream since complex64 is not supported on GPU
with mx.stream(mx.cpu):
A_c64 = mx.array(
[[1.0 + 1j, 2.0 + 2j], [3.0 + 3j, 4.0 + 4j]], dtype=mx.complex64
)
U_c64, S_c64, Vt_c64 = mx.linalg.svd(A_c64, compute_uv=True)
mx.eval(U_c64, S_c64, Vt_c64)
self.assertTrue(
mx.allclose(
U_c64[:, : len(S_c64)] @ mx.diag(S_c64) @ Vt_c64,
A_c64,
rtol=1e-5,
atol=1e-7,
)
)
self.assertEqual(S_c64.dtype, mx.float32)
self.assertEqual(U_c64.dtype, mx.complex64)
self.assertEqual(Vt_c64.dtype, mx.complex64)
def test_inverse(self):
A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32)
A_inv = mx.linalg.inv(A, stream=mx.cpu)
@@ -378,43 +342,6 @@ class TestLinalg(mlx_tests.MLXTestCase):
A_np = np.random.randn(3, n, n).astype(np.float32)
check_eigs_and_vecs(A_np)
# Test float64 - use CPU stream since float64 is not supported on GPU
with mx.stream(mx.cpu):
A_np_f64 = np.array([[1.0, 1.0], [3.0, 4.0]], dtype=np.float64)
A_f64 = mx.array(A_np_f64, dtype=mx.float64)
eig_vals_f64, eig_vecs_f64 = mx.linalg.eig(A_f64)
mx.eval(eig_vals_f64, eig_vecs_f64)
self.assertTrue(
mx.allclose(
A_f64 @ eig_vecs_f64,
eig_vals_f64[..., None, :] * eig_vecs_f64,
rtol=1e-5,
atol=1e-5,
)
)
# Eigenvalues should be complex64 (output dtype)
self.assertEqual(eig_vals_f64.dtype, mx.complex64)
self.assertEqual(eig_vecs_f64.dtype, mx.complex64)
# Test complex64 input - use CPU stream since complex64 is not supported on GPU
with mx.stream(mx.cpu):
A_np_c64 = np.array(
[[1.0 + 1j, 2.0 + 2j], [3.0 + 3j, 4.0 + 4j]], dtype=np.complex64
)
A_c64 = mx.array(A_np_c64, dtype=mx.complex64)
eig_vals_c64, eig_vecs_c64 = mx.linalg.eig(A_c64)
mx.eval(eig_vals_c64, eig_vecs_c64)
self.assertTrue(
mx.allclose(
A_c64 @ eig_vecs_c64,
eig_vals_c64[..., None, :] * eig_vecs_c64,
rtol=1e-5,
atol=1e-5,
)
)
self.assertEqual(eig_vals_c64.dtype, mx.complex64)
self.assertEqual(eig_vecs_c64.dtype, mx.complex64)
# Test error cases
with self.assertRaises(ValueError):
mx.linalg.eig(mx.array([1.0, 2.0])) # 1D array

View File

@@ -1443,22 +1443,23 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertListEqual(a.tolist(), expected)
def test_unary_ops(self):
def test_ops(npop, mlxop, x, y, atol, rtol):
def test_ops(npop, mlxop, x, y, atol):
r_np = npop(x)
r_mlx = mlxop(y)
mx.eval(r_mlx)
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol, rtol=rtol))
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))
x = np.random.rand(18, 28, 38)
for op in ["abs", "exp", "log", "square", "sqrt"]:
with self.subTest(op=op):
float_dtypes = [("float16", 1e-3, 1e-3), ("float32", 1e-6, 1e-5)]
float_dtypes = [("float16", 1e-3), ("float32", 1e-6)]
for dtype, atol, rtol in float_dtypes:
for dtype, atol in float_dtypes:
with self.subTest(dtype=dtype):
x_ = x.astype(getattr(np, dtype))
y_ = mx.array(x_)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol, rtol)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
def test_unary_ops_from_non_array(self):
unary_ops = [
@@ -1510,14 +1511,12 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(expected, np.array(out), equal_nan=True))
def test_trig_ops(self):
def test_ops(npop, mlxop, x, y, atol, rtol):
def test_ops(npop, mlxop, x, y, atol):
r_np = npop(x)
r_mlx = mlxop(y)
mx.eval(r_mlx)
self.assertTrue(
np.allclose(r_np, r_mlx, atol=atol, rtol=rtol, equal_nan=True)
)
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol, equal_nan=True))
x = np.random.rand(9, 12, 18)
xi = np.random.rand(9, 12, 18)
@@ -1527,34 +1526,34 @@ class TestOps(mlx_tests.MLXTestCase):
for op in all_fwd_ops:
with self.subTest(op=op):
float_dtypes = [("float16", 1e-3, 1e-3), ("float32", 1e-6, 1e-5)]
float_dtypes = [("float16", 1e-3), ("float32", 1e-6)]
for dtype, atol, rtol in float_dtypes:
for dtype, atol in float_dtypes:
with self.subTest(dtype=dtype):
x_ = x.astype(getattr(np, dtype))
y_ = mx.array(x_)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol, rtol)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
with self.subTest(op=op):
dtype = "complex64"
with self.subTest(dtype=dtype):
x_ = x + 1.0j * xi
x_ = x_.astype(getattr(np, dtype))
y_ = mx.array(x_)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, 1e-5, 1e-5)
float_dtypes = [("complex64", 1e-5)]
for dtype, atol in float_dtypes:
with self.subTest(dtype=dtype):
x_ = x + 1.0j * xi
x_ = x_.astype(getattr(np, dtype))
y_ = mx.array(x_)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
with self.subTest(op="arc" + op):
float_dtypes = [("float16", 1e-3, 1e-3), ("float32", 1e-6, 1e-5)]
float_dtypes = [("float16", 1e-3), ("float32", 1e-6)]
op_inv = "arc" + op
for dtype, atol, rtol in float_dtypes:
for dtype, atol in float_dtypes:
with self.subTest(dtype=dtype):
np_op_fwd = getattr(np, op)
x_ = np_op_fwd(x).astype(getattr(np, dtype))
y_ = mx.array(x_)
test_ops(
getattr(np, op_inv), getattr(mx, op_inv), x_, y_, atol, rtol
)
test_ops(getattr(np, op_inv), getattr(mx, op_inv), x_, y_, atol)
# Test grads
np_vjp_funcs = {
@@ -1580,10 +1579,11 @@ class TestOps(mlx_tests.MLXTestCase):
x_ = x.astype(np.float32)
y_ = mx.array(x_)
op_ = op
atol_ = 1e-5
np_vjp = lambda x: np_vjp_funcs[op_](primal_np, x)
mx_vjp = lambda x: mx.vjp(getattr(mx, op_), [primal_mx], [x])[1][0]
test_ops(np_vjp, mx_vjp, x_, y_, 1e-5, 1e-5)
test_ops(np_vjp, mx_vjp, x_, y_, atol_)
with self.subTest(op="arc" + op):
np_op_fwd = getattr(np, op)
@@ -1599,10 +1599,11 @@ class TestOps(mlx_tests.MLXTestCase):
x_ = x.astype(np.float32)
y_ = mx.array(x_)
op_ = "arc" + op
atol_ = 1e-5
np_vjp = lambda x: np_vjp_funcs[op_](primal_np, x)
mx_vjp = lambda x: mx.vjp(getattr(mx, op_), [primal_mx], [x])[1][0]
test_ops(np_vjp, mx_vjp, x_, y_, 1e-5, 1e-5)
test_ops(np_vjp, mx_vjp, x_, y_, atol_)
def test_binary_ops(self):
def test_ops(npop, mlxop, x1, x2, y1, y2, atol):