mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
8 Commits
v0.30.0
...
d5f61a93fa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d5f61a93fa | ||
|
|
4a09264236 | ||
|
|
0dbc7e5bee | ||
|
|
0d68efd461 | ||
|
|
f9e1a14135 | ||
|
|
d8e9ded928 | ||
|
|
60939d010c | ||
|
|
fdcd2923fd |
@@ -17,6 +17,8 @@ runs:
|
||||
steps:
|
||||
- name: Build Python package
|
||||
shell: bash -l {0}
|
||||
env:
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||
run: |
|
||||
pip install build
|
||||
python setup.py clean --all
|
||||
@@ -25,6 +27,8 @@ runs:
|
||||
- name: Build backend package
|
||||
if: ${{ inputs.build-backend }}
|
||||
shell: bash -l {0}
|
||||
env:
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||
run: |
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=2 python -m build -w
|
||||
|
||||
2
.github/workflows/pull_request.yml
vendored
2
.github/workflows/pull_request.yml
vendored
@@ -13,7 +13,7 @@ permissions:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/head/main' }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
jobs:
|
||||
check_lint:
|
||||
|
||||
9
.github/workflows/release.yml
vendored
9
.github/workflows/release.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
|
||||
build_documentation:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
runs-on: [self-hosted, macos]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/build-docs
|
||||
@@ -65,14 +65,14 @@ jobs:
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
overwrite: true
|
||||
name: linux-wheels-${{ matrix.python_version }}
|
||||
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
|
||||
path: wheelhouse/mlx-*.whl
|
||||
- name: Upload CPU artifacts
|
||||
if: matrix.python_version == '3.10'
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
overwrite: true
|
||||
name: mlx-cpu
|
||||
name: mlx-cpu-${{ matrix.arch }}
|
||||
path: wheelhouse/mlx_cpu-*.whl
|
||||
|
||||
build_mac_release:
|
||||
@@ -208,7 +208,8 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
with:
|
||||
name: mlx-cpu
|
||||
pattern: mlx-cpu-*
|
||||
merge-multiple: true
|
||||
path: dist
|
||||
- name: Display structure of downloaded files
|
||||
run: ls -R dist
|
||||
|
||||
@@ -92,7 +92,7 @@ CudaAllocator::CudaAllocator()
|
||||
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||
size_t free, total;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||
memory_limit_ = total * 0.95;
|
||||
memory_limit_ = total * 0.9;
|
||||
max_pool_size_ = memory_limit_;
|
||||
|
||||
int device_count = 0;
|
||||
@@ -176,7 +176,7 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
|
||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||
}
|
||||
// Copy to managed here if the buffer is not on the right device
|
||||
if (buf->device != device) {
|
||||
if (buf->device >= 0 && buf->device != device) {
|
||||
copy_to_managed(*buf);
|
||||
}
|
||||
return Buffer{buf};
|
||||
@@ -219,9 +219,9 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
||||
scalar_pool_.free(buf);
|
||||
} else {
|
||||
if (buf->device >= 0) {
|
||||
cudaFreeAsync(buf->data, free_streams_[buf->device]);
|
||||
CHECK_CUDA_ERROR(cudaFreeAsync(buf->data, free_streams_[buf->device]));
|
||||
} else {
|
||||
cudaFree(buf->data);
|
||||
CHECK_CUDA_ERROR(cudaFree(buf->data));
|
||||
}
|
||||
delete buf;
|
||||
}
|
||||
|
||||
@@ -139,10 +139,10 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// keys has shape (N1, ..., NK, 2)
|
||||
// out has shape (N1, ..., NK, M1, M2, ...)
|
||||
auto& keys = inputs[0];
|
||||
uint32_t num_keys = keys.size() / 2;
|
||||
size_t num_keys = keys.size() / 2;
|
||||
|
||||
uint32_t elems_per_key = out.size() / num_keys;
|
||||
uint32_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
size_t elems_per_key = out.size() / num_keys;
|
||||
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||
@@ -150,19 +150,25 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t out_per_key = (bytes_per_key + 4 - 1) / 4;
|
||||
uint32_t half_size = out_per_key / 2;
|
||||
size_t out_per_key = (bytes_per_key + 4 - 1) / 4;
|
||||
size_t half_size = out_per_key / 2;
|
||||
|
||||
bool odd = out_per_key % 2;
|
||||
if ((half_size + odd) >= UINT32_MAX || num_keys >= UINT32_MAX) {
|
||||
throw std::runtime_error("[RandomBits::eval_gpu] Large size unsupported");
|
||||
}
|
||||
|
||||
encoder.set_input_array(keys);
|
||||
encoder.set_output_array(out);
|
||||
dim3 grid_dims{num_keys, half_size + odd};
|
||||
int64_t total = grid_dims.x * grid_dims.y;
|
||||
int32_t threads_y = 1;
|
||||
while ((total / threads_y) >= (1U << 31)) {
|
||||
int64_t total = num_keys * (half_size + odd);
|
||||
uint32_t threads_y = 1;
|
||||
while ((total / threads_y) >= UINT_MAX) {
|
||||
threads_y *= 2;
|
||||
}
|
||||
int32_t threads_x = cuda::ceil_div(total, threads_y);
|
||||
uint32_t threads_x = cuda::ceil_div(total, threads_y);
|
||||
|
||||
dim3 grid_dims{
|
||||
static_cast<uint32_t>(num_keys), static_cast<uint32_t>(half_size + odd)};
|
||||
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
|
||||
auto& stream = encoder.stream();
|
||||
if (keys.flags().row_contiguous) {
|
||||
|
||||
@@ -121,14 +121,6 @@ if(NOT MLX_METAL_PATH)
|
||||
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
|
||||
endif()
|
||||
|
||||
if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL
|
||||
26.2))
|
||||
set(MLX_ENABLE_NAX TRUE)
|
||||
target_compile_definitions(mlx PRIVATE MLX_ENABLE_NAX)
|
||||
else()
|
||||
set(MLX_ENABLE_NAX FALSE)
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
|
||||
|
||||
target_compile_definitions(mlx
|
||||
|
||||
@@ -265,14 +265,19 @@ Device& device(mlx::core::Device);
|
||||
|
||||
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
||||
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
inline bool is_nax_available() {
|
||||
static bool is_nax_available_ =
|
||||
metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17;
|
||||
auto _check_nax = []() {
|
||||
bool can_use_nax = false;
|
||||
if (__builtin_available(
|
||||
macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
can_use_nax = true;
|
||||
}
|
||||
can_use_nax &=
|
||||
metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17;
|
||||
return can_use_nax;
|
||||
};
|
||||
static bool is_nax_available_ = _check_nax();
|
||||
return is_nax_available_;
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
||||
@@ -9,13 +9,17 @@ set(BASE_HEADERS
|
||||
utils.h)
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
set(METAL_FLAGS -x metal -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||
set(METAL_FLAGS
|
||||
-x
|
||||
metal
|
||||
-Wall
|
||||
-Wextra
|
||||
-fno-fast-math
|
||||
-Wno-c++17-extensions
|
||||
-Wno-c++20-extensions)
|
||||
if(MLX_METAL_DEBUG)
|
||||
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
|
||||
endif()
|
||||
if(MLX_ENABLE_NAX)
|
||||
set(METAL_FLAGS ${METAL_FLAGS} -Wno-c++20-extensions -std=metal4.0)
|
||||
endif()
|
||||
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
||||
set(METAL_FLAGS ${METAL_FLAGS}
|
||||
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
@@ -123,8 +127,8 @@ if(NOT MLX_METAL_JIT)
|
||||
build_kernel(gemv_masked steel/utils.h)
|
||||
endif()
|
||||
|
||||
if(MLX_ENABLE_NAX)
|
||||
|
||||
if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL
|
||||
26.2))
|
||||
set(STEEL_NAX_HEADERS
|
||||
steel/defines.h
|
||||
steel/utils.h
|
||||
|
||||
@@ -172,8 +172,6 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
||||
// Regular steel matmul dispatch
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
template <bool CHECK_AB>
|
||||
void steel_matmul_regular_axpby_nax(
|
||||
const Stream& s,
|
||||
@@ -210,11 +208,11 @@ void steel_matmul_regular_axpby_nax(
|
||||
std::ostringstream kname;
|
||||
|
||||
// clang-format off
|
||||
kname << "steel_gemm_fused_nax_"
|
||||
kname << "steel_gemm_fused_nax_"
|
||||
<< (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n')
|
||||
<< "_" << type_to_name(a)
|
||||
<< "_" << type_to_name(out)
|
||||
<< (transpose_b ? 't' : 'n')
|
||||
<< "_" << type_to_name(a)
|
||||
<< "_" << type_to_name(out)
|
||||
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
||||
|
||||
@@ -329,8 +327,6 @@ void steel_matmul_regular_axpby_nax(
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
template <bool CHECK_AB>
|
||||
void steel_matmul_regular_axpby(
|
||||
const Stream& s,
|
||||
@@ -357,41 +353,35 @@ void steel_matmul_regular_axpby(
|
||||
int64_t C_batch_stride /* = 0*/,
|
||||
float alpha /* = 1.0f */,
|
||||
float beta /* = 0.0f */) {
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) &&
|
||||
(env::enable_tf32() || a.dtype() != float32)) {
|
||||
return steel_matmul_regular_axpby_nax<CHECK_AB>(
|
||||
/* const Stream& s = */ s,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const array& a = */ a,
|
||||
/* const array& b = */ b,
|
||||
/* const array& c = */ c,
|
||||
/* array& out = */ out,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* int batch_size_out = */ batch_size_out,
|
||||
/* int lda = */ lda,
|
||||
/* int ldb = */ ldb,
|
||||
/* int ldd = */ ldd,
|
||||
/* bool transpose_a = */ transpose_a,
|
||||
/* bool transpose_b = */ transpose_b,
|
||||
/* std::vector<array>& copies = */ copies,
|
||||
/* Shape batch_shape = */ batch_shape,
|
||||
/* Strides batch_strides = */ batch_strides,
|
||||
/* int64_t A_batch_stride = */ A_batch_stride,
|
||||
/* int64_t B_batch_stride = */ B_batch_stride,
|
||||
/* int64_t matrix_stride_out = */ matrix_stride_out,
|
||||
/* int64_t C_batch_stride = */ C_batch_stride,
|
||||
/* float alpha = */ alpha,
|
||||
/* float beta = */ beta);
|
||||
}
|
||||
if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) &&
|
||||
(env::enable_tf32() || a.dtype() != float32)) {
|
||||
return steel_matmul_regular_axpby_nax<CHECK_AB>(
|
||||
/* const Stream& s = */ s,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const array& a = */ a,
|
||||
/* const array& b = */ b,
|
||||
/* const array& c = */ c,
|
||||
/* array& out = */ out,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* int batch_size_out = */ batch_size_out,
|
||||
/* int lda = */ lda,
|
||||
/* int ldb = */ ldb,
|
||||
/* int ldd = */ ldd,
|
||||
/* bool transpose_a = */ transpose_a,
|
||||
/* bool transpose_b = */ transpose_b,
|
||||
/* std::vector<array>& copies = */ copies,
|
||||
/* Shape batch_shape = */ batch_shape,
|
||||
/* Strides batch_strides = */ batch_strides,
|
||||
/* int64_t A_batch_stride = */ A_batch_stride,
|
||||
/* int64_t B_batch_stride = */ B_batch_stride,
|
||||
/* int64_t matrix_stride_out = */ matrix_stride_out,
|
||||
/* int64_t C_batch_stride = */ C_batch_stride,
|
||||
/* float alpha = */ alpha,
|
||||
/* float beta = */ beta);
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
// Determine dispatch kernel
|
||||
@@ -405,11 +395,11 @@ void steel_matmul_regular_axpby(
|
||||
std::ostringstream kname;
|
||||
|
||||
// clang-format off
|
||||
kname << "steel_gemm_fused_"
|
||||
kname << "steel_gemm_fused_"
|
||||
<< (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n')
|
||||
<< "_" << type_to_name(a)
|
||||
<< "_" << type_to_name(out)
|
||||
<< (transpose_b ? 't' : 'n')
|
||||
<< "_" << type_to_name(a)
|
||||
<< "_" << type_to_name(out)
|
||||
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
||||
|
||||
@@ -574,14 +564,14 @@ void steel_gemm_splitk_axpby(
|
||||
std::ostringstream kname;
|
||||
|
||||
// clang-format off
|
||||
kname << "steel_gemm_splitk_"
|
||||
kname << "steel_gemm_splitk_"
|
||||
<< (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n')
|
||||
<< "_" << type_to_name(a)
|
||||
<< "_" << type_to_name(C_split)
|
||||
<< (transpose_b ? 't' : 'n')
|
||||
<< "_" << type_to_name(a)
|
||||
<< "_" << type_to_name(C_split)
|
||||
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn
|
||||
<< "_MN_" << (mn_aligned ? "t" : "n") << "aligned"
|
||||
<< "_wm" << wm << "_wn" << wn
|
||||
<< "_MN_" << (mn_aligned ? "t" : "n") << "aligned"
|
||||
<< "_K_" << (k_aligned ? "t" : "n") << "aligned"; // clang-format on
|
||||
|
||||
// Encode and dispatch gemm kernel
|
||||
@@ -915,10 +905,10 @@ void gemv_axbpy(
|
||||
const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);
|
||||
|
||||
// clang-format off
|
||||
kname << "_bm" << bm << "_bn" << bn
|
||||
<< "_sm" << sm << "_sn" << sn
|
||||
kname << "_bm" << bm << "_bn" << bn
|
||||
<< "_sm" << sm << "_sn" << sn
|
||||
<< "_tm" << tm << "_tn" << tn
|
||||
<< "_nc" << !contiguous_kernel
|
||||
<< "_nc" << !contiguous_kernel
|
||||
<< "_axpby" << do_axpby; // clang-format on
|
||||
|
||||
// Encode and dispatch kernel
|
||||
@@ -1766,8 +1756,6 @@ void gather_mm_rhs(
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
void gather_mm_rhs_nax(
|
||||
const array& a_,
|
||||
const array& b_,
|
||||
@@ -1911,8 +1899,6 @@ void gather_mm_rhs_nax(
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
void gather_mv(
|
||||
const array& mat_,
|
||||
const array& vec_,
|
||||
@@ -2196,19 +2182,10 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// We are walking a in order and b is also in order so we can batch up the
|
||||
// matmuls and reuse reading a and b.
|
||||
if (M == 1 && right_sorted_ == true) {
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
if (__builtin_available(
|
||||
macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
if (metal::is_nax_available() &&
|
||||
!issubdtype(a.dtype(), complexfloating) &&
|
||||
(env::enable_tf32() || a.dtype() != float32)) {
|
||||
return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s);
|
||||
}
|
||||
if (metal::is_nax_available() &&
|
||||
(env::enable_tf32() || a.dtype() != float32)) {
|
||||
return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s);
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
gather_mm_rhs(a, b, rhs_indices, out, d, s);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -451,8 +451,6 @@ void qvm(
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
void qmm_nax(
|
||||
const array& x,
|
||||
const array& w,
|
||||
@@ -653,8 +651,6 @@ void gather_qmm_nax(
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
void qmm(
|
||||
const array& x,
|
||||
const array& w,
|
||||
@@ -670,31 +666,25 @@ void qmm(
|
||||
metal::Device& d,
|
||||
const Stream& s,
|
||||
const std::string& mode) {
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
|
||||
(env::enable_tf32() || x.dtype() != float32)) {
|
||||
return qmm_nax(
|
||||
/* const array& x = */ x,
|
||||
/* const array& w = */ w,
|
||||
/* const array& scales = */ scales,
|
||||
/* const std::optional<array>& biases = */ biases,
|
||||
/* array& out = */ out,
|
||||
/* bool transpose = */ transpose,
|
||||
/* int group_size = */ group_size,
|
||||
/* int bits = */ bits,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const Stream& s = */ s,
|
||||
/* const std::string& mode = */ mode);
|
||||
}
|
||||
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
|
||||
(env::enable_tf32() || x.dtype() != float32)) {
|
||||
return qmm_nax(
|
||||
/* const array& x = */ x,
|
||||
/* const array& w = */ w,
|
||||
/* const array& scales = */ scales,
|
||||
/* const std::optional<array>& biases = */ biases,
|
||||
/* array& out = */ out,
|
||||
/* bool transpose = */ transpose,
|
||||
/* int group_size = */ group_size,
|
||||
/* int bits = */ bits,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const Stream& s = */ s,
|
||||
/* const std::string& mode = */ mode);
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
int B = out.size() / M / N;
|
||||
|
||||
int wm = 2;
|
||||
@@ -772,33 +762,27 @@ void gather_qmm(
|
||||
metal::Device& d,
|
||||
const Stream& s,
|
||||
const std::string& mode) {
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
|
||||
(env::enable_tf32() || x.dtype() != float32)) {
|
||||
return gather_qmm_nax(
|
||||
/* const array& x = */ x,
|
||||
/* const array& w = */ w,
|
||||
/* const array& scales = */ scales,
|
||||
/* const std::optional<array>& biases = */ biases,
|
||||
/* const array& lhs_indices = */ lhs_indices,
|
||||
/* const array& rhs_indices = */ rhs_indices,
|
||||
/* array& out = */ out,
|
||||
/* bool transpose = */ transpose,
|
||||
/* int group_size = */ group_size,
|
||||
/* int bits = */ bits,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const Stream& s = */ s,
|
||||
/* const std::string& mode = */ mode);
|
||||
}
|
||||
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
|
||||
(env::enable_tf32() || x.dtype() != float32)) {
|
||||
return gather_qmm_nax(
|
||||
/* const array& x = */ x,
|
||||
/* const array& w = */ w,
|
||||
/* const array& scales = */ scales,
|
||||
/* const std::optional<array>& biases = */ biases,
|
||||
/* const array& lhs_indices = */ lhs_indices,
|
||||
/* const array& rhs_indices = */ rhs_indices,
|
||||
/* array& out = */ out,
|
||||
/* bool transpose = */ transpose,
|
||||
/* int group_size = */ group_size,
|
||||
/* int bits = */ bits,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const Stream& s = */ s,
|
||||
/* const std::string& mode = */ mode);
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
int B = out.size() / M / N;
|
||||
|
||||
int wm = 2;
|
||||
@@ -975,8 +959,6 @@ void gather_qvm(
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
void gather_qmm_rhs_nax(
|
||||
const array& x_,
|
||||
const array& w_,
|
||||
@@ -1108,8 +1090,6 @@ void gather_qmm_rhs_nax(
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
void gather_qmm_rhs(
|
||||
const array& x_,
|
||||
const array& w_,
|
||||
@@ -1126,32 +1106,26 @@ void gather_qmm_rhs(
|
||||
metal::Device& d,
|
||||
const Stream& s,
|
||||
const std::string mode) {
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
if (metal::is_nax_available() && transpose &&
|
||||
(env::enable_tf32() || x_.dtype() != float32)) {
|
||||
return gather_qmm_rhs_nax(
|
||||
/* const array& x_ = */ x_,
|
||||
/* const array& w_ = */ w_,
|
||||
/* const array& scales_ = */ scales_,
|
||||
/* const std::optional<array>& biases_ = */ biases_,
|
||||
/* const array& indices_ = */ indices_,
|
||||
/* array& out = */ out,
|
||||
/* bool transpose = */ transpose,
|
||||
/* int group_size = */ group_size,
|
||||
/* int bits = */ bits,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const Stream& s = */ s,
|
||||
/* const std::string mode = */ mode);
|
||||
}
|
||||
if (metal::is_nax_available() && transpose &&
|
||||
(env::enable_tf32() || x_.dtype() != float32)) {
|
||||
return gather_qmm_rhs_nax(
|
||||
/* const array& x_ = */ x_,
|
||||
/* const array& w_ = */ w_,
|
||||
/* const array& scales_ = */ scales_,
|
||||
/* const std::optional<array>& biases_ = */ biases_,
|
||||
/* const array& indices_ = */ indices_,
|
||||
/* array& out = */ out,
|
||||
/* bool transpose = */ transpose,
|
||||
/* int group_size = */ group_size,
|
||||
/* int bits = */ bits,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const Stream& s = */ s,
|
||||
/* const std::string mode = */ mode);
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
// Start by normalizing the indices
|
||||
array indices = ensure_row_contiguous(indices_, d, s);
|
||||
|
||||
|
||||
@@ -13,8 +13,6 @@ namespace mlx::core::fast {
|
||||
|
||||
namespace {
|
||||
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
void sdpa_full_self_attention_nax(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@@ -150,8 +148,6 @@ void sdpa_full_self_attention_nax(
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
void sdpa_full_self_attention_metal(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@@ -163,24 +159,20 @@ void sdpa_full_self_attention_metal(
|
||||
bool do_causal_,
|
||||
const std::optional<array>& mask,
|
||||
const std::optional<array>& sinks) {
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
if (metal::is_nax_available() && q.shape(3) != 80 &&
|
||||
(env::enable_tf32() || q.dtype() != float32)) {
|
||||
return sdpa_full_self_attention_nax(
|
||||
/* const Stream& s = */ s,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const array& q = */ q,
|
||||
/* const array& k = */ k,
|
||||
/* const array& v = */ v,
|
||||
/* const float scale = */ scale,
|
||||
/* array& o = */ o,
|
||||
/* bool do_causal_ = */ do_causal_,
|
||||
/* const std::optional<array>& mask = */ mask,
|
||||
/* const std::optional<array>& sinks = */ sinks);
|
||||
}
|
||||
if (metal::is_nax_available() && q.shape(3) != 80 &&
|
||||
(env::enable_tf32() || q.dtype() != float32)) {
|
||||
return sdpa_full_self_attention_nax(
|
||||
/* const Stream& s = */ s,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const array& q = */ q,
|
||||
/* const array& k = */ k,
|
||||
/* const array& v = */ v,
|
||||
/* const float scale = */ scale,
|
||||
/* array& o = */ o,
|
||||
/* bool do_causal_ = */ do_causal_,
|
||||
/* const std::optional<array>& mask = */ mask,
|
||||
/* const std::optional<array>& sinks = */ sinks);
|
||||
}
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
#define MLX_VERSION_MAJOR 0
|
||||
#define MLX_VERSION_MINOR 30
|
||||
#define MLX_VERSION_PATCH 0
|
||||
#define MLX_VERSION_PATCH 1
|
||||
#define MLX_VERSION_NUMERIC \
|
||||
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
||||
|
||||
|
||||
@@ -1443,23 +1443,22 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
|
||||
def test_unary_ops(self):
|
||||
def test_ops(npop, mlxop, x, y, atol):
|
||||
def test_ops(npop, mlxop, x, y, atol, rtol):
|
||||
r_np = npop(x)
|
||||
r_mlx = mlxop(y)
|
||||
mx.eval(r_mlx)
|
||||
|
||||
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))
|
||||
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol, rtol=rtol))
|
||||
|
||||
x = np.random.rand(18, 28, 38)
|
||||
for op in ["abs", "exp", "log", "square", "sqrt"]:
|
||||
with self.subTest(op=op):
|
||||
float_dtypes = [("float16", 1e-3), ("float32", 1e-6)]
|
||||
float_dtypes = [("float16", 1e-3, 1e-3), ("float32", 1e-6, 1e-5)]
|
||||
|
||||
for dtype, atol in float_dtypes:
|
||||
for dtype, atol, rtol in float_dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
x_ = x.astype(getattr(np, dtype))
|
||||
y_ = mx.array(x_)
|
||||
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
|
||||
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol, rtol)
|
||||
|
||||
def test_unary_ops_from_non_array(self):
|
||||
unary_ops = [
|
||||
@@ -1511,12 +1510,14 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.allclose(expected, np.array(out), equal_nan=True))
|
||||
|
||||
def test_trig_ops(self):
|
||||
def test_ops(npop, mlxop, x, y, atol):
|
||||
def test_ops(npop, mlxop, x, y, atol, rtol):
|
||||
r_np = npop(x)
|
||||
r_mlx = mlxop(y)
|
||||
mx.eval(r_mlx)
|
||||
|
||||
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol, equal_nan=True))
|
||||
self.assertTrue(
|
||||
np.allclose(r_np, r_mlx, atol=atol, rtol=rtol, equal_nan=True)
|
||||
)
|
||||
|
||||
x = np.random.rand(9, 12, 18)
|
||||
xi = np.random.rand(9, 12, 18)
|
||||
@@ -1526,34 +1527,34 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
for op in all_fwd_ops:
|
||||
with self.subTest(op=op):
|
||||
float_dtypes = [("float16", 1e-3), ("float32", 1e-6)]
|
||||
float_dtypes = [("float16", 1e-3, 1e-3), ("float32", 1e-6, 1e-5)]
|
||||
|
||||
for dtype, atol in float_dtypes:
|
||||
for dtype, atol, rtol in float_dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
x_ = x.astype(getattr(np, dtype))
|
||||
y_ = mx.array(x_)
|
||||
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
|
||||
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol, rtol)
|
||||
|
||||
with self.subTest(op=op):
|
||||
float_dtypes = [("complex64", 1e-5)]
|
||||
|
||||
for dtype, atol in float_dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
x_ = x + 1.0j * xi
|
||||
x_ = x_.astype(getattr(np, dtype))
|
||||
y_ = mx.array(x_)
|
||||
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
|
||||
dtype = "complex64"
|
||||
with self.subTest(dtype=dtype):
|
||||
x_ = x + 1.0j * xi
|
||||
x_ = x_.astype(getattr(np, dtype))
|
||||
y_ = mx.array(x_)
|
||||
test_ops(getattr(np, op), getattr(mx, op), x_, y_, 1e-5, 1e-5)
|
||||
|
||||
with self.subTest(op="arc" + op):
|
||||
float_dtypes = [("float16", 1e-3), ("float32", 1e-6)]
|
||||
float_dtypes = [("float16", 1e-3, 1e-3), ("float32", 1e-6, 1e-5)]
|
||||
op_inv = "arc" + op
|
||||
|
||||
for dtype, atol in float_dtypes:
|
||||
for dtype, atol, rtol in float_dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
np_op_fwd = getattr(np, op)
|
||||
x_ = np_op_fwd(x).astype(getattr(np, dtype))
|
||||
y_ = mx.array(x_)
|
||||
test_ops(getattr(np, op_inv), getattr(mx, op_inv), x_, y_, atol)
|
||||
test_ops(
|
||||
getattr(np, op_inv), getattr(mx, op_inv), x_, y_, atol, rtol
|
||||
)
|
||||
|
||||
# Test grads
|
||||
np_vjp_funcs = {
|
||||
@@ -1579,11 +1580,10 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
x_ = x.astype(np.float32)
|
||||
y_ = mx.array(x_)
|
||||
op_ = op
|
||||
atol_ = 1e-5
|
||||
|
||||
np_vjp = lambda x: np_vjp_funcs[op_](primal_np, x)
|
||||
mx_vjp = lambda x: mx.vjp(getattr(mx, op_), [primal_mx], [x])[1][0]
|
||||
test_ops(np_vjp, mx_vjp, x_, y_, atol_)
|
||||
test_ops(np_vjp, mx_vjp, x_, y_, 1e-5, 1e-5)
|
||||
|
||||
with self.subTest(op="arc" + op):
|
||||
np_op_fwd = getattr(np, op)
|
||||
@@ -1599,11 +1599,10 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
x_ = x.astype(np.float32)
|
||||
y_ = mx.array(x_)
|
||||
op_ = "arc" + op
|
||||
atol_ = 1e-5
|
||||
|
||||
np_vjp = lambda x: np_vjp_funcs[op_](primal_np, x)
|
||||
mx_vjp = lambda x: mx.vjp(getattr(mx, op_), [primal_mx], [x])[1][0]
|
||||
test_ops(np_vjp, mx_vjp, x_, y_, atol_)
|
||||
test_ops(np_vjp, mx_vjp, x_, y_, 1e-5, 1e-5)
|
||||
|
||||
def test_binary_ops(self):
|
||||
def test_ops(npop, mlxop, x1, x2, y1, y2, atol):
|
||||
|
||||
4
setup.py
4
setup.py
@@ -24,8 +24,8 @@ def get_version():
|
||||
if "#define MLX_VERSION_PATCH" in l:
|
||||
patch = l.split()[-1]
|
||||
version = f"{major}.{minor}.{patch}"
|
||||
pypi_release = os.environ.get("PYPI_RELEASE", False)
|
||||
dev_release = os.environ.get("DEV_RELEASE", False)
|
||||
pypi_release = int(os.environ.get("PYPI_RELEASE", 0))
|
||||
dev_release = int(os.environ.get("DEV_RELEASE", 0))
|
||||
if not pypi_release or dev_release:
|
||||
today = datetime.date.today()
|
||||
version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}"
|
||||
|
||||
Reference in New Issue
Block a user