mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
9 Commits
v0.30.0
...
618c87af8c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
618c87af8c | ||
|
|
d5f61a93fa | ||
|
|
4a09264236 | ||
|
|
0dbc7e5bee | ||
|
|
0d68efd461 | ||
|
|
f9e1a14135 | ||
|
|
d8e9ded928 | ||
|
|
60939d010c | ||
|
|
fdcd2923fd |
@@ -17,6 +17,8 @@ runs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Build Python package
|
- name: Build Python package
|
||||||
shell: bash -l {0}
|
shell: bash -l {0}
|
||||||
|
env:
|
||||||
|
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||||
run: |
|
run: |
|
||||||
pip install build
|
pip install build
|
||||||
python setup.py clean --all
|
python setup.py clean --all
|
||||||
@@ -25,6 +27,8 @@ runs:
|
|||||||
- name: Build backend package
|
- name: Build backend package
|
||||||
if: ${{ inputs.build-backend }}
|
if: ${{ inputs.build-backend }}
|
||||||
shell: bash -l {0}
|
shell: bash -l {0}
|
||||||
|
env:
|
||||||
|
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||||
run: |
|
run: |
|
||||||
python setup.py clean --all
|
python setup.py clean --all
|
||||||
MLX_BUILD_STAGE=2 python -m build -w
|
MLX_BUILD_STAGE=2 python -m build -w
|
||||||
|
|||||||
2
.github/workflows/pull_request.yml
vendored
2
.github/workflows/pull_request.yml
vendored
@@ -13,7 +13,7 @@ permissions:
|
|||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
cancel-in-progress: ${{ github.ref != 'refs/head/main' }}
|
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
check_lint:
|
check_lint:
|
||||||
|
|||||||
9
.github/workflows/release.yml
vendored
9
.github/workflows/release.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
|
|
||||||
build_documentation:
|
build_documentation:
|
||||||
if: github.repository == 'ml-explore/mlx'
|
if: github.repository == 'ml-explore/mlx'
|
||||||
runs-on: [self-hosted, macos]
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v5
|
||||||
- uses: ./.github/actions/build-docs
|
- uses: ./.github/actions/build-docs
|
||||||
@@ -65,14 +65,14 @@ jobs:
|
|||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
overwrite: true
|
overwrite: true
|
||||||
name: linux-wheels-${{ matrix.python_version }}
|
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
|
||||||
path: wheelhouse/mlx-*.whl
|
path: wheelhouse/mlx-*.whl
|
||||||
- name: Upload CPU artifacts
|
- name: Upload CPU artifacts
|
||||||
if: matrix.python_version == '3.10'
|
if: matrix.python_version == '3.10'
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
overwrite: true
|
overwrite: true
|
||||||
name: mlx-cpu
|
name: mlx-cpu-${{ matrix.arch }}
|
||||||
path: wheelhouse/mlx_cpu-*.whl
|
path: wheelhouse/mlx_cpu-*.whl
|
||||||
|
|
||||||
build_mac_release:
|
build_mac_release:
|
||||||
@@ -208,7 +208,8 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- uses: actions/download-artifact@v6
|
- uses: actions/download-artifact@v6
|
||||||
with:
|
with:
|
||||||
name: mlx-cpu
|
pattern: mlx-cpu-*
|
||||||
|
merge-multiple: true
|
||||||
path: dist
|
path: dist
|
||||||
- name: Display structure of downloaded files
|
- name: Display structure of downloaded files
|
||||||
run: ls -R dist
|
run: ls -R dist
|
||||||
|
|||||||
@@ -12,6 +12,167 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
complex64_t to_complex(T r, T i) {
|
||||||
|
return {static_cast<float>(r), static_cast<float>(i)};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, class Enable = void>
|
||||||
|
struct EigWork {};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct EigWork<
|
||||||
|
T,
|
||||||
|
typename std::enable_if<std::is_floating_point<T>::value>::type> {
|
||||||
|
using O = complex64_t;
|
||||||
|
|
||||||
|
char jobl;
|
||||||
|
char jobr;
|
||||||
|
int N;
|
||||||
|
int lwork;
|
||||||
|
int info;
|
||||||
|
std::vector<array::Data> buffers;
|
||||||
|
|
||||||
|
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
|
||||||
|
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1) {
|
||||||
|
T work;
|
||||||
|
int n_vecs_l = compute_eigenvectors ? N_ : 1;
|
||||||
|
int n_vecs_r = 1;
|
||||||
|
geev<T>(
|
||||||
|
&jobl,
|
||||||
|
&jobr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_l,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_r,
|
||||||
|
&work,
|
||||||
|
&lwork,
|
||||||
|
&info);
|
||||||
|
lwork = static_cast<int>(work);
|
||||||
|
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * N * 2));
|
||||||
|
if (compute_eigenvectors) {
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * N * N * 2));
|
||||||
|
}
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||||
|
}
|
||||||
|
|
||||||
|
void run(T* a, O* values, O* vectors) {
|
||||||
|
auto eig_tmp = static_cast<T*>(buffers[0].buffer.raw_ptr());
|
||||||
|
T* vec_tmp = nullptr;
|
||||||
|
if (vectors) {
|
||||||
|
vec_tmp = static_cast<T*>(buffers[1].buffer.raw_ptr());
|
||||||
|
}
|
||||||
|
auto work = static_cast<T*>(buffers.back().buffer.raw_ptr());
|
||||||
|
|
||||||
|
int n_vecs_l = vectors ? N : 1;
|
||||||
|
int n_vecs_r = 1;
|
||||||
|
geev<T>(
|
||||||
|
&jobl,
|
||||||
|
&jobr,
|
||||||
|
&N,
|
||||||
|
a,
|
||||||
|
&N,
|
||||||
|
eig_tmp,
|
||||||
|
eig_tmp + N,
|
||||||
|
vectors ? vec_tmp : nullptr,
|
||||||
|
&n_vecs_l,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_r,
|
||||||
|
work,
|
||||||
|
&lwork,
|
||||||
|
&info);
|
||||||
|
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
values[i] = to_complex(eig_tmp[i], eig_tmp[N + i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vectors) {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
if (values[i].imag() != 0) {
|
||||||
|
for (int j = 0; j < N; ++j) {
|
||||||
|
vectors[i * N + j] =
|
||||||
|
to_complex(vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]);
|
||||||
|
vectors[(i + 1) * N + j] =
|
||||||
|
to_complex(vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]);
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
} else {
|
||||||
|
for (int j = 0; j < N; ++j) {
|
||||||
|
vectors[i * N + j] = to_complex(vec_tmp[i * N + j], T(0.0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct EigWork<std::complex<float>> {
|
||||||
|
using T = std::complex<float>;
|
||||||
|
using R = float;
|
||||||
|
using O = T;
|
||||||
|
|
||||||
|
char jobl;
|
||||||
|
char jobr;
|
||||||
|
int N;
|
||||||
|
int lwork;
|
||||||
|
int lrwork;
|
||||||
|
int info;
|
||||||
|
std::vector<array::Data> buffers;
|
||||||
|
|
||||||
|
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
|
||||||
|
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1), lrwork(2 * N_) {
|
||||||
|
T work;
|
||||||
|
R rwork;
|
||||||
|
int n_vecs_l = compute_eigenvectors ? N_ : 1;
|
||||||
|
int n_vecs_r = 1;
|
||||||
|
geev<T>(
|
||||||
|
&jobl,
|
||||||
|
&jobr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_l,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_r,
|
||||||
|
&work,
|
||||||
|
&lwork,
|
||||||
|
&rwork,
|
||||||
|
&info);
|
||||||
|
lwork = static_cast<int>(work.real());
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
|
||||||
|
}
|
||||||
|
|
||||||
|
void run(T* a, T* values, T* vectors) {
|
||||||
|
int n_vecs_l = vectors ? N : 1;
|
||||||
|
int n_vecs_r = 1;
|
||||||
|
geev<T>(
|
||||||
|
&jobl,
|
||||||
|
&jobr,
|
||||||
|
&N,
|
||||||
|
a,
|
||||||
|
&N,
|
||||||
|
values,
|
||||||
|
vectors,
|
||||||
|
&n_vecs_l,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_r,
|
||||||
|
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
||||||
|
&lwork,
|
||||||
|
static_cast<R*>(buffers[1].buffer.raw_ptr()),
|
||||||
|
&info);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void eig_impl(
|
void eig_impl(
|
||||||
array& a,
|
array& a,
|
||||||
@@ -19,101 +180,39 @@ void eig_impl(
|
|||||||
array& values,
|
array& values,
|
||||||
bool compute_eigenvectors,
|
bool compute_eigenvectors,
|
||||||
Stream stream) {
|
Stream stream) {
|
||||||
using OT = std::complex<T>;
|
|
||||||
auto a_ptr = a.data<T>();
|
auto a_ptr = a.data<T>();
|
||||||
auto eig_ptr = values.data<OT>();
|
auto val_ptr = values.data<complex64_t>();
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
encoder.set_input_array(a);
|
encoder.set_input_array(a);
|
||||||
encoder.set_output_array(values);
|
encoder.set_output_array(values);
|
||||||
OT* vec_ptr = nullptr;
|
complex64_t* vec_ptr = nullptr;
|
||||||
if (compute_eigenvectors) {
|
if (compute_eigenvectors) {
|
||||||
encoder.set_output_array(vectors);
|
encoder.set_output_array(vectors);
|
||||||
vec_ptr = vectors.data<OT>();
|
vec_ptr = vectors.data<complex64_t>();
|
||||||
}
|
}
|
||||||
encoder.dispatch([a_ptr,
|
encoder.dispatch([a_ptr,
|
||||||
|
val_ptr,
|
||||||
vec_ptr,
|
vec_ptr,
|
||||||
eig_ptr,
|
|
||||||
compute_eigenvectors,
|
compute_eigenvectors,
|
||||||
N = vectors.shape(-1),
|
N = vectors.shape(-1),
|
||||||
size = vectors.size()]() mutable {
|
size = vectors.size()]() mutable {
|
||||||
// Work query
|
|
||||||
char jobr = 'N';
|
char jobr = 'N';
|
||||||
char jobl = compute_eigenvectors ? 'V' : 'N';
|
char jobl = compute_eigenvectors ? 'V' : 'N';
|
||||||
int n_vecs_r = 1;
|
|
||||||
int n_vecs_l = compute_eigenvectors ? N : 1;
|
|
||||||
int lwork = -1;
|
|
||||||
int info;
|
|
||||||
{
|
|
||||||
T work;
|
|
||||||
geev<T>(
|
|
||||||
&jobl,
|
|
||||||
&jobr,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
&n_vecs_l,
|
|
||||||
nullptr,
|
|
||||||
&n_vecs_r,
|
|
||||||
&work,
|
|
||||||
&lwork,
|
|
||||||
&info);
|
|
||||||
lwork = static_cast<int>(work);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
|
EigWork<T> work(jobl, jobr, N, compute_eigenvectors);
|
||||||
auto vec_tmp_data =
|
|
||||||
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
|
|
||||||
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
|
|
||||||
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
|
|
||||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
|
||||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||||
geev<T>(
|
work.run(a_ptr, val_ptr, vec_ptr);
|
||||||
&jobl,
|
a_ptr += N * N;
|
||||||
&jobr,
|
val_ptr += N;
|
||||||
&N,
|
|
||||||
a_ptr,
|
|
||||||
&N,
|
|
||||||
eig_tmp,
|
|
||||||
eig_tmp + N,
|
|
||||||
vec_tmp,
|
|
||||||
&n_vecs_l,
|
|
||||||
nullptr,
|
|
||||||
&n_vecs_r,
|
|
||||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
|
||||||
&lwork,
|
|
||||||
&info);
|
|
||||||
for (int i = 0; i < N; ++i) {
|
|
||||||
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
|
|
||||||
}
|
|
||||||
if (vec_ptr) {
|
if (vec_ptr) {
|
||||||
for (int i = 0; i < N; ++i) {
|
|
||||||
if (eig_ptr[i].imag() != 0) {
|
|
||||||
// This vector and the next are a pair
|
|
||||||
for (int j = 0; j < N; ++j) {
|
|
||||||
vec_ptr[i * N + j] = {
|
|
||||||
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
|
|
||||||
vec_ptr[(i + 1) * N + j] = {
|
|
||||||
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
|
|
||||||
}
|
|
||||||
i += 1;
|
|
||||||
} else {
|
|
||||||
for (int j = 0; j < N; ++j) {
|
|
||||||
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
vec_ptr += N * N;
|
vec_ptr += N * N;
|
||||||
}
|
}
|
||||||
a_ptr += N * N;
|
if (work.info != 0) {
|
||||||
eig_ptr += N;
|
|
||||||
if (info != 0) {
|
|
||||||
std::stringstream msg;
|
std::stringstream msg;
|
||||||
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
|
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||||
<< info;
|
<< work.info;
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -165,8 +264,17 @@ void Eig::eval_cpu(
|
|||||||
case float32:
|
case float32:
|
||||||
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
|
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
eig_impl<double>(
|
||||||
|
a_copy, vectors, values, compute_eigenvectors_, stream());
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
eig_impl<std::complex<float>>(
|
||||||
|
a_copy, vectors, values, compute_eigenvectors_, stream());
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
|
throw std::runtime_error(
|
||||||
|
"[Eig::eval_cpu] only supports float32, float64, or complex64.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -45,9 +45,7 @@
|
|||||||
INSTANTIATE_LAPACK_REAL(geqrf)
|
INSTANTIATE_LAPACK_REAL(geqrf)
|
||||||
INSTANTIATE_LAPACK_REAL(orgqr)
|
INSTANTIATE_LAPACK_REAL(orgqr)
|
||||||
INSTANTIATE_LAPACK_REAL(syevd)
|
INSTANTIATE_LAPACK_REAL(syevd)
|
||||||
INSTANTIATE_LAPACK_REAL(geev)
|
|
||||||
INSTANTIATE_LAPACK_REAL(potrf)
|
INSTANTIATE_LAPACK_REAL(potrf)
|
||||||
INSTANTIATE_LAPACK_REAL(gesdd)
|
|
||||||
INSTANTIATE_LAPACK_REAL(getrf)
|
INSTANTIATE_LAPACK_REAL(getrf)
|
||||||
INSTANTIATE_LAPACK_REAL(getri)
|
INSTANTIATE_LAPACK_REAL(getri)
|
||||||
INSTANTIATE_LAPACK_REAL(trtri)
|
INSTANTIATE_LAPACK_REAL(trtri)
|
||||||
@@ -63,3 +61,20 @@ INSTANTIATE_LAPACK_REAL(trtri)
|
|||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_LAPACK_COMPLEX(heevd)
|
INSTANTIATE_LAPACK_COMPLEX(heevd)
|
||||||
|
|
||||||
|
#define INSTANTIATE_LAPACK_ALL(FUNC) \
|
||||||
|
template <typename T, typename... Args> \
|
||||||
|
void FUNC(Args... args) { \
|
||||||
|
if constexpr (std::is_same_v<T, float>) { \
|
||||||
|
MLX_LAPACK_FUNC(s##FUNC)(std::forward<Args>(args)...); \
|
||||||
|
} else if constexpr (std::is_same_v<T, double>) { \
|
||||||
|
MLX_LAPACK_FUNC(d##FUNC)(std::forward<Args>(args)...); \
|
||||||
|
} else if constexpr (std::is_same_v<T, std::complex<float>>) { \
|
||||||
|
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
|
||||||
|
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
|
||||||
|
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_LAPACK_ALL(geev)
|
||||||
|
INSTANTIATE_LAPACK_ALL(gesdd)
|
||||||
|
|||||||
@@ -8,6 +8,183 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <typename T, class Enable = void>
|
||||||
|
struct SVDWork {};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct SVDWork<
|
||||||
|
T,
|
||||||
|
typename std::enable_if<std::is_floating_point<T>::value>::type> {
|
||||||
|
using R = T;
|
||||||
|
|
||||||
|
int N;
|
||||||
|
int M;
|
||||||
|
int K;
|
||||||
|
int lda;
|
||||||
|
int ldu;
|
||||||
|
int ldvt;
|
||||||
|
char jobz;
|
||||||
|
std::vector<array::Data> buffers;
|
||||||
|
int lwork;
|
||||||
|
|
||||||
|
SVDWork(int N, int M, int K, char jobz)
|
||||||
|
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
|
||||||
|
T workspace_dimension = 0;
|
||||||
|
|
||||||
|
// Will contain the indices of eigenvectors that failed to converge (not
|
||||||
|
// used here but required by lapack).
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
|
||||||
|
|
||||||
|
int lwork_query = -1;
|
||||||
|
int info;
|
||||||
|
|
||||||
|
// Compute workspace size.
|
||||||
|
gesdd<T>(
|
||||||
|
/* jobz = */ &jobz,
|
||||||
|
// M and N are swapped since lapack expects column-major.
|
||||||
|
/* m = */ &N,
|
||||||
|
/* n = */ &M,
|
||||||
|
/* a = */ nullptr,
|
||||||
|
/* lda = */ &lda,
|
||||||
|
/* s = */ nullptr,
|
||||||
|
/* u = */ nullptr,
|
||||||
|
/* ldu = */ &ldu,
|
||||||
|
/* vt = */ nullptr,
|
||||||
|
/* ldvt = */ &ldvt,
|
||||||
|
/* work = */ &workspace_dimension,
|
||||||
|
/* lwork = */ &lwork_query,
|
||||||
|
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
|
||||||
|
/* info = */ &info);
|
||||||
|
|
||||||
|
if (info != 0) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
|
||||||
|
throw std::runtime_error(ss.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
lwork = workspace_dimension;
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||||
|
}
|
||||||
|
|
||||||
|
void run(T* a, R* s, T* u, T* vt) {
|
||||||
|
int info;
|
||||||
|
gesdd<T>(
|
||||||
|
/* jobz = */ &jobz,
|
||||||
|
// M and N are swapped since lapack expects column-major.
|
||||||
|
/* m = */ &N,
|
||||||
|
/* n = */ &M,
|
||||||
|
/* a = */ a,
|
||||||
|
/* lda = */ &lda,
|
||||||
|
/* s = */ s,
|
||||||
|
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||||
|
/* u = */ u,
|
||||||
|
/* ldu = */ &ldu,
|
||||||
|
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
||||||
|
/* vt = */ vt,
|
||||||
|
/* ldvt = */ &ldvt,
|
||||||
|
/* work = */ static_cast<T*>(buffers[1].buffer.raw_ptr()),
|
||||||
|
/* lwork = */ &lwork,
|
||||||
|
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
|
||||||
|
/* info = */ &info);
|
||||||
|
|
||||||
|
if (info != 0) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
||||||
|
throw std::runtime_error(ss.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct SVDWork<std::complex<float>> {
|
||||||
|
using T = std::complex<float>;
|
||||||
|
using R = float;
|
||||||
|
|
||||||
|
int N;
|
||||||
|
int M;
|
||||||
|
int K;
|
||||||
|
int lda;
|
||||||
|
int ldu;
|
||||||
|
int ldvt;
|
||||||
|
char jobz;
|
||||||
|
std::vector<array::Data> buffers;
|
||||||
|
int lwork;
|
||||||
|
|
||||||
|
SVDWork(int N, int M, int K, char jobz)
|
||||||
|
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
|
||||||
|
T workspace_dimension = 0;
|
||||||
|
|
||||||
|
// Will contain the indices of eigenvectors that failed to converge (not
|
||||||
|
// used here but required by lapack).
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
|
||||||
|
|
||||||
|
const int lrwork =
|
||||||
|
jobz == 'A' ? std::max(1, 5 * K * K + 5 * K) : std::max(1, 7 * K);
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(float) * lrwork));
|
||||||
|
|
||||||
|
int lwork_query = -1;
|
||||||
|
int work_query = -1;
|
||||||
|
int info;
|
||||||
|
|
||||||
|
// Compute workspace size.
|
||||||
|
gesdd<T>(
|
||||||
|
/* jobz = */ &jobz,
|
||||||
|
// M and N are swapped since lapack expects column-major.
|
||||||
|
/* m = */ &N,
|
||||||
|
/* n = */ &M,
|
||||||
|
/* a = */ nullptr,
|
||||||
|
/* lda = */ &lda,
|
||||||
|
/* s = */ nullptr,
|
||||||
|
/* u = */ nullptr,
|
||||||
|
/* ldu = */ &ldu,
|
||||||
|
/* vt = */ nullptr,
|
||||||
|
/* ldvt = */ &ldvt,
|
||||||
|
/* work = */ &workspace_dimension,
|
||||||
|
/* lwork = */ &lwork_query,
|
||||||
|
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
|
||||||
|
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
|
||||||
|
/* info = */ &info);
|
||||||
|
|
||||||
|
if (info != 0) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
|
||||||
|
throw std::runtime_error(ss.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
lwork = workspace_dimension.real();
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||||
|
}
|
||||||
|
|
||||||
|
void run(T* a, R* s, T* u, T* vt) {
|
||||||
|
int info;
|
||||||
|
gesdd<T>(
|
||||||
|
/* jobz = */ &jobz,
|
||||||
|
// M and N are swapped since lapack expects column-major.
|
||||||
|
/* m = */ &N,
|
||||||
|
/* n = */ &M,
|
||||||
|
/* a = */ a,
|
||||||
|
/* lda = */ &lda,
|
||||||
|
/* s = */ s,
|
||||||
|
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||||
|
/* u = */ u,
|
||||||
|
/* ldu = */ &ldu,
|
||||||
|
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
||||||
|
/* vt = */ vt,
|
||||||
|
/* ldvt = */ &ldvt,
|
||||||
|
/* work = */ static_cast<T*>(buffers[2].buffer.raw_ptr()),
|
||||||
|
/* lwork = */ &lwork,
|
||||||
|
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
|
||||||
|
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
|
||||||
|
/* info = */ &info);
|
||||||
|
|
||||||
|
if (info != 0) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
||||||
|
throw std::runtime_error(ss.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void svd_impl(
|
void svd_impl(
|
||||||
const array& a,
|
const array& a,
|
||||||
@@ -27,6 +204,8 @@ void svd_impl(
|
|||||||
const int N = a.shape(-1);
|
const int N = a.shape(-1);
|
||||||
const int K = std::min(M, N);
|
const int K = std::min(M, N);
|
||||||
|
|
||||||
|
using R = typename SVDWork<T>::R;
|
||||||
|
|
||||||
size_t num_matrices = a.size() / (M * N);
|
size_t num_matrices = a.size() / (M * N);
|
||||||
|
|
||||||
// lapack clobbers the input, so we have to make a copy.
|
// lapack clobbers the input, so we have to make a copy.
|
||||||
@@ -42,7 +221,7 @@ void svd_impl(
|
|||||||
encoder.set_input_array(a);
|
encoder.set_input_array(a);
|
||||||
auto in_ptr = in.data<T>();
|
auto in_ptr = in.data<T>();
|
||||||
T* u_ptr;
|
T* u_ptr;
|
||||||
T* s_ptr;
|
R* s_ptr;
|
||||||
T* vt_ptr;
|
T* vt_ptr;
|
||||||
|
|
||||||
if (compute_uv) {
|
if (compute_uv) {
|
||||||
@@ -58,7 +237,7 @@ void svd_impl(
|
|||||||
encoder.set_output_array(s);
|
encoder.set_output_array(s);
|
||||||
encoder.set_output_array(vt);
|
encoder.set_output_array(vt);
|
||||||
|
|
||||||
s_ptr = s.data<T>();
|
s_ptr = s.data<R>();
|
||||||
u_ptr = u.data<T>();
|
u_ptr = u.data<T>();
|
||||||
vt_ptr = vt.data<T>();
|
vt_ptr = vt.data<T>();
|
||||||
} else {
|
} else {
|
||||||
@@ -68,96 +247,26 @@ void svd_impl(
|
|||||||
|
|
||||||
encoder.set_output_array(s);
|
encoder.set_output_array(s);
|
||||||
|
|
||||||
s_ptr = s.data<T>();
|
s_ptr = s.data<R>();
|
||||||
u_ptr = nullptr;
|
u_ptr = nullptr;
|
||||||
vt_ptr = nullptr;
|
vt_ptr = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() {
|
encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() {
|
||||||
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
|
auto jobz = (u_ptr) ? 'A' : 'N';
|
||||||
const int lda = N;
|
SVDWork<T> svd_work(N, M, K, jobz);
|
||||||
// U of shape M x M. (N x N in lapack).
|
|
||||||
const int ldu = N;
|
|
||||||
// Vᵀ of shape N x N. (M x M in lapack).
|
|
||||||
const int ldvt = M;
|
|
||||||
|
|
||||||
auto jobz = (u_ptr) ? "A" : "N";
|
|
||||||
|
|
||||||
T workspace_dimension = 0;
|
|
||||||
|
|
||||||
// Will contain the indices of eigenvectors that failed to converge (not
|
|
||||||
// used here but required by lapack).
|
|
||||||
auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)};
|
|
||||||
|
|
||||||
static const int lwork_query = -1;
|
|
||||||
|
|
||||||
int info;
|
|
||||||
|
|
||||||
// Compute workspace size.
|
|
||||||
gesdd<T>(
|
|
||||||
/* jobz = */ jobz,
|
|
||||||
// M and N are swapped since lapack expects column-major.
|
|
||||||
/* m = */ &N,
|
|
||||||
/* n = */ &M,
|
|
||||||
/* a = */ nullptr,
|
|
||||||
/* lda = */ &lda,
|
|
||||||
/* s = */ nullptr,
|
|
||||||
/* u = */ nullptr,
|
|
||||||
/* ldu = */ &ldu,
|
|
||||||
/* vt = */ nullptr,
|
|
||||||
/* ldvt = */ &ldvt,
|
|
||||||
/* work = */ &workspace_dimension,
|
|
||||||
/* lwork = */ &lwork_query,
|
|
||||||
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
|
||||||
/* info = */ &info);
|
|
||||||
|
|
||||||
if (info != 0) {
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
|
|
||||||
throw std::runtime_error(ss.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
const int lwork = workspace_dimension;
|
|
||||||
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
|
||||||
|
|
||||||
// Loop over matrices.
|
// Loop over matrices.
|
||||||
for (int i = 0; i < num_matrices; i++) {
|
for (int i = 0; i < num_matrices; i++) {
|
||||||
gesdd<T>(
|
svd_work.run(
|
||||||
/* jobz = */ jobz,
|
in_ptr + M * N * i,
|
||||||
// M and N are swapped since lapack expects column-major.
|
s_ptr + K * i,
|
||||||
/* m = */ &N,
|
vt_ptr ? vt_ptr + N * N * i : nullptr,
|
||||||
/* n = */ &M,
|
u_ptr ? u_ptr + M * M * i : nullptr);
|
||||||
/* a = */ in_ptr + M * N * i,
|
|
||||||
/* lda = */ &lda,
|
|
||||||
/* s = */ s_ptr + K * i,
|
|
||||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
|
||||||
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
|
|
||||||
/* ldu = */ &ldu,
|
|
||||||
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
|
||||||
/* vt = */ u_ptr ? u_ptr + M * M * i : nullptr,
|
|
||||||
/* ldvt = */ &ldvt,
|
|
||||||
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
|
|
||||||
/* lwork = */ &lwork,
|
|
||||||
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
|
||||||
/* info = */ &info);
|
|
||||||
|
|
||||||
if (info != 0) {
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
|
||||||
throw std::runtime_error(ss.str());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
encoder.add_temporary(in);
|
encoder.add_temporary(in);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void compute_svd(
|
|
||||||
const array& a,
|
|
||||||
bool compute_uv,
|
|
||||||
std::vector<array>& outputs,
|
|
||||||
Stream stream) {}
|
|
||||||
|
|
||||||
void SVD::eval_cpu(
|
void SVD::eval_cpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
@@ -168,9 +277,12 @@ void SVD::eval_cpu(
|
|||||||
case float64:
|
case float64:
|
||||||
svd_impl<double>(inputs[0], outputs, compute_uv_, stream());
|
svd_impl<double>(inputs[0], outputs, compute_uv_, stream());
|
||||||
break;
|
break;
|
||||||
|
case complex64:
|
||||||
|
svd_impl<std::complex<float>>(inputs[0], outputs, compute_uv_, stream());
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[SVD::eval_cpu] only supports float32 or float64.");
|
"[SVD::eval_cpu] only supports float32, float64, or complex64.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ CudaAllocator::CudaAllocator()
|
|||||||
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||||
size_t free, total;
|
size_t free, total;
|
||||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||||
memory_limit_ = total * 0.95;
|
memory_limit_ = total * 0.9;
|
||||||
max_pool_size_ = memory_limit_;
|
max_pool_size_ = memory_limit_;
|
||||||
|
|
||||||
int device_count = 0;
|
int device_count = 0;
|
||||||
@@ -176,7 +176,7 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
|
|||||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||||
}
|
}
|
||||||
// Copy to managed here if the buffer is not on the right device
|
// Copy to managed here if the buffer is not on the right device
|
||||||
if (buf->device != device) {
|
if (buf->device >= 0 && buf->device != device) {
|
||||||
copy_to_managed(*buf);
|
copy_to_managed(*buf);
|
||||||
}
|
}
|
||||||
return Buffer{buf};
|
return Buffer{buf};
|
||||||
@@ -219,9 +219,9 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
|||||||
scalar_pool_.free(buf);
|
scalar_pool_.free(buf);
|
||||||
} else {
|
} else {
|
||||||
if (buf->device >= 0) {
|
if (buf->device >= 0) {
|
||||||
cudaFreeAsync(buf->data, free_streams_[buf->device]);
|
CHECK_CUDA_ERROR(cudaFreeAsync(buf->data, free_streams_[buf->device]));
|
||||||
} else {
|
} else {
|
||||||
cudaFree(buf->data);
|
CHECK_CUDA_ERROR(cudaFree(buf->data));
|
||||||
}
|
}
|
||||||
delete buf;
|
delete buf;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -139,10 +139,10 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// keys has shape (N1, ..., NK, 2)
|
// keys has shape (N1, ..., NK, 2)
|
||||||
// out has shape (N1, ..., NK, M1, M2, ...)
|
// out has shape (N1, ..., NK, M1, M2, ...)
|
||||||
auto& keys = inputs[0];
|
auto& keys = inputs[0];
|
||||||
uint32_t num_keys = keys.size() / 2;
|
size_t num_keys = keys.size() / 2;
|
||||||
|
|
||||||
uint32_t elems_per_key = out.size() / num_keys;
|
size_t elems_per_key = out.size() / num_keys;
|
||||||
uint32_t bytes_per_key = out.itemsize() * elems_per_key;
|
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||||
@@ -150,19 +150,25 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t out_per_key = (bytes_per_key + 4 - 1) / 4;
|
size_t out_per_key = (bytes_per_key + 4 - 1) / 4;
|
||||||
uint32_t half_size = out_per_key / 2;
|
size_t half_size = out_per_key / 2;
|
||||||
|
|
||||||
bool odd = out_per_key % 2;
|
bool odd = out_per_key % 2;
|
||||||
|
if ((half_size + odd) >= UINT32_MAX || num_keys >= UINT32_MAX) {
|
||||||
|
throw std::runtime_error("[RandomBits::eval_gpu] Large size unsupported");
|
||||||
|
}
|
||||||
|
|
||||||
encoder.set_input_array(keys);
|
encoder.set_input_array(keys);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
dim3 grid_dims{num_keys, half_size + odd};
|
int64_t total = num_keys * (half_size + odd);
|
||||||
int64_t total = grid_dims.x * grid_dims.y;
|
uint32_t threads_y = 1;
|
||||||
int32_t threads_y = 1;
|
while ((total / threads_y) >= UINT_MAX) {
|
||||||
while ((total / threads_y) >= (1U << 31)) {
|
|
||||||
threads_y *= 2;
|
threads_y *= 2;
|
||||||
}
|
}
|
||||||
int32_t threads_x = cuda::ceil_div(total, threads_y);
|
uint32_t threads_x = cuda::ceil_div(total, threads_y);
|
||||||
|
|
||||||
|
dim3 grid_dims{
|
||||||
|
static_cast<uint32_t>(num_keys), static_cast<uint32_t>(half_size + odd)};
|
||||||
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
|
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
|
||||||
auto& stream = encoder.stream();
|
auto& stream = encoder.stream();
|
||||||
if (keys.flags().row_contiguous) {
|
if (keys.flags().row_contiguous) {
|
||||||
|
|||||||
@@ -121,14 +121,6 @@ if(NOT MLX_METAL_PATH)
|
|||||||
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
|
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL
|
|
||||||
26.2))
|
|
||||||
set(MLX_ENABLE_NAX TRUE)
|
|
||||||
target_compile_definitions(mlx PRIVATE MLX_ENABLE_NAX)
|
|
||||||
else()
|
|
||||||
set(MLX_ENABLE_NAX FALSE)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
|
||||||
|
|
||||||
target_compile_definitions(mlx
|
target_compile_definitions(mlx
|
||||||
|
|||||||
@@ -265,14 +265,19 @@ Device& device(mlx::core::Device);
|
|||||||
|
|
||||||
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
||||||
|
|
||||||
#ifdef MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
inline bool is_nax_available() {
|
inline bool is_nax_available() {
|
||||||
static bool is_nax_available_ =
|
auto _check_nax = []() {
|
||||||
metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17;
|
bool can_use_nax = false;
|
||||||
|
if (__builtin_available(
|
||||||
|
macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||||
|
can_use_nax = true;
|
||||||
|
}
|
||||||
|
can_use_nax &=
|
||||||
|
metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17;
|
||||||
|
return can_use_nax;
|
||||||
|
};
|
||||||
|
static bool is_nax_available_ = _check_nax();
|
||||||
return is_nax_available_;
|
return is_nax_available_;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace mlx::core::metal
|
||||||
|
|||||||
@@ -9,13 +9,17 @@ set(BASE_HEADERS
|
|||||||
utils.h)
|
utils.h)
|
||||||
|
|
||||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||||
set(METAL_FLAGS -x metal -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
set(METAL_FLAGS
|
||||||
|
-x
|
||||||
|
metal
|
||||||
|
-Wall
|
||||||
|
-Wextra
|
||||||
|
-fno-fast-math
|
||||||
|
-Wno-c++17-extensions
|
||||||
|
-Wno-c++20-extensions)
|
||||||
if(MLX_METAL_DEBUG)
|
if(MLX_METAL_DEBUG)
|
||||||
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
|
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
|
||||||
endif()
|
endif()
|
||||||
if(MLX_ENABLE_NAX)
|
|
||||||
set(METAL_FLAGS ${METAL_FLAGS} -Wno-c++20-extensions -std=metal4.0)
|
|
||||||
endif()
|
|
||||||
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
||||||
set(METAL_FLAGS ${METAL_FLAGS}
|
set(METAL_FLAGS ${METAL_FLAGS}
|
||||||
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||||
@@ -123,8 +127,8 @@ if(NOT MLX_METAL_JIT)
|
|||||||
build_kernel(gemv_masked steel/utils.h)
|
build_kernel(gemv_masked steel/utils.h)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MLX_ENABLE_NAX)
|
if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL
|
||||||
|
26.2))
|
||||||
set(STEEL_NAX_HEADERS
|
set(STEEL_NAX_HEADERS
|
||||||
steel/defines.h
|
steel/defines.h
|
||||||
steel/utils.h
|
steel/utils.h
|
||||||
|
|||||||
@@ -172,8 +172,6 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
|||||||
// Regular steel matmul dispatch
|
// Regular steel matmul dispatch
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#ifdef MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
template <bool CHECK_AB>
|
template <bool CHECK_AB>
|
||||||
void steel_matmul_regular_axpby_nax(
|
void steel_matmul_regular_axpby_nax(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
@@ -210,11 +208,11 @@ void steel_matmul_regular_axpby_nax(
|
|||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
kname << "steel_gemm_fused_nax_"
|
kname << "steel_gemm_fused_nax_"
|
||||||
<< (transpose_a ? 't' : 'n')
|
<< (transpose_a ? 't' : 'n')
|
||||||
<< (transpose_b ? 't' : 'n')
|
<< (transpose_b ? 't' : 'n')
|
||||||
<< "_" << type_to_name(a)
|
<< "_" << type_to_name(a)
|
||||||
<< "_" << type_to_name(out)
|
<< "_" << type_to_name(out)
|
||||||
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
|
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||||
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
||||||
|
|
||||||
@@ -329,8 +327,6 @@ void steel_matmul_regular_axpby_nax(
|
|||||||
d.add_temporaries(std::move(copies), s.index);
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
template <bool CHECK_AB>
|
template <bool CHECK_AB>
|
||||||
void steel_matmul_regular_axpby(
|
void steel_matmul_regular_axpby(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
@@ -357,41 +353,35 @@ void steel_matmul_regular_axpby(
|
|||||||
int64_t C_batch_stride /* = 0*/,
|
int64_t C_batch_stride /* = 0*/,
|
||||||
float alpha /* = 1.0f */,
|
float alpha /* = 1.0f */,
|
||||||
float beta /* = 0.0f */) {
|
float beta /* = 0.0f */) {
|
||||||
#ifdef MLX_ENABLE_NAX
|
if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) &&
|
||||||
|
(env::enable_tf32() || a.dtype() != float32)) {
|
||||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
return steel_matmul_regular_axpby_nax<CHECK_AB>(
|
||||||
if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) &&
|
/* const Stream& s = */ s,
|
||||||
(env::enable_tf32() || a.dtype() != float32)) {
|
/* metal::Device& d = */ d,
|
||||||
return steel_matmul_regular_axpby_nax<CHECK_AB>(
|
/* const array& a = */ a,
|
||||||
/* const Stream& s = */ s,
|
/* const array& b = */ b,
|
||||||
/* metal::Device& d = */ d,
|
/* const array& c = */ c,
|
||||||
/* const array& a = */ a,
|
/* array& out = */ out,
|
||||||
/* const array& b = */ b,
|
/* int M = */ M,
|
||||||
/* const array& c = */ c,
|
/* int N = */ N,
|
||||||
/* array& out = */ out,
|
/* int K = */ K,
|
||||||
/* int M = */ M,
|
/* int batch_size_out = */ batch_size_out,
|
||||||
/* int N = */ N,
|
/* int lda = */ lda,
|
||||||
/* int K = */ K,
|
/* int ldb = */ ldb,
|
||||||
/* int batch_size_out = */ batch_size_out,
|
/* int ldd = */ ldd,
|
||||||
/* int lda = */ lda,
|
/* bool transpose_a = */ transpose_a,
|
||||||
/* int ldb = */ ldb,
|
/* bool transpose_b = */ transpose_b,
|
||||||
/* int ldd = */ ldd,
|
/* std::vector<array>& copies = */ copies,
|
||||||
/* bool transpose_a = */ transpose_a,
|
/* Shape batch_shape = */ batch_shape,
|
||||||
/* bool transpose_b = */ transpose_b,
|
/* Strides batch_strides = */ batch_strides,
|
||||||
/* std::vector<array>& copies = */ copies,
|
/* int64_t A_batch_stride = */ A_batch_stride,
|
||||||
/* Shape batch_shape = */ batch_shape,
|
/* int64_t B_batch_stride = */ B_batch_stride,
|
||||||
/* Strides batch_strides = */ batch_strides,
|
/* int64_t matrix_stride_out = */ matrix_stride_out,
|
||||||
/* int64_t A_batch_stride = */ A_batch_stride,
|
/* int64_t C_batch_stride = */ C_batch_stride,
|
||||||
/* int64_t B_batch_stride = */ B_batch_stride,
|
/* float alpha = */ alpha,
|
||||||
/* int64_t matrix_stride_out = */ matrix_stride_out,
|
/* float beta = */ beta);
|
||||||
/* int64_t C_batch_stride = */ C_batch_stride,
|
|
||||||
/* float alpha = */ alpha,
|
|
||||||
/* float beta = */ beta);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
using namespace mlx::steel;
|
using namespace mlx::steel;
|
||||||
|
|
||||||
// Determine dispatch kernel
|
// Determine dispatch kernel
|
||||||
@@ -405,11 +395,11 @@ void steel_matmul_regular_axpby(
|
|||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
kname << "steel_gemm_fused_"
|
kname << "steel_gemm_fused_"
|
||||||
<< (transpose_a ? 't' : 'n')
|
<< (transpose_a ? 't' : 'n')
|
||||||
<< (transpose_b ? 't' : 'n')
|
<< (transpose_b ? 't' : 'n')
|
||||||
<< "_" << type_to_name(a)
|
<< "_" << type_to_name(a)
|
||||||
<< "_" << type_to_name(out)
|
<< "_" << type_to_name(out)
|
||||||
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
|
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||||
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
||||||
|
|
||||||
@@ -574,14 +564,14 @@ void steel_gemm_splitk_axpby(
|
|||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
kname << "steel_gemm_splitk_"
|
kname << "steel_gemm_splitk_"
|
||||||
<< (transpose_a ? 't' : 'n')
|
<< (transpose_a ? 't' : 'n')
|
||||||
<< (transpose_b ? 't' : 'n')
|
<< (transpose_b ? 't' : 'n')
|
||||||
<< "_" << type_to_name(a)
|
<< "_" << type_to_name(a)
|
||||||
<< "_" << type_to_name(C_split)
|
<< "_" << type_to_name(C_split)
|
||||||
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
|
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||||
<< "_wm" << wm << "_wn" << wn
|
<< "_wm" << wm << "_wn" << wn
|
||||||
<< "_MN_" << (mn_aligned ? "t" : "n") << "aligned"
|
<< "_MN_" << (mn_aligned ? "t" : "n") << "aligned"
|
||||||
<< "_K_" << (k_aligned ? "t" : "n") << "aligned"; // clang-format on
|
<< "_K_" << (k_aligned ? "t" : "n") << "aligned"; // clang-format on
|
||||||
|
|
||||||
// Encode and dispatch gemm kernel
|
// Encode and dispatch gemm kernel
|
||||||
@@ -915,10 +905,10 @@ void gemv_axbpy(
|
|||||||
const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);
|
const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
kname << "_bm" << bm << "_bn" << bn
|
kname << "_bm" << bm << "_bn" << bn
|
||||||
<< "_sm" << sm << "_sn" << sn
|
<< "_sm" << sm << "_sn" << sn
|
||||||
<< "_tm" << tm << "_tn" << tn
|
<< "_tm" << tm << "_tn" << tn
|
||||||
<< "_nc" << !contiguous_kernel
|
<< "_nc" << !contiguous_kernel
|
||||||
<< "_axpby" << do_axpby; // clang-format on
|
<< "_axpby" << do_axpby; // clang-format on
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
@@ -1766,8 +1756,6 @@ void gather_mm_rhs(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
void gather_mm_rhs_nax(
|
void gather_mm_rhs_nax(
|
||||||
const array& a_,
|
const array& a_,
|
||||||
const array& b_,
|
const array& b_,
|
||||||
@@ -1911,8 +1899,6 @@ void gather_mm_rhs_nax(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
void gather_mv(
|
void gather_mv(
|
||||||
const array& mat_,
|
const array& mat_,
|
||||||
const array& vec_,
|
const array& vec_,
|
||||||
@@ -2196,19 +2182,10 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// We are walking a in order and b is also in order so we can batch up the
|
// We are walking a in order and b is also in order so we can batch up the
|
||||||
// matmuls and reuse reading a and b.
|
// matmuls and reuse reading a and b.
|
||||||
if (M == 1 && right_sorted_ == true) {
|
if (M == 1 && right_sorted_ == true) {
|
||||||
#ifdef MLX_ENABLE_NAX
|
if (metal::is_nax_available() &&
|
||||||
|
(env::enable_tf32() || a.dtype() != float32)) {
|
||||||
if (__builtin_available(
|
return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s);
|
||||||
macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
|
||||||
if (metal::is_nax_available() &&
|
|
||||||
!issubdtype(a.dtype(), complexfloating) &&
|
|
||||||
(env::enable_tf32() || a.dtype() != float32)) {
|
|
||||||
return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
gather_mm_rhs(a, b, rhs_indices, out, d, s);
|
gather_mm_rhs(a, b, rhs_indices, out, d, s);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -451,8 +451,6 @@ void qvm(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
void qmm_nax(
|
void qmm_nax(
|
||||||
const array& x,
|
const array& x,
|
||||||
const array& w,
|
const array& w,
|
||||||
@@ -653,8 +651,6 @@ void gather_qmm_nax(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
void qmm(
|
void qmm(
|
||||||
const array& x,
|
const array& x,
|
||||||
const array& w,
|
const array& w,
|
||||||
@@ -670,31 +666,25 @@ void qmm(
|
|||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
const std::string& mode) {
|
const std::string& mode) {
|
||||||
#ifdef MLX_ENABLE_NAX
|
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
|
||||||
|
(env::enable_tf32() || x.dtype() != float32)) {
|
||||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
return qmm_nax(
|
||||||
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
|
/* const array& x = */ x,
|
||||||
(env::enable_tf32() || x.dtype() != float32)) {
|
/* const array& w = */ w,
|
||||||
return qmm_nax(
|
/* const array& scales = */ scales,
|
||||||
/* const array& x = */ x,
|
/* const std::optional<array>& biases = */ biases,
|
||||||
/* const array& w = */ w,
|
/* array& out = */ out,
|
||||||
/* const array& scales = */ scales,
|
/* bool transpose = */ transpose,
|
||||||
/* const std::optional<array>& biases = */ biases,
|
/* int group_size = */ group_size,
|
||||||
/* array& out = */ out,
|
/* int bits = */ bits,
|
||||||
/* bool transpose = */ transpose,
|
/* int M = */ M,
|
||||||
/* int group_size = */ group_size,
|
/* int N = */ N,
|
||||||
/* int bits = */ bits,
|
/* int K = */ K,
|
||||||
/* int M = */ M,
|
/* metal::Device& d = */ d,
|
||||||
/* int N = */ N,
|
/* const Stream& s = */ s,
|
||||||
/* int K = */ K,
|
/* const std::string& mode = */ mode);
|
||||||
/* metal::Device& d = */ d,
|
|
||||||
/* const Stream& s = */ s,
|
|
||||||
/* const std::string& mode = */ mode);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
int B = out.size() / M / N;
|
int B = out.size() / M / N;
|
||||||
|
|
||||||
int wm = 2;
|
int wm = 2;
|
||||||
@@ -772,33 +762,27 @@ void gather_qmm(
|
|||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
const std::string& mode) {
|
const std::string& mode) {
|
||||||
#ifdef MLX_ENABLE_NAX
|
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
|
||||||
|
(env::enable_tf32() || x.dtype() != float32)) {
|
||||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
return gather_qmm_nax(
|
||||||
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
|
/* const array& x = */ x,
|
||||||
(env::enable_tf32() || x.dtype() != float32)) {
|
/* const array& w = */ w,
|
||||||
return gather_qmm_nax(
|
/* const array& scales = */ scales,
|
||||||
/* const array& x = */ x,
|
/* const std::optional<array>& biases = */ biases,
|
||||||
/* const array& w = */ w,
|
/* const array& lhs_indices = */ lhs_indices,
|
||||||
/* const array& scales = */ scales,
|
/* const array& rhs_indices = */ rhs_indices,
|
||||||
/* const std::optional<array>& biases = */ biases,
|
/* array& out = */ out,
|
||||||
/* const array& lhs_indices = */ lhs_indices,
|
/* bool transpose = */ transpose,
|
||||||
/* const array& rhs_indices = */ rhs_indices,
|
/* int group_size = */ group_size,
|
||||||
/* array& out = */ out,
|
/* int bits = */ bits,
|
||||||
/* bool transpose = */ transpose,
|
/* int M = */ M,
|
||||||
/* int group_size = */ group_size,
|
/* int N = */ N,
|
||||||
/* int bits = */ bits,
|
/* int K = */ K,
|
||||||
/* int M = */ M,
|
/* metal::Device& d = */ d,
|
||||||
/* int N = */ N,
|
/* const Stream& s = */ s,
|
||||||
/* int K = */ K,
|
/* const std::string& mode = */ mode);
|
||||||
/* metal::Device& d = */ d,
|
|
||||||
/* const Stream& s = */ s,
|
|
||||||
/* const std::string& mode = */ mode);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
int B = out.size() / M / N;
|
int B = out.size() / M / N;
|
||||||
|
|
||||||
int wm = 2;
|
int wm = 2;
|
||||||
@@ -975,8 +959,6 @@ void gather_qvm(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
void gather_qmm_rhs_nax(
|
void gather_qmm_rhs_nax(
|
||||||
const array& x_,
|
const array& x_,
|
||||||
const array& w_,
|
const array& w_,
|
||||||
@@ -1108,8 +1090,6 @@ void gather_qmm_rhs_nax(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
void gather_qmm_rhs(
|
void gather_qmm_rhs(
|
||||||
const array& x_,
|
const array& x_,
|
||||||
const array& w_,
|
const array& w_,
|
||||||
@@ -1126,32 +1106,26 @@ void gather_qmm_rhs(
|
|||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
const std::string mode) {
|
const std::string mode) {
|
||||||
#ifdef MLX_ENABLE_NAX
|
if (metal::is_nax_available() && transpose &&
|
||||||
|
(env::enable_tf32() || x_.dtype() != float32)) {
|
||||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
return gather_qmm_rhs_nax(
|
||||||
if (metal::is_nax_available() && transpose &&
|
/* const array& x_ = */ x_,
|
||||||
(env::enable_tf32() || x_.dtype() != float32)) {
|
/* const array& w_ = */ w_,
|
||||||
return gather_qmm_rhs_nax(
|
/* const array& scales_ = */ scales_,
|
||||||
/* const array& x_ = */ x_,
|
/* const std::optional<array>& biases_ = */ biases_,
|
||||||
/* const array& w_ = */ w_,
|
/* const array& indices_ = */ indices_,
|
||||||
/* const array& scales_ = */ scales_,
|
/* array& out = */ out,
|
||||||
/* const std::optional<array>& biases_ = */ biases_,
|
/* bool transpose = */ transpose,
|
||||||
/* const array& indices_ = */ indices_,
|
/* int group_size = */ group_size,
|
||||||
/* array& out = */ out,
|
/* int bits = */ bits,
|
||||||
/* bool transpose = */ transpose,
|
/* int M = */ M,
|
||||||
/* int group_size = */ group_size,
|
/* int N = */ N,
|
||||||
/* int bits = */ bits,
|
/* int K = */ K,
|
||||||
/* int M = */ M,
|
/* metal::Device& d = */ d,
|
||||||
/* int N = */ N,
|
/* const Stream& s = */ s,
|
||||||
/* int K = */ K,
|
/* const std::string mode = */ mode);
|
||||||
/* metal::Device& d = */ d,
|
|
||||||
/* const Stream& s = */ s,
|
|
||||||
/* const std::string mode = */ mode);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
// Start by normalizing the indices
|
// Start by normalizing the indices
|
||||||
array indices = ensure_row_contiguous(indices_, d, s);
|
array indices = ensure_row_contiguous(indices_, d, s);
|
||||||
|
|
||||||
|
|||||||
@@ -13,8 +13,6 @@ namespace mlx::core::fast {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#ifdef MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
void sdpa_full_self_attention_nax(
|
void sdpa_full_self_attention_nax(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
@@ -150,8 +148,6 @@ void sdpa_full_self_attention_nax(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
void sdpa_full_self_attention_metal(
|
void sdpa_full_self_attention_metal(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
@@ -163,24 +159,20 @@ void sdpa_full_self_attention_metal(
|
|||||||
bool do_causal_,
|
bool do_causal_,
|
||||||
const std::optional<array>& mask,
|
const std::optional<array>& mask,
|
||||||
const std::optional<array>& sinks) {
|
const std::optional<array>& sinks) {
|
||||||
#ifdef MLX_ENABLE_NAX
|
if (metal::is_nax_available() && q.shape(3) != 80 &&
|
||||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
(env::enable_tf32() || q.dtype() != float32)) {
|
||||||
if (metal::is_nax_available() && q.shape(3) != 80 &&
|
return sdpa_full_self_attention_nax(
|
||||||
(env::enable_tf32() || q.dtype() != float32)) {
|
/* const Stream& s = */ s,
|
||||||
return sdpa_full_self_attention_nax(
|
/* metal::Device& d = */ d,
|
||||||
/* const Stream& s = */ s,
|
/* const array& q = */ q,
|
||||||
/* metal::Device& d = */ d,
|
/* const array& k = */ k,
|
||||||
/* const array& q = */ q,
|
/* const array& v = */ v,
|
||||||
/* const array& k = */ k,
|
/* const float scale = */ scale,
|
||||||
/* const array& v = */ v,
|
/* array& o = */ o,
|
||||||
/* const float scale = */ scale,
|
/* bool do_causal_ = */ do_causal_,
|
||||||
/* array& o = */ o,
|
/* const std::optional<array>& mask = */ mask,
|
||||||
/* bool do_causal_ = */ do_causal_,
|
/* const std::optional<array>& sinks = */ sinks);
|
||||||
/* const std::optional<array>& mask = */ mask,
|
|
||||||
/* const std::optional<array>& sinks = */ sinks);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
using namespace mlx::steel;
|
using namespace mlx::steel;
|
||||||
|
|
||||||
|
|||||||
@@ -250,7 +250,7 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
std::vector<array>
|
std::vector<array>
|
||||||
svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
|
svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
|
||||||
check_cpu_stream(s, "[linalg::svd]");
|
check_cpu_stream(s, "[linalg::svd]");
|
||||||
check_float(a.dtype(), "[linalg::svd]");
|
check_float_or_complex(a.dtype(), "[linalg::svd]");
|
||||||
|
|
||||||
if (a.ndim() < 2) {
|
if (a.ndim() < 2) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@@ -268,10 +268,12 @@ svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
|
|||||||
s_shape.pop_back();
|
s_shape.pop_back();
|
||||||
s_shape[rank - 2] = std::min(m, n);
|
s_shape[rank - 2] = std::min(m, n);
|
||||||
|
|
||||||
|
auto s_dtype = a.dtype() == complex64 ? float32 : a.dtype();
|
||||||
|
|
||||||
if (!compute_uv) {
|
if (!compute_uv) {
|
||||||
return {array(
|
return {array(
|
||||||
std::move(s_shape),
|
std::move(s_shape),
|
||||||
a.dtype(),
|
s_dtype,
|
||||||
std::make_shared<SVD>(to_stream(s), compute_uv),
|
std::make_shared<SVD>(to_stream(s), compute_uv),
|
||||||
{a})};
|
{a})};
|
||||||
}
|
}
|
||||||
@@ -286,7 +288,7 @@ svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
|
|||||||
|
|
||||||
return array::make_arrays(
|
return array::make_arrays(
|
||||||
{u_shape, s_shape, vt_shape},
|
{u_shape, s_shape, vt_shape},
|
||||||
{a.dtype(), a.dtype(), a.dtype()},
|
{a.dtype(), s_dtype, a.dtype()},
|
||||||
std::make_shared<SVD>(to_stream(s), compute_uv),
|
std::make_shared<SVD>(to_stream(s), compute_uv),
|
||||||
{a});
|
{a});
|
||||||
}
|
}
|
||||||
@@ -703,4 +705,4 @@ array solve_triangular(
|
|||||||
return matmul(a_inv, b, s);
|
return matmul(a_inv, b, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::linalg
|
} // namespace mlx::core::linalg
|
||||||
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
#define MLX_VERSION_MAJOR 0
|
#define MLX_VERSION_MAJOR 0
|
||||||
#define MLX_VERSION_MINOR 30
|
#define MLX_VERSION_MINOR 30
|
||||||
#define MLX_VERSION_PATCH 0
|
#define MLX_VERSION_PATCH 1
|
||||||
#define MLX_VERSION_NUMERIC \
|
#define MLX_VERSION_NUMERIC \
|
||||||
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
||||||
|
|
||||||
|
|||||||
@@ -168,6 +168,42 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Test float64 - use CPU stream since float64 is not supported on GPU
|
||||||
|
with mx.stream(mx.cpu):
|
||||||
|
A_f64 = mx.array(
|
||||||
|
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float64
|
||||||
|
)
|
||||||
|
U_f64, S_f64, Vt_f64 = mx.linalg.svd(A_f64, compute_uv=True)
|
||||||
|
mx.eval(U_f64, S_f64, Vt_f64)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
U_f64[:, : len(S_f64)] @ mx.diag(S_f64) @ Vt_f64,
|
||||||
|
A_f64,
|
||||||
|
rtol=1e-5,
|
||||||
|
atol=1e-7,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(S_f64.dtype, mx.float64)
|
||||||
|
|
||||||
|
# Test complex64 - use CPU stream since complex64 is not supported on GPU
|
||||||
|
with mx.stream(mx.cpu):
|
||||||
|
A_c64 = mx.array(
|
||||||
|
[[1.0 + 1j, 2.0 + 2j], [3.0 + 3j, 4.0 + 4j]], dtype=mx.complex64
|
||||||
|
)
|
||||||
|
U_c64, S_c64, Vt_c64 = mx.linalg.svd(A_c64, compute_uv=True)
|
||||||
|
mx.eval(U_c64, S_c64, Vt_c64)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
U_c64[:, : len(S_c64)] @ mx.diag(S_c64) @ Vt_c64,
|
||||||
|
A_c64,
|
||||||
|
rtol=1e-5,
|
||||||
|
atol=1e-7,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(S_c64.dtype, mx.float32)
|
||||||
|
self.assertEqual(U_c64.dtype, mx.complex64)
|
||||||
|
self.assertEqual(Vt_c64.dtype, mx.complex64)
|
||||||
|
|
||||||
def test_inverse(self):
|
def test_inverse(self):
|
||||||
A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32)
|
A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32)
|
||||||
A_inv = mx.linalg.inv(A, stream=mx.cpu)
|
A_inv = mx.linalg.inv(A, stream=mx.cpu)
|
||||||
@@ -342,6 +378,43 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
|||||||
A_np = np.random.randn(3, n, n).astype(np.float32)
|
A_np = np.random.randn(3, n, n).astype(np.float32)
|
||||||
check_eigs_and_vecs(A_np)
|
check_eigs_and_vecs(A_np)
|
||||||
|
|
||||||
|
# Test float64 - use CPU stream since float64 is not supported on GPU
|
||||||
|
with mx.stream(mx.cpu):
|
||||||
|
A_np_f64 = np.array([[1.0, 1.0], [3.0, 4.0]], dtype=np.float64)
|
||||||
|
A_f64 = mx.array(A_np_f64, dtype=mx.float64)
|
||||||
|
eig_vals_f64, eig_vecs_f64 = mx.linalg.eig(A_f64)
|
||||||
|
mx.eval(eig_vals_f64, eig_vecs_f64)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
A_f64 @ eig_vecs_f64,
|
||||||
|
eig_vals_f64[..., None, :] * eig_vecs_f64,
|
||||||
|
rtol=1e-5,
|
||||||
|
atol=1e-5,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Eigenvalues should be complex64 (output dtype)
|
||||||
|
self.assertEqual(eig_vals_f64.dtype, mx.complex64)
|
||||||
|
self.assertEqual(eig_vecs_f64.dtype, mx.complex64)
|
||||||
|
|
||||||
|
# Test complex64 input - use CPU stream since complex64 is not supported on GPU
|
||||||
|
with mx.stream(mx.cpu):
|
||||||
|
A_np_c64 = np.array(
|
||||||
|
[[1.0 + 1j, 2.0 + 2j], [3.0 + 3j, 4.0 + 4j]], dtype=np.complex64
|
||||||
|
)
|
||||||
|
A_c64 = mx.array(A_np_c64, dtype=mx.complex64)
|
||||||
|
eig_vals_c64, eig_vecs_c64 = mx.linalg.eig(A_c64)
|
||||||
|
mx.eval(eig_vals_c64, eig_vecs_c64)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
A_c64 @ eig_vecs_c64,
|
||||||
|
eig_vals_c64[..., None, :] * eig_vecs_c64,
|
||||||
|
rtol=1e-5,
|
||||||
|
atol=1e-5,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(eig_vals_c64.dtype, mx.complex64)
|
||||||
|
self.assertEqual(eig_vecs_c64.dtype, mx.complex64)
|
||||||
|
|
||||||
# Test error cases
|
# Test error cases
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
mx.linalg.eig(mx.array([1.0, 2.0])) # 1D array
|
mx.linalg.eig(mx.array([1.0, 2.0])) # 1D array
|
||||||
|
|||||||
@@ -1443,23 +1443,22 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertListEqual(a.tolist(), expected)
|
self.assertListEqual(a.tolist(), expected)
|
||||||
|
|
||||||
def test_unary_ops(self):
|
def test_unary_ops(self):
|
||||||
def test_ops(npop, mlxop, x, y, atol):
|
def test_ops(npop, mlxop, x, y, atol, rtol):
|
||||||
r_np = npop(x)
|
r_np = npop(x)
|
||||||
r_mlx = mlxop(y)
|
r_mlx = mlxop(y)
|
||||||
mx.eval(r_mlx)
|
mx.eval(r_mlx)
|
||||||
|
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol, rtol=rtol))
|
||||||
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))
|
|
||||||
|
|
||||||
x = np.random.rand(18, 28, 38)
|
x = np.random.rand(18, 28, 38)
|
||||||
for op in ["abs", "exp", "log", "square", "sqrt"]:
|
for op in ["abs", "exp", "log", "square", "sqrt"]:
|
||||||
with self.subTest(op=op):
|
with self.subTest(op=op):
|
||||||
float_dtypes = [("float16", 1e-3), ("float32", 1e-6)]
|
float_dtypes = [("float16", 1e-3, 1e-3), ("float32", 1e-6, 1e-5)]
|
||||||
|
|
||||||
for dtype, atol in float_dtypes:
|
for dtype, atol, rtol in float_dtypes:
|
||||||
with self.subTest(dtype=dtype):
|
with self.subTest(dtype=dtype):
|
||||||
x_ = x.astype(getattr(np, dtype))
|
x_ = x.astype(getattr(np, dtype))
|
||||||
y_ = mx.array(x_)
|
y_ = mx.array(x_)
|
||||||
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
|
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol, rtol)
|
||||||
|
|
||||||
def test_unary_ops_from_non_array(self):
|
def test_unary_ops_from_non_array(self):
|
||||||
unary_ops = [
|
unary_ops = [
|
||||||
@@ -1511,12 +1510,14 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(np.allclose(expected, np.array(out), equal_nan=True))
|
self.assertTrue(np.allclose(expected, np.array(out), equal_nan=True))
|
||||||
|
|
||||||
def test_trig_ops(self):
|
def test_trig_ops(self):
|
||||||
def test_ops(npop, mlxop, x, y, atol):
|
def test_ops(npop, mlxop, x, y, atol, rtol):
|
||||||
r_np = npop(x)
|
r_np = npop(x)
|
||||||
r_mlx = mlxop(y)
|
r_mlx = mlxop(y)
|
||||||
mx.eval(r_mlx)
|
mx.eval(r_mlx)
|
||||||
|
|
||||||
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol, equal_nan=True))
|
self.assertTrue(
|
||||||
|
np.allclose(r_np, r_mlx, atol=atol, rtol=rtol, equal_nan=True)
|
||||||
|
)
|
||||||
|
|
||||||
x = np.random.rand(9, 12, 18)
|
x = np.random.rand(9, 12, 18)
|
||||||
xi = np.random.rand(9, 12, 18)
|
xi = np.random.rand(9, 12, 18)
|
||||||
@@ -1526,34 +1527,34 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
for op in all_fwd_ops:
|
for op in all_fwd_ops:
|
||||||
with self.subTest(op=op):
|
with self.subTest(op=op):
|
||||||
float_dtypes = [("float16", 1e-3), ("float32", 1e-6)]
|
float_dtypes = [("float16", 1e-3, 1e-3), ("float32", 1e-6, 1e-5)]
|
||||||
|
|
||||||
for dtype, atol in float_dtypes:
|
for dtype, atol, rtol in float_dtypes:
|
||||||
with self.subTest(dtype=dtype):
|
with self.subTest(dtype=dtype):
|
||||||
x_ = x.astype(getattr(np, dtype))
|
x_ = x.astype(getattr(np, dtype))
|
||||||
y_ = mx.array(x_)
|
y_ = mx.array(x_)
|
||||||
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
|
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol, rtol)
|
||||||
|
|
||||||
with self.subTest(op=op):
|
with self.subTest(op=op):
|
||||||
float_dtypes = [("complex64", 1e-5)]
|
dtype = "complex64"
|
||||||
|
with self.subTest(dtype=dtype):
|
||||||
for dtype, atol in float_dtypes:
|
x_ = x + 1.0j * xi
|
||||||
with self.subTest(dtype=dtype):
|
x_ = x_.astype(getattr(np, dtype))
|
||||||
x_ = x + 1.0j * xi
|
y_ = mx.array(x_)
|
||||||
x_ = x_.astype(getattr(np, dtype))
|
test_ops(getattr(np, op), getattr(mx, op), x_, y_, 1e-5, 1e-5)
|
||||||
y_ = mx.array(x_)
|
|
||||||
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
|
|
||||||
|
|
||||||
with self.subTest(op="arc" + op):
|
with self.subTest(op="arc" + op):
|
||||||
float_dtypes = [("float16", 1e-3), ("float32", 1e-6)]
|
float_dtypes = [("float16", 1e-3, 1e-3), ("float32", 1e-6, 1e-5)]
|
||||||
op_inv = "arc" + op
|
op_inv = "arc" + op
|
||||||
|
|
||||||
for dtype, atol in float_dtypes:
|
for dtype, atol, rtol in float_dtypes:
|
||||||
with self.subTest(dtype=dtype):
|
with self.subTest(dtype=dtype):
|
||||||
np_op_fwd = getattr(np, op)
|
np_op_fwd = getattr(np, op)
|
||||||
x_ = np_op_fwd(x).astype(getattr(np, dtype))
|
x_ = np_op_fwd(x).astype(getattr(np, dtype))
|
||||||
y_ = mx.array(x_)
|
y_ = mx.array(x_)
|
||||||
test_ops(getattr(np, op_inv), getattr(mx, op_inv), x_, y_, atol)
|
test_ops(
|
||||||
|
getattr(np, op_inv), getattr(mx, op_inv), x_, y_, atol, rtol
|
||||||
|
)
|
||||||
|
|
||||||
# Test grads
|
# Test grads
|
||||||
np_vjp_funcs = {
|
np_vjp_funcs = {
|
||||||
@@ -1579,11 +1580,10 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
x_ = x.astype(np.float32)
|
x_ = x.astype(np.float32)
|
||||||
y_ = mx.array(x_)
|
y_ = mx.array(x_)
|
||||||
op_ = op
|
op_ = op
|
||||||
atol_ = 1e-5
|
|
||||||
|
|
||||||
np_vjp = lambda x: np_vjp_funcs[op_](primal_np, x)
|
np_vjp = lambda x: np_vjp_funcs[op_](primal_np, x)
|
||||||
mx_vjp = lambda x: mx.vjp(getattr(mx, op_), [primal_mx], [x])[1][0]
|
mx_vjp = lambda x: mx.vjp(getattr(mx, op_), [primal_mx], [x])[1][0]
|
||||||
test_ops(np_vjp, mx_vjp, x_, y_, atol_)
|
test_ops(np_vjp, mx_vjp, x_, y_, 1e-5, 1e-5)
|
||||||
|
|
||||||
with self.subTest(op="arc" + op):
|
with self.subTest(op="arc" + op):
|
||||||
np_op_fwd = getattr(np, op)
|
np_op_fwd = getattr(np, op)
|
||||||
@@ -1599,11 +1599,10 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
x_ = x.astype(np.float32)
|
x_ = x.astype(np.float32)
|
||||||
y_ = mx.array(x_)
|
y_ = mx.array(x_)
|
||||||
op_ = "arc" + op
|
op_ = "arc" + op
|
||||||
atol_ = 1e-5
|
|
||||||
|
|
||||||
np_vjp = lambda x: np_vjp_funcs[op_](primal_np, x)
|
np_vjp = lambda x: np_vjp_funcs[op_](primal_np, x)
|
||||||
mx_vjp = lambda x: mx.vjp(getattr(mx, op_), [primal_mx], [x])[1][0]
|
mx_vjp = lambda x: mx.vjp(getattr(mx, op_), [primal_mx], [x])[1][0]
|
||||||
test_ops(np_vjp, mx_vjp, x_, y_, atol_)
|
test_ops(np_vjp, mx_vjp, x_, y_, 1e-5, 1e-5)
|
||||||
|
|
||||||
def test_binary_ops(self):
|
def test_binary_ops(self):
|
||||||
def test_ops(npop, mlxop, x1, x2, y1, y2, atol):
|
def test_ops(npop, mlxop, x1, x2, y1, y2, atol):
|
||||||
|
|||||||
4
setup.py
4
setup.py
@@ -24,8 +24,8 @@ def get_version():
|
|||||||
if "#define MLX_VERSION_PATCH" in l:
|
if "#define MLX_VERSION_PATCH" in l:
|
||||||
patch = l.split()[-1]
|
patch = l.split()[-1]
|
||||||
version = f"{major}.{minor}.{patch}"
|
version = f"{major}.{minor}.{patch}"
|
||||||
pypi_release = os.environ.get("PYPI_RELEASE", False)
|
pypi_release = int(os.environ.get("PYPI_RELEASE", 0))
|
||||||
dev_release = os.environ.get("DEV_RELEASE", False)
|
dev_release = int(os.environ.get("DEV_RELEASE", 0))
|
||||||
if not pypi_release or dev_release:
|
if not pypi_release or dev_release:
|
||||||
today = datetime.date.today()
|
today = datetime.date.today()
|
||||||
version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}"
|
version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}"
|
||||||
|
|||||||
Reference in New Issue
Block a user