Compare commits

...

11 Commits

Author SHA1 Message Date
Andrey Portnoy
3e05cea9f8 Force cudaGraphExec reinstantiation when clusters are used (#2813)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Co-authored-by: Awni Hannun <awni@apple.com>
2025-11-22 12:43:49 -08:00
CCYeh
5b0f047226 Fix mx.core.load type annotation (#2819) 2025-11-22 11:09:44 -08:00
Harsh Sutaria
618c87af8c Add float64 Eig and complex64 SVD/Eig support (Fixes #2708) (#2737)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2025-11-22 06:51:36 -08:00
Cheng
d5f61a93fa Fix typo: refs/head/main => refs/heads/main (#2818)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-22 09:43:35 +09:00
Awni Hannun
4a09264236 Tolerance for some ops tests on cuda (#2815) 2025-11-21 16:06:16 -08:00
Awni Hannun
0dbc7e5bee Centralize NAX condition (#2811)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-21 13:28:15 -08:00
Awni Hannun
0d68efd461 patch bump for future version (#2804)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-11-20 09:26:20 -08:00
Awni Hannun
f9e1a14135 [CUDA] Partly fix random for large sizes (#2798) 2025-11-20 07:27:50 -08:00
Awni Hannun
d8e9ded928 Fix cuda allocator copy condition (#2800) 2025-11-20 07:06:55 -08:00
Awni Hannun
60939d010c Fix macos release target and linux arm release (#2802)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-11-19 21:37:50 -08:00
Awni Hannun
fdcd2923fd patch + fix docs build (#2799) 2025-11-19 16:16:26 -08:00
23 changed files with 760 additions and 456 deletions

View File

@@ -17,6 +17,8 @@ runs:
steps:
- name: Build Python package
shell: bash -l {0}
env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
pip install build
python setup.py clean --all
@@ -25,6 +27,8 @@ runs:
- name: Build backend package
if: ${{ inputs.build-backend }}
shell: bash -l {0}
env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w

View File

@@ -13,7 +13,7 @@ permissions:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/head/main' }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
jobs:
check_lint:

View File

@@ -23,7 +23,7 @@ jobs:
build_documentation:
if: github.repository == 'ml-explore/mlx'
runs-on: [self-hosted, macos]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
@@ -65,14 +65,14 @@ jobs:
uses: actions/upload-artifact@v5
with:
overwrite: true
name: linux-wheels-${{ matrix.python_version }}
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
path: wheelhouse/mlx-*.whl
- name: Upload CPU artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mlx-cpu
name: mlx-cpu-${{ matrix.arch }}
path: wheelhouse/mlx_cpu-*.whl
build_mac_release:
@@ -208,7 +208,8 @@ jobs:
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-cpu
pattern: mlx-cpu-*
merge-multiple: true
path: dist
- name: Display structure of downloaded files
run: ls -R dist

View File

@@ -13,39 +13,31 @@ namespace mlx::core {
namespace {
template <typename T>
void eig_impl(
array& a,
array& vectors,
array& values,
bool compute_eigenvectors,
Stream stream) {
using OT = std::complex<T>;
auto a_ptr = a.data<T>();
auto eig_ptr = values.data<OT>();
complex64_t to_complex(T r, T i) {
return {static_cast<float>(r), static_cast<float>(i)};
}
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(values);
OT* vec_ptr = nullptr;
if (compute_eigenvectors) {
encoder.set_output_array(vectors);
vec_ptr = vectors.data<OT>();
}
encoder.dispatch([a_ptr,
vec_ptr,
eig_ptr,
compute_eigenvectors,
N = vectors.shape(-1),
size = vectors.size()]() mutable {
// Work query
char jobr = 'N';
char jobl = compute_eigenvectors ? 'V' : 'N';
int n_vecs_r = 1;
int n_vecs_l = compute_eigenvectors ? N : 1;
int lwork = -1;
template <typename T, class Enable = void>
struct EigWork {};
template <typename T>
struct EigWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
using O = complex64_t;
char jobl;
char jobr;
int N;
int lwork;
int info;
{
std::vector<array::Data> buffers;
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1) {
T work;
int n_vecs_l = compute_eigenvectors ? N_ : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
@@ -62,58 +54,165 @@ void eig_impl(
&lwork,
&info);
lwork = static_cast<int>(work);
buffers.emplace_back(allocator::malloc(sizeof(T) * N * 2));
if (compute_eigenvectors) {
buffers.emplace_back(allocator::malloc(sizeof(T) * N * N * 2));
}
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
}
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
auto vec_tmp_data =
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
for (size_t i = 0; i < size / (N * N); ++i) {
void run(T* a, O* values, O* vectors) {
auto eig_tmp = static_cast<T*>(buffers[0].buffer.raw_ptr());
T* vec_tmp = nullptr;
if (vectors) {
vec_tmp = static_cast<T*>(buffers[1].buffer.raw_ptr());
}
auto work = static_cast<T*>(buffers.back().buffer.raw_ptr());
int n_vecs_l = vectors ? N : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
a_ptr,
a,
&N,
eig_tmp,
eig_tmp + N,
vec_tmp,
vectors ? vec_tmp : nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
static_cast<T*>(work_buf.buffer.raw_ptr()),
work,
&lwork,
&info);
for (int i = 0; i < N; ++i) {
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
values[i] = to_complex(eig_tmp[i], eig_tmp[N + i]);
}
if (vec_ptr) {
if (vectors) {
for (int i = 0; i < N; ++i) {
if (eig_ptr[i].imag() != 0) {
// This vector and the next are a pair
if (values[i].imag() != 0) {
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
vec_ptr[(i + 1) * N + j] = {
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
vectors[i * N + j] =
to_complex(vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]);
vectors[(i + 1) * N + j] =
to_complex(vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]);
}
i += 1;
} else {
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
vectors[i * N + j] = to_complex(vec_tmp[i * N + j], T(0.0));
}
}
}
}
}
};
template <>
struct EigWork<std::complex<float>> {
using T = std::complex<float>;
using R = float;
using O = T;
char jobl;
char jobr;
int N;
int lwork;
int lrwork;
int info;
std::vector<array::Data> buffers;
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1), lrwork(2 * N_) {
T work;
R rwork;
int n_vecs_l = compute_eigenvectors ? N_ : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&rwork,
&info);
lwork = static_cast<int>(work.real());
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
}
void run(T* a, T* values, T* vectors) {
int n_vecs_l = vectors ? N : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
a,
&N,
values,
vectors,
&n_vecs_l,
nullptr,
&n_vecs_r,
static_cast<T*>(buffers[0].buffer.raw_ptr()),
&lwork,
static_cast<R*>(buffers[1].buffer.raw_ptr()),
&info);
}
};
template <typename T>
void eig_impl(
array& a,
array& vectors,
array& values,
bool compute_eigenvectors,
Stream stream) {
auto a_ptr = a.data<T>();
auto val_ptr = values.data<complex64_t>();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(values);
complex64_t* vec_ptr = nullptr;
if (compute_eigenvectors) {
encoder.set_output_array(vectors);
vec_ptr = vectors.data<complex64_t>();
}
encoder.dispatch([a_ptr,
val_ptr,
vec_ptr,
compute_eigenvectors,
N = vectors.shape(-1),
size = vectors.size()]() mutable {
char jobr = 'N';
char jobl = compute_eigenvectors ? 'V' : 'N';
EigWork<T> work(jobl, jobr, N, compute_eigenvectors);
for (size_t i = 0; i < size / (N * N); ++i) {
work.run(a_ptr, val_ptr, vec_ptr);
a_ptr += N * N;
val_ptr += N;
if (vec_ptr) {
vec_ptr += N * N;
}
a_ptr += N * N;
eig_ptr += N;
if (info != 0) {
if (work.info != 0) {
std::stringstream msg;
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
<< work.info;
throw std::runtime_error(msg.str());
}
}
@@ -165,8 +264,17 @@ void Eig::eval_cpu(
case float32:
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
break;
case float64:
eig_impl<double>(
a_copy, vectors, values, compute_eigenvectors_, stream());
break;
case complex64:
eig_impl<std::complex<float>>(
a_copy, vectors, values, compute_eigenvectors_, stream());
break;
default:
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
throw std::runtime_error(
"[Eig::eval_cpu] only supports float32, float64, or complex64.");
}
}

View File

@@ -45,9 +45,7 @@
INSTANTIATE_LAPACK_REAL(geqrf)
INSTANTIATE_LAPACK_REAL(orgqr)
INSTANTIATE_LAPACK_REAL(syevd)
INSTANTIATE_LAPACK_REAL(geev)
INSTANTIATE_LAPACK_REAL(potrf)
INSTANTIATE_LAPACK_REAL(gesdd)
INSTANTIATE_LAPACK_REAL(getrf)
INSTANTIATE_LAPACK_REAL(getri)
INSTANTIATE_LAPACK_REAL(trtri)
@@ -63,3 +61,20 @@ INSTANTIATE_LAPACK_REAL(trtri)
}
INSTANTIATE_LAPACK_COMPLEX(heevd)
#define INSTANTIATE_LAPACK_ALL(FUNC) \
template <typename T, typename... Args> \
void FUNC(Args... args) { \
if constexpr (std::is_same_v<T, float>) { \
MLX_LAPACK_FUNC(s##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, double>) { \
MLX_LAPACK_FUNC(d##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, std::complex<float>>) { \
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
} \
}
INSTANTIATE_LAPACK_ALL(geev)
INSTANTIATE_LAPACK_ALL(gesdd)

View File

@@ -8,6 +8,183 @@
namespace mlx::core {
template <typename T, class Enable = void>
struct SVDWork {};
template <typename T>
struct SVDWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
using R = T;
int N;
int M;
int K;
int lda;
int ldu;
int ldvt;
char jobz;
std::vector<array::Data> buffers;
int lwork;
SVDWork(int N, int M, int K, char jobz)
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
int lwork_query = -1;
int info;
// Compute workspace size.
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
lwork = workspace_dimension;
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
}
void run(T* a, R* s, T* u, T* vt) {
int info;
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ a,
/* lda = */ &lda,
/* s = */ s,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ u,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ vt,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(buffers[1].buffer.raw_ptr()),
/* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
}
};
template <>
struct SVDWork<std::complex<float>> {
using T = std::complex<float>;
using R = float;
int N;
int M;
int K;
int lda;
int ldu;
int ldvt;
char jobz;
std::vector<array::Data> buffers;
int lwork;
SVDWork(int N, int M, int K, char jobz)
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
const int lrwork =
jobz == 'A' ? std::max(1, 5 * K * K + 5 * K) : std::max(1, 7 * K);
buffers.emplace_back(allocator::malloc(sizeof(float) * lrwork));
int lwork_query = -1;
int work_query = -1;
int info;
// Compute workspace size.
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
lwork = workspace_dimension.real();
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
}
void run(T* a, R* s, T* u, T* vt) {
int info;
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ a,
/* lda = */ &lda,
/* s = */ s,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ u,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ vt,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(buffers[2].buffer.raw_ptr()),
/* lwork = */ &lwork,
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
}
};
template <typename T>
void svd_impl(
const array& a,
@@ -27,6 +204,8 @@ void svd_impl(
const int N = a.shape(-1);
const int K = std::min(M, N);
using R = typename SVDWork<T>::R;
size_t num_matrices = a.size() / (M * N);
// lapack clobbers the input, so we have to make a copy.
@@ -42,7 +221,7 @@ void svd_impl(
encoder.set_input_array(a);
auto in_ptr = in.data<T>();
T* u_ptr;
T* s_ptr;
R* s_ptr;
T* vt_ptr;
if (compute_uv) {
@@ -58,7 +237,7 @@ void svd_impl(
encoder.set_output_array(s);
encoder.set_output_array(vt);
s_ptr = s.data<T>();
s_ptr = s.data<R>();
u_ptr = u.data<T>();
vt_ptr = vt.data<T>();
} else {
@@ -68,96 +247,26 @@ void svd_impl(
encoder.set_output_array(s);
s_ptr = s.data<T>();
s_ptr = s.data<R>();
u_ptr = nullptr;
vt_ptr = nullptr;
}
encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() {
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
const int lda = N;
// U of shape M x M. (N x N in lapack).
const int ldu = N;
// Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M;
auto jobz = (u_ptr) ? "A" : "N";
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)};
static const int lwork_query = -1;
int info;
// Compute workspace size.
gesdd<T>(
/* jobz = */ jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_dimension;
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
auto jobz = (u_ptr) ? 'A' : 'N';
SVDWork<T> svd_work(N, M, K, jobz);
// Loop over matrices.
for (int i = 0; i < num_matrices; i++) {
gesdd<T>(
/* jobz = */ jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ in_ptr + M * N * i,
/* lda = */ &lda,
/* s = */ s_ptr + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ u_ptr ? u_ptr + M * M * i : nullptr,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
svd_work.run(
in_ptr + M * N * i,
s_ptr + K * i,
vt_ptr ? vt_ptr + N * N * i : nullptr,
u_ptr ? u_ptr + M * M * i : nullptr);
}
});
encoder.add_temporary(in);
}
template <typename T>
void compute_svd(
const array& a,
bool compute_uv,
std::vector<array>& outputs,
Stream stream) {}
void SVD::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
@@ -168,9 +277,12 @@ void SVD::eval_cpu(
case float64:
svd_impl<double>(inputs[0], outputs, compute_uv_, stream());
break;
case complex64:
svd_impl<std::complex<float>>(inputs[0], outputs, compute_uv_, stream());
break;
default:
throw std::runtime_error(
"[SVD::eval_cpu] only supports float32 or float64.");
"[SVD::eval_cpu] only supports float32, float64, or complex64.");
}
}

View File

@@ -92,7 +92,7 @@ CudaAllocator::CudaAllocator()
[this](CudaBuffer* buf) { cuda_free(buf); }) {
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.95;
memory_limit_ = total * 0.9;
max_pool_size_ = memory_limit_;
int device_count = 0;
@@ -176,7 +176,7 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
// Copy to managed here if the buffer is not on the right device
if (buf->device != device) {
if (buf->device >= 0 && buf->device != device) {
copy_to_managed(*buf);
}
return Buffer{buf};
@@ -219,9 +219,9 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) {
scalar_pool_.free(buf);
} else {
if (buf->device >= 0) {
cudaFreeAsync(buf->data, free_streams_[buf->device]);
CHECK_CUDA_ERROR(cudaFreeAsync(buf->data, free_streams_[buf->device]));
} else {
cudaFree(buf->data);
CHECK_CUDA_ERROR(cudaFree(buf->data));
}
delete buf;
}

View File

@@ -115,18 +115,17 @@ CommandEncoder::ConcurrentContext::~ConcurrentContext() {
}
// Use an empty graph node for synchronization
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
enc.empty_node_count_++;
CommandEncoder::GraphNode empty{NULL, "E", std::to_string(enc.node_count_++)};
CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0));
// Insert the concurrent -> empty node dependencies
for (auto& from : enc.concurrent_nodes_) {
enc.from_nodes_.push_back(from.node);
enc.to_nodes_.push_back(empty.node);
enc.graph_key_ += from.id;
enc.graph_key_ += from.node_type;
enc.graph_key_ += empty.id;
enc.graph_key_ += empty.node_type;
enc.graph_deps_key_ += from.id;
enc.graph_deps_key_ += "-";
enc.graph_deps_key_ += empty.id;
enc.graph_deps_key_ += "-";
}
// Insert the input -> concurrent node dependencies without updating output
@@ -141,9 +140,6 @@ CommandEncoder::ConcurrentContext::~ConcurrentContext() {
}
void CommandEncoder::insert_graph_dependencies(GraphNode node) {
if (node.node_type == 'G') {
graph_node_count_++;
}
node.id = std::to_string(node_count_++);
if (in_concurrent_) {
concurrent_nodes_.push_back(std::move(node));
@@ -155,6 +151,10 @@ void CommandEncoder::insert_graph_dependencies(GraphNode node) {
}
void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
for (auto& node : nodes) {
graph_nodes_key_ += node.node_type;
graph_nodes_key_ += "-";
}
std::vector<GraphNode> deps;
{
// Dependencies must be added in the same order to produce a consistent
@@ -182,10 +182,10 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
for (auto& to : nodes) {
from_nodes_.push_back(from.node);
to_nodes_.push_back(to.node);
graph_key_ += from.id;
graph_key_ += from.node_type;
graph_key_ += to.id;
graph_key_ += to.node_type;
graph_deps_key_ += from.id;
graph_deps_key_ += "-";
graph_deps_key_ += to.id;
graph_deps_key_ += "-";
}
}
}
@@ -309,13 +309,46 @@ void CommandEncoder::add_kernel_node(
void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
cudaGraphNode_t node;
CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, &params));
insert_graph_dependencies(GraphNode{node, 'K'});
insert_graph_dependencies(GraphNode{node, "K"});
}
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
CUgraphNode node;
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, &params));
insert_graph_dependencies(GraphNode{node, 'K'});
insert_graph_dependencies(GraphNode{node, "K"});
}
bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
// CUDA graphs do not get updated correctly if a kernel node getting updated
// has a different cluster shape than the node it's being updated with.
size_t num_nodes = 0;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));
if (num_nodes == 0) {
return true;
}
std::vector<cudaGraphNode_t> nodes(num_nodes);
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
for (const auto& node : nodes) {
cudaGraphNodeType type;
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
if (type != cudaGraphNodeTypeKernel) {
return false;
}
cudaLaunchAttributeValue cluster_dim;
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
// Only dim.x can be greater than 1
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
return false;
}
// Only one child node allowed when subgraph uses clusters
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
return false;
}
cluster_dim_x = cluster_dim.clusterDim.x;
}
return true;
}
void CommandEncoder::add_graph_node(cudaGraph_t child) {
@@ -328,8 +361,11 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
return;
}
cudaGraphNode_t node;
int cluster_dim_x = 0;
is_graph_updatable_ = is_graph_updatable(child, cluster_dim_x);
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies(GraphNode{node, 'G'});
insert_graph_dependencies(
GraphNode{node, "G" + std::to_string(cluster_dim_x)});
}
bool CommandEncoder::needs_commit() {
@@ -354,14 +390,15 @@ void CommandEncoder::commit() {
from_nodes_.size()));
}
graph_key_ += ".";
graph_key_ += std::to_string(node_count_);
graph_key_ += ".";
graph_key_ += std::to_string(graph_node_count_);
graph_key_ += ".";
graph_key_ += std::to_string(empty_node_count_);
device_.make_current();
CudaGraphExec& graph_exec = graph_cache_[graph_key_];
if (!is_graph_updatable_) {
CudaGraphExec graph_exec;
graph_exec.instantiate(graph_);
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
} else {
auto graph_key = graph_nodes_key_ + ":" + graph_deps_key_;
auto& graph_exec = graph_cache_[graph_key];
if (graph_exec != nullptr) {
cudaGraphExecUpdateResult update_result;
@@ -381,17 +418,17 @@ void CommandEncoder::commit() {
if (graph_exec == nullptr) {
graph_exec.instantiate(graph_);
}
device_.make_current();
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
}
// Reset state
graph_node_count_ = 0;
empty_node_count_ = 0;
from_nodes_.clear();
to_nodes_.clear();
graph_key_.clear();
graph_deps_key_.clear();
graph_nodes_key_.clear();
node_map_.clear();
graph_ = CudaGraph(device_);
is_graph_updatable_ = true;
}
// Put completion handlers in a batch.

View File

@@ -106,8 +106,9 @@ class CommandEncoder {
cudaGraphNode_t node;
// K = kernel
// E = empty
// G = subgraph
char node_type;
// G* = subgraph (with metadata)
// Symbols ':', '-' are reserved as separators
std::string node_type;
std::string id;
};
@@ -119,12 +120,11 @@ class CommandEncoder {
CudaGraph graph_;
Worker worker_;
char node_count_{0};
char graph_node_count_{0};
char empty_node_count_{0};
bool in_concurrent_{false};
std::vector<cudaGraphNode_t> from_nodes_;
std::vector<cudaGraphNode_t> to_nodes_;
std::string graph_key_;
std::string graph_nodes_key_;
std::string graph_deps_key_;
std::vector<GraphNode> concurrent_nodes_;
std::vector<std::shared_ptr<array::Data>> temporaries_;
LRUCache<std::string, CudaGraphExec> graph_cache_;
@@ -132,6 +132,7 @@ class CommandEncoder {
std::vector<std::uintptr_t> active_outputs_;
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
size_t bytes_in_graph_{0};
bool is_graph_updatable_{true};
int max_ops_per_graph_;
int max_mb_per_graph_;
};

View File

@@ -139,10 +139,10 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
// keys has shape (N1, ..., NK, 2)
// out has shape (N1, ..., NK, M1, M2, ...)
auto& keys = inputs[0];
uint32_t num_keys = keys.size() / 2;
size_t num_keys = keys.size() / 2;
uint32_t elems_per_key = out.size() / num_keys;
uint32_t bytes_per_key = out.itemsize() * elems_per_key;
size_t elems_per_key = out.size() / num_keys;
size_t bytes_per_key = out.itemsize() * elems_per_key;
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder));
@@ -150,19 +150,25 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
uint32_t out_per_key = (bytes_per_key + 4 - 1) / 4;
uint32_t half_size = out_per_key / 2;
size_t out_per_key = (bytes_per_key + 4 - 1) / 4;
size_t half_size = out_per_key / 2;
bool odd = out_per_key % 2;
if ((half_size + odd) >= UINT32_MAX || num_keys >= UINT32_MAX) {
throw std::runtime_error("[RandomBits::eval_gpu] Large size unsupported");
}
encoder.set_input_array(keys);
encoder.set_output_array(out);
dim3 grid_dims{num_keys, half_size + odd};
int64_t total = grid_dims.x * grid_dims.y;
int32_t threads_y = 1;
while ((total / threads_y) >= (1U << 31)) {
int64_t total = num_keys * (half_size + odd);
uint32_t threads_y = 1;
while ((total / threads_y) >= UINT_MAX) {
threads_y *= 2;
}
int32_t threads_x = cuda::ceil_div(total, threads_y);
uint32_t threads_x = cuda::ceil_div(total, threads_y);
dim3 grid_dims{
static_cast<uint32_t>(num_keys), static_cast<uint32_t>(half_size + odd)};
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
auto& stream = encoder.stream();
if (keys.flags().row_contiguous) {

View File

@@ -5,6 +5,7 @@
#include "mlx/dtype_utils.h"
#include <fmt/format.h>
#include <vector>
namespace mlx::core {

View File

@@ -121,14 +121,6 @@ if(NOT MLX_METAL_PATH)
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
endif()
if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL
26.2))
set(MLX_ENABLE_NAX TRUE)
target_compile_definitions(mlx PRIVATE MLX_ENABLE_NAX)
else()
set(MLX_ENABLE_NAX FALSE)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
target_compile_definitions(mlx

View File

@@ -265,14 +265,19 @@ Device& device(mlx::core::Device);
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
#ifdef MLX_ENABLE_NAX
inline bool is_nax_available() {
static bool is_nax_available_ =
auto _check_nax = []() {
bool can_use_nax = false;
if (__builtin_available(
macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
can_use_nax = true;
}
can_use_nax &=
metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17;
return can_use_nax;
};
static bool is_nax_available_ = _check_nax();
return is_nax_available_;
}
#endif // MLX_ENABLE_NAX
} // namespace mlx::core::metal

View File

@@ -9,13 +9,17 @@ set(BASE_HEADERS
utils.h)
function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -x metal -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
set(METAL_FLAGS
-x
metal
-Wall
-Wextra
-fno-fast-math
-Wno-c++17-extensions
-Wno-c++20-extensions)
if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
endif()
if(MLX_ENABLE_NAX)
set(METAL_FLAGS ${METAL_FLAGS} -Wno-c++20-extensions -std=metal4.0)
endif()
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
set(METAL_FLAGS ${METAL_FLAGS}
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
@@ -123,8 +127,8 @@ if(NOT MLX_METAL_JIT)
build_kernel(gemv_masked steel/utils.h)
endif()
if(MLX_ENABLE_NAX)
if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL
26.2))
set(STEEL_NAX_HEADERS
steel/defines.h
steel/utils.h

View File

@@ -172,8 +172,6 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
// Regular steel matmul dispatch
///////////////////////////////////////////////////////////////////////////////
#ifdef MLX_ENABLE_NAX
template <bool CHECK_AB>
void steel_matmul_regular_axpby_nax(
const Stream& s,
@@ -329,8 +327,6 @@ void steel_matmul_regular_axpby_nax(
d.add_temporaries(std::move(copies), s.index);
}
#endif // MLX_ENABLE_NAX
template <bool CHECK_AB>
void steel_matmul_regular_axpby(
const Stream& s,
@@ -357,9 +353,6 @@ void steel_matmul_regular_axpby(
int64_t C_batch_stride /* = 0*/,
float alpha /* = 1.0f */,
float beta /* = 0.0f */) {
#ifdef MLX_ENABLE_NAX
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) &&
(env::enable_tf32() || a.dtype() != float32)) {
return steel_matmul_regular_axpby_nax<CHECK_AB>(
@@ -388,9 +381,6 @@ void steel_matmul_regular_axpby(
/* float alpha = */ alpha,
/* float beta = */ beta);
}
}
#endif // MLX_ENABLE_NAX
using namespace mlx::steel;
@@ -1766,8 +1756,6 @@ void gather_mm_rhs(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
#ifdef MLX_ENABLE_NAX
void gather_mm_rhs_nax(
const array& a_,
const array& b_,
@@ -1911,8 +1899,6 @@ void gather_mm_rhs_nax(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
#endif // MLX_ENABLE_NAX
void gather_mv(
const array& mat_,
const array& vec_,
@@ -2196,19 +2182,10 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// We are walking a in order and b is also in order so we can batch up the
// matmuls and reuse reading a and b.
if (M == 1 && right_sorted_ == true) {
#ifdef MLX_ENABLE_NAX
if (__builtin_available(
macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() &&
!issubdtype(a.dtype(), complexfloating) &&
(env::enable_tf32() || a.dtype() != float32)) {
return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s);
}
}
#endif // MLX_ENABLE_NAX
gather_mm_rhs(a, b, rhs_indices, out, d, s);
return;
}

View File

@@ -451,8 +451,6 @@ void qvm(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
#ifdef MLX_ENABLE_NAX
void qmm_nax(
const array& x,
const array& w,
@@ -653,8 +651,6 @@ void gather_qmm_nax(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
#endif // MLX_ENABLE_NAX
void qmm(
const array& x,
const array& w,
@@ -670,9 +666,6 @@ void qmm(
metal::Device& d,
const Stream& s,
const std::string& mode) {
#ifdef MLX_ENABLE_NAX
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
(env::enable_tf32() || x.dtype() != float32)) {
return qmm_nax(
@@ -691,9 +684,6 @@ void qmm(
/* const Stream& s = */ s,
/* const std::string& mode = */ mode);
}
}
#endif // MLX_ENABLE_NAX
int B = out.size() / M / N;
@@ -772,9 +762,6 @@ void gather_qmm(
metal::Device& d,
const Stream& s,
const std::string& mode) {
#ifdef MLX_ENABLE_NAX
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
(env::enable_tf32() || x.dtype() != float32)) {
return gather_qmm_nax(
@@ -795,9 +782,6 @@ void gather_qmm(
/* const Stream& s = */ s,
/* const std::string& mode = */ mode);
}
}
#endif // MLX_ENABLE_NAX
int B = out.size() / M / N;
@@ -975,8 +959,6 @@ void gather_qvm(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
#ifdef MLX_ENABLE_NAX
void gather_qmm_rhs_nax(
const array& x_,
const array& w_,
@@ -1108,8 +1090,6 @@ void gather_qmm_rhs_nax(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
#endif // MLX_ENABLE_NAX
void gather_qmm_rhs(
const array& x_,
const array& w_,
@@ -1126,9 +1106,6 @@ void gather_qmm_rhs(
metal::Device& d,
const Stream& s,
const std::string mode) {
#ifdef MLX_ENABLE_NAX
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && transpose &&
(env::enable_tf32() || x_.dtype() != float32)) {
return gather_qmm_rhs_nax(
@@ -1148,9 +1125,6 @@ void gather_qmm_rhs(
/* const Stream& s = */ s,
/* const std::string mode = */ mode);
}
}
#endif // MLX_ENABLE_NAX
// Start by normalizing the indices
array indices = ensure_row_contiguous(indices_, d, s);

View File

@@ -13,8 +13,6 @@ namespace mlx::core::fast {
namespace {
#ifdef MLX_ENABLE_NAX
void sdpa_full_self_attention_nax(
const Stream& s,
metal::Device& d,
@@ -150,8 +148,6 @@ void sdpa_full_self_attention_nax(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
#endif // MLX_ENABLE_NAX
void sdpa_full_self_attention_metal(
const Stream& s,
metal::Device& d,
@@ -163,8 +159,6 @@ void sdpa_full_self_attention_metal(
bool do_causal_,
const std::optional<array>& mask,
const std::optional<array>& sinks) {
#ifdef MLX_ENABLE_NAX
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && q.shape(3) != 80 &&
(env::enable_tf32() || q.dtype() != float32)) {
return sdpa_full_self_attention_nax(
@@ -179,8 +173,6 @@ void sdpa_full_self_attention_metal(
/* const std::optional<array>& mask = */ mask,
/* const std::optional<array>& sinks = */ sinks);
}
}
#endif // MLX_ENABLE_NAX
using namespace mlx::steel;

View File

@@ -250,7 +250,7 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
std::vector<array>
svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
check_cpu_stream(s, "[linalg::svd]");
check_float(a.dtype(), "[linalg::svd]");
check_float_or_complex(a.dtype(), "[linalg::svd]");
if (a.ndim() < 2) {
std::ostringstream msg;
@@ -268,10 +268,12 @@ svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
s_shape.pop_back();
s_shape[rank - 2] = std::min(m, n);
auto s_dtype = a.dtype() == complex64 ? float32 : a.dtype();
if (!compute_uv) {
return {array(
std::move(s_shape),
a.dtype(),
s_dtype,
std::make_shared<SVD>(to_stream(s), compute_uv),
{a})};
}
@@ -286,7 +288,7 @@ svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
return array::make_arrays(
{u_shape, s_shape, vt_shape},
{a.dtype(), a.dtype(), a.dtype()},
{a.dtype(), s_dtype, a.dtype()},
std::make_shared<SVD>(to_stream(s), compute_uv),
{a});
}

View File

@@ -4,7 +4,7 @@
#define MLX_VERSION_MAJOR 0
#define MLX_VERSION_MINOR 30
#define MLX_VERSION_PATCH 0
#define MLX_VERSION_PATCH 1
#define MLX_VERSION_NUMERIC \
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)

View File

@@ -4021,7 +4021,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def load(file: Union[file, str, pathlib.Path], /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"),
"def load(file: Union[file, str, pathlib.Path], /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array], Tuple[dict[str, array], dict[str, Any]]]"),
R"pbdoc(
Load array(s) from a binary file.
@@ -4037,11 +4037,12 @@ void init_ops(nb::module_& m) {
which support matadata. The metadata will be returned as an
additional dictionary. Default: ``False``.
Returns:
array or dict:
array, dict, or tuple:
A single array if loading from a ``.npy`` file or a dict
mapping names to arrays if loading from a ``.npz`` or
``.safetensors`` file. If ``return_metadata`` is ``True`` an
additional dictionary of metadata will be returned.
``.safetensors`` file. If ``return_metadata`` is ``True`` a
tuple ``(arrays, metadata)`` will be returned where the second
element is a dictionary containing the metadata.
Warning:

View File

@@ -168,6 +168,42 @@ class TestLinalg(mlx_tests.MLXTestCase):
)
)
# Test float64 - use CPU stream since float64 is not supported on GPU
with mx.stream(mx.cpu):
A_f64 = mx.array(
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float64
)
U_f64, S_f64, Vt_f64 = mx.linalg.svd(A_f64, compute_uv=True)
mx.eval(U_f64, S_f64, Vt_f64)
self.assertTrue(
mx.allclose(
U_f64[:, : len(S_f64)] @ mx.diag(S_f64) @ Vt_f64,
A_f64,
rtol=1e-5,
atol=1e-7,
)
)
self.assertEqual(S_f64.dtype, mx.float64)
# Test complex64 - use CPU stream since complex64 is not supported on GPU
with mx.stream(mx.cpu):
A_c64 = mx.array(
[[1.0 + 1j, 2.0 + 2j], [3.0 + 3j, 4.0 + 4j]], dtype=mx.complex64
)
U_c64, S_c64, Vt_c64 = mx.linalg.svd(A_c64, compute_uv=True)
mx.eval(U_c64, S_c64, Vt_c64)
self.assertTrue(
mx.allclose(
U_c64[:, : len(S_c64)] @ mx.diag(S_c64) @ Vt_c64,
A_c64,
rtol=1e-5,
atol=1e-7,
)
)
self.assertEqual(S_c64.dtype, mx.float32)
self.assertEqual(U_c64.dtype, mx.complex64)
self.assertEqual(Vt_c64.dtype, mx.complex64)
def test_inverse(self):
A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32)
A_inv = mx.linalg.inv(A, stream=mx.cpu)
@@ -342,6 +378,43 @@ class TestLinalg(mlx_tests.MLXTestCase):
A_np = np.random.randn(3, n, n).astype(np.float32)
check_eigs_and_vecs(A_np)
# Test float64 - use CPU stream since float64 is not supported on GPU
with mx.stream(mx.cpu):
A_np_f64 = np.array([[1.0, 1.0], [3.0, 4.0]], dtype=np.float64)
A_f64 = mx.array(A_np_f64, dtype=mx.float64)
eig_vals_f64, eig_vecs_f64 = mx.linalg.eig(A_f64)
mx.eval(eig_vals_f64, eig_vecs_f64)
self.assertTrue(
mx.allclose(
A_f64 @ eig_vecs_f64,
eig_vals_f64[..., None, :] * eig_vecs_f64,
rtol=1e-5,
atol=1e-5,
)
)
# Eigenvalues should be complex64 (output dtype)
self.assertEqual(eig_vals_f64.dtype, mx.complex64)
self.assertEqual(eig_vecs_f64.dtype, mx.complex64)
# Test complex64 input - use CPU stream since complex64 is not supported on GPU
with mx.stream(mx.cpu):
A_np_c64 = np.array(
[[1.0 + 1j, 2.0 + 2j], [3.0 + 3j, 4.0 + 4j]], dtype=np.complex64
)
A_c64 = mx.array(A_np_c64, dtype=mx.complex64)
eig_vals_c64, eig_vecs_c64 = mx.linalg.eig(A_c64)
mx.eval(eig_vals_c64, eig_vecs_c64)
self.assertTrue(
mx.allclose(
A_c64 @ eig_vecs_c64,
eig_vals_c64[..., None, :] * eig_vecs_c64,
rtol=1e-5,
atol=1e-5,
)
)
self.assertEqual(eig_vals_c64.dtype, mx.complex64)
self.assertEqual(eig_vecs_c64.dtype, mx.complex64)
# Test error cases
with self.assertRaises(ValueError):
mx.linalg.eig(mx.array([1.0, 2.0])) # 1D array

View File

@@ -1443,23 +1443,22 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertListEqual(a.tolist(), expected)
def test_unary_ops(self):
def test_ops(npop, mlxop, x, y, atol):
def test_ops(npop, mlxop, x, y, atol, rtol):
r_np = npop(x)
r_mlx = mlxop(y)
mx.eval(r_mlx)
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol, rtol=rtol))
x = np.random.rand(18, 28, 38)
for op in ["abs", "exp", "log", "square", "sqrt"]:
with self.subTest(op=op):
float_dtypes = [("float16", 1e-3), ("float32", 1e-6)]
float_dtypes = [("float16", 1e-3, 1e-3), ("float32", 1e-6, 1e-5)]
for dtype, atol in float_dtypes:
for dtype, atol, rtol in float_dtypes:
with self.subTest(dtype=dtype):
x_ = x.astype(getattr(np, dtype))
y_ = mx.array(x_)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol, rtol)
def test_unary_ops_from_non_array(self):
unary_ops = [
@@ -1511,12 +1510,14 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(expected, np.array(out), equal_nan=True))
def test_trig_ops(self):
def test_ops(npop, mlxop, x, y, atol):
def test_ops(npop, mlxop, x, y, atol, rtol):
r_np = npop(x)
r_mlx = mlxop(y)
mx.eval(r_mlx)
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol, equal_nan=True))
self.assertTrue(
np.allclose(r_np, r_mlx, atol=atol, rtol=rtol, equal_nan=True)
)
x = np.random.rand(9, 12, 18)
xi = np.random.rand(9, 12, 18)
@@ -1526,34 +1527,34 @@ class TestOps(mlx_tests.MLXTestCase):
for op in all_fwd_ops:
with self.subTest(op=op):
float_dtypes = [("float16", 1e-3), ("float32", 1e-6)]
float_dtypes = [("float16", 1e-3, 1e-3), ("float32", 1e-6, 1e-5)]
for dtype, atol in float_dtypes:
for dtype, atol, rtol in float_dtypes:
with self.subTest(dtype=dtype):
x_ = x.astype(getattr(np, dtype))
y_ = mx.array(x_)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol, rtol)
with self.subTest(op=op):
float_dtypes = [("complex64", 1e-5)]
for dtype, atol in float_dtypes:
dtype = "complex64"
with self.subTest(dtype=dtype):
x_ = x + 1.0j * xi
x_ = x_.astype(getattr(np, dtype))
y_ = mx.array(x_)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, 1e-5, 1e-5)
with self.subTest(op="arc" + op):
float_dtypes = [("float16", 1e-3), ("float32", 1e-6)]
float_dtypes = [("float16", 1e-3, 1e-3), ("float32", 1e-6, 1e-5)]
op_inv = "arc" + op
for dtype, atol in float_dtypes:
for dtype, atol, rtol in float_dtypes:
with self.subTest(dtype=dtype):
np_op_fwd = getattr(np, op)
x_ = np_op_fwd(x).astype(getattr(np, dtype))
y_ = mx.array(x_)
test_ops(getattr(np, op_inv), getattr(mx, op_inv), x_, y_, atol)
test_ops(
getattr(np, op_inv), getattr(mx, op_inv), x_, y_, atol, rtol
)
# Test grads
np_vjp_funcs = {
@@ -1579,11 +1580,10 @@ class TestOps(mlx_tests.MLXTestCase):
x_ = x.astype(np.float32)
y_ = mx.array(x_)
op_ = op
atol_ = 1e-5
np_vjp = lambda x: np_vjp_funcs[op_](primal_np, x)
mx_vjp = lambda x: mx.vjp(getattr(mx, op_), [primal_mx], [x])[1][0]
test_ops(np_vjp, mx_vjp, x_, y_, atol_)
test_ops(np_vjp, mx_vjp, x_, y_, 1e-5, 1e-5)
with self.subTest(op="arc" + op):
np_op_fwd = getattr(np, op)
@@ -1599,11 +1599,10 @@ class TestOps(mlx_tests.MLXTestCase):
x_ = x.astype(np.float32)
y_ = mx.array(x_)
op_ = "arc" + op
atol_ = 1e-5
np_vjp = lambda x: np_vjp_funcs[op_](primal_np, x)
mx_vjp = lambda x: mx.vjp(getattr(mx, op_), [primal_mx], [x])[1][0]
test_ops(np_vjp, mx_vjp, x_, y_, atol_)
test_ops(np_vjp, mx_vjp, x_, y_, 1e-5, 1e-5)
def test_binary_ops(self):
def test_ops(npop, mlxop, x1, x2, y1, y2, atol):

View File

@@ -24,8 +24,8 @@ def get_version():
if "#define MLX_VERSION_PATCH" in l:
patch = l.split()[-1]
version = f"{major}.{minor}.{patch}"
pypi_release = os.environ.get("PYPI_RELEASE", False)
dev_release = os.environ.get("DEV_RELEASE", False)
pypi_release = int(os.environ.get("PYPI_RELEASE", 0))
dev_release = int(os.environ.get("DEV_RELEASE", 0))
if not pypi_release or dev_release:
today = datetime.date.today()
version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}"