mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Centralize NAX condition (#2811)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
This commit is contained in:
@@ -121,14 +121,6 @@ if(NOT MLX_METAL_PATH)
|
|||||||
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
|
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL
|
|
||||||
26.2))
|
|
||||||
set(MLX_ENABLE_NAX TRUE)
|
|
||||||
target_compile_definitions(mlx PRIVATE MLX_ENABLE_NAX)
|
|
||||||
else()
|
|
||||||
set(MLX_ENABLE_NAX FALSE)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
|
||||||
|
|
||||||
target_compile_definitions(mlx
|
target_compile_definitions(mlx
|
||||||
|
|||||||
@@ -265,14 +265,19 @@ Device& device(mlx::core::Device);
|
|||||||
|
|
||||||
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
||||||
|
|
||||||
#ifdef MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
inline bool is_nax_available() {
|
inline bool is_nax_available() {
|
||||||
static bool is_nax_available_ =
|
auto _check_nax = []() {
|
||||||
metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17;
|
bool can_use_nax = false;
|
||||||
|
if (__builtin_available(
|
||||||
|
macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||||
|
can_use_nax = true;
|
||||||
|
}
|
||||||
|
can_use_nax &=
|
||||||
|
metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17;
|
||||||
|
return can_use_nax;
|
||||||
|
};
|
||||||
|
static bool is_nax_available_ = _check_nax();
|
||||||
return is_nax_available_;
|
return is_nax_available_;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace mlx::core::metal
|
||||||
|
|||||||
@@ -9,13 +9,17 @@ set(BASE_HEADERS
|
|||||||
utils.h)
|
utils.h)
|
||||||
|
|
||||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||||
set(METAL_FLAGS -x metal -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
set(METAL_FLAGS
|
||||||
|
-x
|
||||||
|
metal
|
||||||
|
-Wall
|
||||||
|
-Wextra
|
||||||
|
-fno-fast-math
|
||||||
|
-Wno-c++17-extensions
|
||||||
|
-Wno-c++20-extensions)
|
||||||
if(MLX_METAL_DEBUG)
|
if(MLX_METAL_DEBUG)
|
||||||
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
|
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
|
||||||
endif()
|
endif()
|
||||||
if(MLX_ENABLE_NAX)
|
|
||||||
set(METAL_FLAGS ${METAL_FLAGS} -Wno-c++20-extensions -std=metal4.0)
|
|
||||||
endif()
|
|
||||||
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
||||||
set(METAL_FLAGS ${METAL_FLAGS}
|
set(METAL_FLAGS ${METAL_FLAGS}
|
||||||
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||||
@@ -123,8 +127,8 @@ if(NOT MLX_METAL_JIT)
|
|||||||
build_kernel(gemv_masked steel/utils.h)
|
build_kernel(gemv_masked steel/utils.h)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MLX_ENABLE_NAX)
|
if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL
|
||||||
|
26.2))
|
||||||
set(STEEL_NAX_HEADERS
|
set(STEEL_NAX_HEADERS
|
||||||
steel/defines.h
|
steel/defines.h
|
||||||
steel/utils.h
|
steel/utils.h
|
||||||
|
|||||||
@@ -172,8 +172,6 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
|||||||
// Regular steel matmul dispatch
|
// Regular steel matmul dispatch
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#ifdef MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
template <bool CHECK_AB>
|
template <bool CHECK_AB>
|
||||||
void steel_matmul_regular_axpby_nax(
|
void steel_matmul_regular_axpby_nax(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
@@ -210,11 +208,11 @@ void steel_matmul_regular_axpby_nax(
|
|||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
kname << "steel_gemm_fused_nax_"
|
kname << "steel_gemm_fused_nax_"
|
||||||
<< (transpose_a ? 't' : 'n')
|
<< (transpose_a ? 't' : 'n')
|
||||||
<< (transpose_b ? 't' : 'n')
|
<< (transpose_b ? 't' : 'n')
|
||||||
<< "_" << type_to_name(a)
|
<< "_" << type_to_name(a)
|
||||||
<< "_" << type_to_name(out)
|
<< "_" << type_to_name(out)
|
||||||
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
|
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||||
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
||||||
|
|
||||||
@@ -329,8 +327,6 @@ void steel_matmul_regular_axpby_nax(
|
|||||||
d.add_temporaries(std::move(copies), s.index);
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
template <bool CHECK_AB>
|
template <bool CHECK_AB>
|
||||||
void steel_matmul_regular_axpby(
|
void steel_matmul_regular_axpby(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
@@ -357,41 +353,35 @@ void steel_matmul_regular_axpby(
|
|||||||
int64_t C_batch_stride /* = 0*/,
|
int64_t C_batch_stride /* = 0*/,
|
||||||
float alpha /* = 1.0f */,
|
float alpha /* = 1.0f */,
|
||||||
float beta /* = 0.0f */) {
|
float beta /* = 0.0f */) {
|
||||||
#ifdef MLX_ENABLE_NAX
|
if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) &&
|
||||||
|
(env::enable_tf32() || a.dtype() != float32)) {
|
||||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
return steel_matmul_regular_axpby_nax<CHECK_AB>(
|
||||||
if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) &&
|
/* const Stream& s = */ s,
|
||||||
(env::enable_tf32() || a.dtype() != float32)) {
|
/* metal::Device& d = */ d,
|
||||||
return steel_matmul_regular_axpby_nax<CHECK_AB>(
|
/* const array& a = */ a,
|
||||||
/* const Stream& s = */ s,
|
/* const array& b = */ b,
|
||||||
/* metal::Device& d = */ d,
|
/* const array& c = */ c,
|
||||||
/* const array& a = */ a,
|
/* array& out = */ out,
|
||||||
/* const array& b = */ b,
|
/* int M = */ M,
|
||||||
/* const array& c = */ c,
|
/* int N = */ N,
|
||||||
/* array& out = */ out,
|
/* int K = */ K,
|
||||||
/* int M = */ M,
|
/* int batch_size_out = */ batch_size_out,
|
||||||
/* int N = */ N,
|
/* int lda = */ lda,
|
||||||
/* int K = */ K,
|
/* int ldb = */ ldb,
|
||||||
/* int batch_size_out = */ batch_size_out,
|
/* int ldd = */ ldd,
|
||||||
/* int lda = */ lda,
|
/* bool transpose_a = */ transpose_a,
|
||||||
/* int ldb = */ ldb,
|
/* bool transpose_b = */ transpose_b,
|
||||||
/* int ldd = */ ldd,
|
/* std::vector<array>& copies = */ copies,
|
||||||
/* bool transpose_a = */ transpose_a,
|
/* Shape batch_shape = */ batch_shape,
|
||||||
/* bool transpose_b = */ transpose_b,
|
/* Strides batch_strides = */ batch_strides,
|
||||||
/* std::vector<array>& copies = */ copies,
|
/* int64_t A_batch_stride = */ A_batch_stride,
|
||||||
/* Shape batch_shape = */ batch_shape,
|
/* int64_t B_batch_stride = */ B_batch_stride,
|
||||||
/* Strides batch_strides = */ batch_strides,
|
/* int64_t matrix_stride_out = */ matrix_stride_out,
|
||||||
/* int64_t A_batch_stride = */ A_batch_stride,
|
/* int64_t C_batch_stride = */ C_batch_stride,
|
||||||
/* int64_t B_batch_stride = */ B_batch_stride,
|
/* float alpha = */ alpha,
|
||||||
/* int64_t matrix_stride_out = */ matrix_stride_out,
|
/* float beta = */ beta);
|
||||||
/* int64_t C_batch_stride = */ C_batch_stride,
|
|
||||||
/* float alpha = */ alpha,
|
|
||||||
/* float beta = */ beta);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
using namespace mlx::steel;
|
using namespace mlx::steel;
|
||||||
|
|
||||||
// Determine dispatch kernel
|
// Determine dispatch kernel
|
||||||
@@ -405,11 +395,11 @@ void steel_matmul_regular_axpby(
|
|||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
kname << "steel_gemm_fused_"
|
kname << "steel_gemm_fused_"
|
||||||
<< (transpose_a ? 't' : 'n')
|
<< (transpose_a ? 't' : 'n')
|
||||||
<< (transpose_b ? 't' : 'n')
|
<< (transpose_b ? 't' : 'n')
|
||||||
<< "_" << type_to_name(a)
|
<< "_" << type_to_name(a)
|
||||||
<< "_" << type_to_name(out)
|
<< "_" << type_to_name(out)
|
||||||
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
|
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||||
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
||||||
|
|
||||||
@@ -574,14 +564,14 @@ void steel_gemm_splitk_axpby(
|
|||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
kname << "steel_gemm_splitk_"
|
kname << "steel_gemm_splitk_"
|
||||||
<< (transpose_a ? 't' : 'n')
|
<< (transpose_a ? 't' : 'n')
|
||||||
<< (transpose_b ? 't' : 'n')
|
<< (transpose_b ? 't' : 'n')
|
||||||
<< "_" << type_to_name(a)
|
<< "_" << type_to_name(a)
|
||||||
<< "_" << type_to_name(C_split)
|
<< "_" << type_to_name(C_split)
|
||||||
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
|
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||||
<< "_wm" << wm << "_wn" << wn
|
<< "_wm" << wm << "_wn" << wn
|
||||||
<< "_MN_" << (mn_aligned ? "t" : "n") << "aligned"
|
<< "_MN_" << (mn_aligned ? "t" : "n") << "aligned"
|
||||||
<< "_K_" << (k_aligned ? "t" : "n") << "aligned"; // clang-format on
|
<< "_K_" << (k_aligned ? "t" : "n") << "aligned"; // clang-format on
|
||||||
|
|
||||||
// Encode and dispatch gemm kernel
|
// Encode and dispatch gemm kernel
|
||||||
@@ -915,10 +905,10 @@ void gemv_axbpy(
|
|||||||
const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);
|
const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
kname << "_bm" << bm << "_bn" << bn
|
kname << "_bm" << bm << "_bn" << bn
|
||||||
<< "_sm" << sm << "_sn" << sn
|
<< "_sm" << sm << "_sn" << sn
|
||||||
<< "_tm" << tm << "_tn" << tn
|
<< "_tm" << tm << "_tn" << tn
|
||||||
<< "_nc" << !contiguous_kernel
|
<< "_nc" << !contiguous_kernel
|
||||||
<< "_axpby" << do_axpby; // clang-format on
|
<< "_axpby" << do_axpby; // clang-format on
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
@@ -1766,8 +1756,6 @@ void gather_mm_rhs(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
void gather_mm_rhs_nax(
|
void gather_mm_rhs_nax(
|
||||||
const array& a_,
|
const array& a_,
|
||||||
const array& b_,
|
const array& b_,
|
||||||
@@ -1911,8 +1899,6 @@ void gather_mm_rhs_nax(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
void gather_mv(
|
void gather_mv(
|
||||||
const array& mat_,
|
const array& mat_,
|
||||||
const array& vec_,
|
const array& vec_,
|
||||||
@@ -2196,19 +2182,10 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// We are walking a in order and b is also in order so we can batch up the
|
// We are walking a in order and b is also in order so we can batch up the
|
||||||
// matmuls and reuse reading a and b.
|
// matmuls and reuse reading a and b.
|
||||||
if (M == 1 && right_sorted_ == true) {
|
if (M == 1 && right_sorted_ == true) {
|
||||||
#ifdef MLX_ENABLE_NAX
|
if (metal::is_nax_available() &&
|
||||||
|
(env::enable_tf32() || a.dtype() != float32)) {
|
||||||
if (__builtin_available(
|
return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s);
|
||||||
macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
|
||||||
if (metal::is_nax_available() &&
|
|
||||||
!issubdtype(a.dtype(), complexfloating) &&
|
|
||||||
(env::enable_tf32() || a.dtype() != float32)) {
|
|
||||||
return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
gather_mm_rhs(a, b, rhs_indices, out, d, s);
|
gather_mm_rhs(a, b, rhs_indices, out, d, s);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -451,8 +451,6 @@ void qvm(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
void qmm_nax(
|
void qmm_nax(
|
||||||
const array& x,
|
const array& x,
|
||||||
const array& w,
|
const array& w,
|
||||||
@@ -653,8 +651,6 @@ void gather_qmm_nax(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
void qmm(
|
void qmm(
|
||||||
const array& x,
|
const array& x,
|
||||||
const array& w,
|
const array& w,
|
||||||
@@ -670,31 +666,25 @@ void qmm(
|
|||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
const std::string& mode) {
|
const std::string& mode) {
|
||||||
#ifdef MLX_ENABLE_NAX
|
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
|
||||||
|
(env::enable_tf32() || x.dtype() != float32)) {
|
||||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
return qmm_nax(
|
||||||
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
|
/* const array& x = */ x,
|
||||||
(env::enable_tf32() || x.dtype() != float32)) {
|
/* const array& w = */ w,
|
||||||
return qmm_nax(
|
/* const array& scales = */ scales,
|
||||||
/* const array& x = */ x,
|
/* const std::optional<array>& biases = */ biases,
|
||||||
/* const array& w = */ w,
|
/* array& out = */ out,
|
||||||
/* const array& scales = */ scales,
|
/* bool transpose = */ transpose,
|
||||||
/* const std::optional<array>& biases = */ biases,
|
/* int group_size = */ group_size,
|
||||||
/* array& out = */ out,
|
/* int bits = */ bits,
|
||||||
/* bool transpose = */ transpose,
|
/* int M = */ M,
|
||||||
/* int group_size = */ group_size,
|
/* int N = */ N,
|
||||||
/* int bits = */ bits,
|
/* int K = */ K,
|
||||||
/* int M = */ M,
|
/* metal::Device& d = */ d,
|
||||||
/* int N = */ N,
|
/* const Stream& s = */ s,
|
||||||
/* int K = */ K,
|
/* const std::string& mode = */ mode);
|
||||||
/* metal::Device& d = */ d,
|
|
||||||
/* const Stream& s = */ s,
|
|
||||||
/* const std::string& mode = */ mode);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
int B = out.size() / M / N;
|
int B = out.size() / M / N;
|
||||||
|
|
||||||
int wm = 2;
|
int wm = 2;
|
||||||
@@ -772,33 +762,27 @@ void gather_qmm(
|
|||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
const std::string& mode) {
|
const std::string& mode) {
|
||||||
#ifdef MLX_ENABLE_NAX
|
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
|
||||||
|
(env::enable_tf32() || x.dtype() != float32)) {
|
||||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
return gather_qmm_nax(
|
||||||
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
|
/* const array& x = */ x,
|
||||||
(env::enable_tf32() || x.dtype() != float32)) {
|
/* const array& w = */ w,
|
||||||
return gather_qmm_nax(
|
/* const array& scales = */ scales,
|
||||||
/* const array& x = */ x,
|
/* const std::optional<array>& biases = */ biases,
|
||||||
/* const array& w = */ w,
|
/* const array& lhs_indices = */ lhs_indices,
|
||||||
/* const array& scales = */ scales,
|
/* const array& rhs_indices = */ rhs_indices,
|
||||||
/* const std::optional<array>& biases = */ biases,
|
/* array& out = */ out,
|
||||||
/* const array& lhs_indices = */ lhs_indices,
|
/* bool transpose = */ transpose,
|
||||||
/* const array& rhs_indices = */ rhs_indices,
|
/* int group_size = */ group_size,
|
||||||
/* array& out = */ out,
|
/* int bits = */ bits,
|
||||||
/* bool transpose = */ transpose,
|
/* int M = */ M,
|
||||||
/* int group_size = */ group_size,
|
/* int N = */ N,
|
||||||
/* int bits = */ bits,
|
/* int K = */ K,
|
||||||
/* int M = */ M,
|
/* metal::Device& d = */ d,
|
||||||
/* int N = */ N,
|
/* const Stream& s = */ s,
|
||||||
/* int K = */ K,
|
/* const std::string& mode = */ mode);
|
||||||
/* metal::Device& d = */ d,
|
|
||||||
/* const Stream& s = */ s,
|
|
||||||
/* const std::string& mode = */ mode);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
int B = out.size() / M / N;
|
int B = out.size() / M / N;
|
||||||
|
|
||||||
int wm = 2;
|
int wm = 2;
|
||||||
@@ -975,8 +959,6 @@ void gather_qvm(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
void gather_qmm_rhs_nax(
|
void gather_qmm_rhs_nax(
|
||||||
const array& x_,
|
const array& x_,
|
||||||
const array& w_,
|
const array& w_,
|
||||||
@@ -1108,8 +1090,6 @@ void gather_qmm_rhs_nax(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
void gather_qmm_rhs(
|
void gather_qmm_rhs(
|
||||||
const array& x_,
|
const array& x_,
|
||||||
const array& w_,
|
const array& w_,
|
||||||
@@ -1126,32 +1106,26 @@ void gather_qmm_rhs(
|
|||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
const std::string mode) {
|
const std::string mode) {
|
||||||
#ifdef MLX_ENABLE_NAX
|
if (metal::is_nax_available() && transpose &&
|
||||||
|
(env::enable_tf32() || x_.dtype() != float32)) {
|
||||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
return gather_qmm_rhs_nax(
|
||||||
if (metal::is_nax_available() && transpose &&
|
/* const array& x_ = */ x_,
|
||||||
(env::enable_tf32() || x_.dtype() != float32)) {
|
/* const array& w_ = */ w_,
|
||||||
return gather_qmm_rhs_nax(
|
/* const array& scales_ = */ scales_,
|
||||||
/* const array& x_ = */ x_,
|
/* const std::optional<array>& biases_ = */ biases_,
|
||||||
/* const array& w_ = */ w_,
|
/* const array& indices_ = */ indices_,
|
||||||
/* const array& scales_ = */ scales_,
|
/* array& out = */ out,
|
||||||
/* const std::optional<array>& biases_ = */ biases_,
|
/* bool transpose = */ transpose,
|
||||||
/* const array& indices_ = */ indices_,
|
/* int group_size = */ group_size,
|
||||||
/* array& out = */ out,
|
/* int bits = */ bits,
|
||||||
/* bool transpose = */ transpose,
|
/* int M = */ M,
|
||||||
/* int group_size = */ group_size,
|
/* int N = */ N,
|
||||||
/* int bits = */ bits,
|
/* int K = */ K,
|
||||||
/* int M = */ M,
|
/* metal::Device& d = */ d,
|
||||||
/* int N = */ N,
|
/* const Stream& s = */ s,
|
||||||
/* int K = */ K,
|
/* const std::string mode = */ mode);
|
||||||
/* metal::Device& d = */ d,
|
|
||||||
/* const Stream& s = */ s,
|
|
||||||
/* const std::string mode = */ mode);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
// Start by normalizing the indices
|
// Start by normalizing the indices
|
||||||
array indices = ensure_row_contiguous(indices_, d, s);
|
array indices = ensure_row_contiguous(indices_, d, s);
|
||||||
|
|
||||||
|
|||||||
@@ -13,8 +13,6 @@ namespace mlx::core::fast {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#ifdef MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
void sdpa_full_self_attention_nax(
|
void sdpa_full_self_attention_nax(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
@@ -150,8 +148,6 @@ void sdpa_full_self_attention_nax(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
void sdpa_full_self_attention_metal(
|
void sdpa_full_self_attention_metal(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
@@ -163,24 +159,20 @@ void sdpa_full_self_attention_metal(
|
|||||||
bool do_causal_,
|
bool do_causal_,
|
||||||
const std::optional<array>& mask,
|
const std::optional<array>& mask,
|
||||||
const std::optional<array>& sinks) {
|
const std::optional<array>& sinks) {
|
||||||
#ifdef MLX_ENABLE_NAX
|
if (metal::is_nax_available() && q.shape(3) != 80 &&
|
||||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
(env::enable_tf32() || q.dtype() != float32)) {
|
||||||
if (metal::is_nax_available() && q.shape(3) != 80 &&
|
return sdpa_full_self_attention_nax(
|
||||||
(env::enable_tf32() || q.dtype() != float32)) {
|
/* const Stream& s = */ s,
|
||||||
return sdpa_full_self_attention_nax(
|
/* metal::Device& d = */ d,
|
||||||
/* const Stream& s = */ s,
|
/* const array& q = */ q,
|
||||||
/* metal::Device& d = */ d,
|
/* const array& k = */ k,
|
||||||
/* const array& q = */ q,
|
/* const array& v = */ v,
|
||||||
/* const array& k = */ k,
|
/* const float scale = */ scale,
|
||||||
/* const array& v = */ v,
|
/* array& o = */ o,
|
||||||
/* const float scale = */ scale,
|
/* bool do_causal_ = */ do_causal_,
|
||||||
/* array& o = */ o,
|
/* const std::optional<array>& mask = */ mask,
|
||||||
/* bool do_causal_ = */ do_causal_,
|
/* const std::optional<array>& sinks = */ sinks);
|
||||||
/* const std::optional<array>& mask = */ mask,
|
|
||||||
/* const std::optional<array>& sinks = */ sinks);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
#endif // MLX_ENABLE_NAX
|
|
||||||
|
|
||||||
using namespace mlx::steel;
|
using namespace mlx::steel;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user