mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add Neural Accelerator Support (#2772)
This commit is contained in:
@@ -121,6 +121,14 @@ if(NOT MLX_METAL_PATH)
|
||||
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
|
||||
endif()
|
||||
|
||||
if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL
|
||||
26.2))
|
||||
set(MLX_ENABLE_NAX TRUE)
|
||||
target_compile_definitions(mlx PRIVATE MLX_ENABLE_NAX)
|
||||
else()
|
||||
set(MLX_ENABLE_NAX FALSE)
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
|
||||
|
||||
target_compile_definitions(mlx
|
||||
|
||||
@@ -265,4 +265,14 @@ Device& device(mlx::core::Device);
|
||||
|
||||
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
||||
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
inline bool is_nax_available() {
|
||||
static bool is_nax_available_ =
|
||||
metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17;
|
||||
return is_nax_available_;
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
||||
@@ -9,10 +9,13 @@ set(BASE_HEADERS
|
||||
utils.h)
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||
set(METAL_FLAGS -x metal -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||
if(MLX_METAL_DEBUG)
|
||||
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
|
||||
endif()
|
||||
if(MLX_ENABLE_NAX)
|
||||
set(METAL_FLAGS ${METAL_FLAGS} -Wno-c++20-extensions -std=metal4.0)
|
||||
endif()
|
||||
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
||||
set(METAL_FLAGS ${METAL_FLAGS}
|
||||
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
@@ -120,6 +123,30 @@ if(NOT MLX_METAL_JIT)
|
||||
build_kernel(gemv_masked steel/utils.h)
|
||||
endif()
|
||||
|
||||
if(MLX_ENABLE_NAX)
|
||||
|
||||
set(STEEL_NAX_HEADERS
|
||||
steel/defines.h
|
||||
steel/utils.h
|
||||
steel/gemm/transforms.h
|
||||
steel/gemm/nax.h
|
||||
steel/gemm/gemm_nax.h
|
||||
steel/utils/type_traits.h
|
||||
steel/utils/integral_constant.h)
|
||||
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_fused_nax ${STEEL_NAX_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS})
|
||||
|
||||
build_kernel(quantized_nax quantized_nax.h ${STEEL_NAX_HEADERS})
|
||||
build_kernel(fp_quantized_nax fp_quantized_nax.h ${STEEL_NAX_HEADERS})
|
||||
|
||||
set(STEEL_NAX_ATTN_HEADERS
|
||||
steel/defines.h steel/utils.h steel/attn/nax.h steel/utils/type_traits.h
|
||||
steel/utils/integral_constant.h)
|
||||
|
||||
build_kernel(steel/attn/kernels/steel_attention_nax ${STEEL_NAX_ATTN_HEADERS})
|
||||
endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
|
||||
|
||||
1066
mlx/backend/metal/kernels/fp_quantized_nax.h
Normal file
1066
mlx/backend/metal/kernels/fp_quantized_nax.h
Normal file
File diff suppressed because it is too large
Load Diff
74
mlx/backend/metal/kernels/fp_quantized_nax.metal
Normal file
74
mlx/backend/metal/kernels/fp_quantized_nax.metal
Normal file
@@ -0,0 +1,74 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/quantized_utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
|
||||
#include "mlx/backend/metal/kernels/fp_quantized_nax.h"
|
||||
|
||||
|
||||
#define instantiate_quantized_batched(mode, name, type, bm, bn, bk, wm, wn, batched) \
|
||||
instantiate_kernel( \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_batch_" #batched, \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
4, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_aligned(mode, name, type, bm, bn, bk, wm, wn, aligned) \
|
||||
instantiate_kernel( \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned, \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
4, \
|
||||
aligned)
|
||||
|
||||
#define instantiate_quantized_aligned_batched(mode, name, type, bm, bn, bk, wm, wn, aligned, batched) \
|
||||
instantiate_kernel( \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned "_batch_" #batched, \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
4, \
|
||||
aligned, \
|
||||
batched)
|
||||
|
||||
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
|
||||
func, \
|
||||
type, \
|
||||
32, \
|
||||
4, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
transpose)
|
||||
|
||||
|
||||
#define instantiate_quantized_all_aligned(type) \
|
||||
instantiate_quantized_aligned(mxfp4, gather_qmm_t_nax, type, 64, 64, 64, 2, 2, true) \
|
||||
instantiate_quantized_aligned(mxfp4, gather_qmm_t_nax, type, 64, 64, 64, 2, 2, false) \
|
||||
instantiate_quantized_aligned_batched(mxfp4, qmm_t_nax, type, 64, 64, 64, 2, 2, true, 1) \
|
||||
instantiate_quantized_aligned_batched(mxfp4, qmm_t_nax, type, 64, 64, 64, 2, 2, true, 0) \
|
||||
instantiate_quantized_aligned_batched(mxfp4, qmm_t_nax, type, 64, 64, 64, 2, 2, false, 1) \
|
||||
instantiate_quantized_aligned_batched(mxfp4, qmm_t_nax, type, 64, 64, 64, 2, 2, false, 0)
|
||||
|
||||
|
||||
#define instantiate_quantized_all_rhs(type) \
|
||||
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs_nax, mxfp4_gather_qmm_rhs_nax_nt, type, 64, 64, 64, 2, 2, true) \
|
||||
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs_nax, mxfp4_gather_qmm_rhs_nax_nn, type, 64, 64, 64, 2, 2, false)
|
||||
|
||||
#define instantiate_quantized_types(type) \
|
||||
instantiate_quantized_all_aligned(type) \
|
||||
instantiate_quantized_all_rhs(type)
|
||||
|
||||
instantiate_quantized_types(float)
|
||||
instantiate_quantized_types(bfloat16_t)
|
||||
instantiate_quantized_types(float16_t)
|
||||
// clang-format on
|
||||
1705
mlx/backend/metal/kernels/quantized_nax.h
Normal file
1705
mlx/backend/metal/kernels/quantized_nax.h
Normal file
File diff suppressed because it is too large
Load Diff
106
mlx/backend/metal/kernels/quantized_nax.metal
Normal file
106
mlx/backend/metal/kernels/quantized_nax.metal
Normal file
@@ -0,0 +1,106 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/loader.h"
|
||||
#include "mlx/backend/metal/kernels/quantized_nax.h"
|
||||
|
||||
#define instantiate_quantized(name, type, group_size, bits, bm, bn, bk, wm, wn) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits, bm, bk, bn, wm, wn)
|
||||
|
||||
#define instantiate_quantized_batched(name, type, group_size, bits, bm, bn, bk, wm, wn, batched) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_batch_" #batched, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits, \
|
||||
batched, bm, bk, bn, wm, wn)
|
||||
|
||||
#define instantiate_quantized_aligned(name, type, group_size, bits, bm, bn, bk, wm, wn, aligned) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits, \
|
||||
aligned, bm, bk, bn, wm, wn)
|
||||
|
||||
#define instantiate_quantized_aligned_batched(name, type, group_size, bits, bm, bn, bk, wm, wn, aligned, batched) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned "_batch_" #batched, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits, \
|
||||
aligned, \
|
||||
batched, bm, bk, bn, wm, wn)
|
||||
|
||||
#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
|
||||
func, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
transpose)
|
||||
|
||||
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
|
||||
instantiate_quantized_batched(name, type, group_size, bits, 64, 64, 64, 2, 2, 1) \
|
||||
instantiate_quantized_batched(name, type, group_size, bits, 64, 64, 64, 2, 2, 0)
|
||||
|
||||
#define instantiate_quantized_all_batched(type, group_size, bits) \
|
||||
instantiate_quantized_batched_wrap(affine_qmm_n_nax, type, group_size, bits)
|
||||
|
||||
|
||||
#define instantiate_quantized_all_single(type, group_size, bits) \
|
||||
instantiate_quantized(affine_gather_qmm_n_nax, type, group_size, bits, 64, 64, 64, 2, 2)
|
||||
|
||||
#define instantiate_quantized_all_aligned(type, group_size, bits) \
|
||||
instantiate_quantized_aligned(affine_gather_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true) \
|
||||
instantiate_quantized_aligned(affine_gather_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false) \
|
||||
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true, 1) \
|
||||
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true, 0) \
|
||||
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false, 1) \
|
||||
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false, 0)
|
||||
|
||||
#define instantiate_quantized_all_rhs(type, group_size, bits) \
|
||||
instantiate_gather_qmm_rhs(affine_gather_qmm_rhs_nax, affine_gather_qmm_rhs_nax_nt, type, group_size, bits, 64, 64, 64, 2, 2, true) \
|
||||
instantiate_gather_qmm_rhs(affine_gather_qmm_rhs_nax, affine_gather_qmm_rhs_nax_nn, type, group_size, bits, 64, 64, 64, 2, 2, false)
|
||||
|
||||
#define instantiate_quantized_funcs(type, group_size, bits) \
|
||||
instantiate_quantized_all_batched(type, group_size, bits) \
|
||||
instantiate_quantized_all_aligned(type, group_size, bits) \
|
||||
instantiate_quantized_all_rhs(type, group_size, bits)
|
||||
|
||||
#define instantiate_quantized_types(group_size, bits) \
|
||||
instantiate_quantized_funcs(float, group_size, bits) \
|
||||
instantiate_quantized_funcs(float16_t, group_size, bits) \
|
||||
instantiate_quantized_funcs(bfloat16_t, group_size, bits)
|
||||
|
||||
#define instantiate_quantized_groups(bits) \
|
||||
instantiate_quantized_types(128, bits) \
|
||||
instantiate_quantized_types(64, bits) \
|
||||
instantiate_quantized_types(32, bits)
|
||||
|
||||
#define instantiate_quantized_all() \
|
||||
instantiate_quantized_groups(2) \
|
||||
instantiate_quantized_groups(3) \
|
||||
instantiate_quantized_groups(4) \
|
||||
instantiate_quantized_groups(5) \
|
||||
instantiate_quantized_groups(6) \
|
||||
instantiate_quantized_groups(8)
|
||||
|
||||
instantiate_quantized_all() // clang-format on
|
||||
@@ -0,0 +1,476 @@
|
||||
// Copyright © 2024-25 Apple Inc.
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
constant bool align_Q [[function_constant(200)]];
|
||||
constant bool align_K [[function_constant(201)]];
|
||||
|
||||
constant bool has_mask [[function_constant(300)]];
|
||||
constant bool do_causal [[function_constant(301)]];
|
||||
constant bool has_sinks [[function_constant(302)]];
|
||||
|
||||
template <typename T>
|
||||
struct TransformScale {
|
||||
T scale;
|
||||
METAL_FUNC TransformScale(T scale_) : scale(scale_) {}
|
||||
|
||||
METAL_FUNC T apply(T x) const {
|
||||
return scale * x;
|
||||
}
|
||||
};
|
||||
|
||||
struct MaxOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return metal::max(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
struct SumOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
|
||||
struct MulOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return x * y;
|
||||
}
|
||||
};
|
||||
|
||||
struct SubOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return x - y;
|
||||
}
|
||||
};
|
||||
|
||||
struct ExpSubOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return fast::exp2(x - y);
|
||||
}
|
||||
};
|
||||
|
||||
struct DivOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
template <
|
||||
typename T,
|
||||
int BQ,
|
||||
int BK,
|
||||
int BD,
|
||||
int WM,
|
||||
int WN,
|
||||
typename MaskType = float,
|
||||
typename AccumType = float>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention_nax(
|
||||
const device T* Q [[buffer(0)]],
|
||||
const device T* K [[buffer(1)]],
|
||||
const device T* V [[buffer(2)]],
|
||||
device T* O [[buffer(3)]],
|
||||
const constant AttnParams* params [[buffer(4)]],
|
||||
const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
|
||||
const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
|
||||
const device T* sinks [[buffer(7), function_constant(has_sinks)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
|
||||
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
(void)simd_lane_id;
|
||||
|
||||
// Move to correct block
|
||||
ulong3 tidl{tid.x, tid.y, tid.z};
|
||||
|
||||
Q += tidl.z * params->Q_strides[0] + // Batch
|
||||
tidl.y * params->Q_strides[1] + // Head
|
||||
tidl.x * BQ * params->Q_strides[2]; // Sequence
|
||||
|
||||
ulong kv_head_idx = int(tid.y) / params->gqa_factor;
|
||||
K += tidl.z * params->K_strides[0] + // Batch
|
||||
kv_head_idx * params->K_strides[1]; // Head
|
||||
|
||||
V += tidl.z * params->V_strides[0] + // Batch
|
||||
kv_head_idx * params->V_strides[1]; // Head
|
||||
|
||||
O += tidl.z * params->O_strides[0] + // Batch
|
||||
tidl.y * params->O_strides[1] + // Head
|
||||
tidl.x * BQ * params->O_strides[2]; // Sequence
|
||||
|
||||
if (has_mask) {
|
||||
mask += tidl.z * mask_params->M_strides[0] + // Batch
|
||||
tidl.y * mask_params->M_strides[1]; // Head
|
||||
}
|
||||
|
||||
const metal::uniform<float> scale2 =
|
||||
make_uniform(params->scale) * make_uniform(1.44269504089f);
|
||||
|
||||
// Prepare MMA tiles
|
||||
constexpr short UQ = 16;
|
||||
constexpr short UD = 32;
|
||||
|
||||
constexpr int kNWarps = WM * WN;
|
||||
static_assert(
|
||||
BQ >= (kNWarps * UQ) && BQ % (kNWarps * UQ) == 0,
|
||||
"Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
|
||||
|
||||
// Q seq frags per warp
|
||||
constexpr int TQ = BQ / (kNWarps * UQ);
|
||||
// HeadDim frags (all warps load the same frags)
|
||||
constexpr int TD = BD / UD;
|
||||
|
||||
static_assert(TQ == 1, "Check TQ");
|
||||
|
||||
using OSubTile = NAXSubTile<AccumType, UQ, UD>;
|
||||
NAXTile<AccumType, TQ, TD, OSubTile> Otile;
|
||||
|
||||
Otile.clear();
|
||||
|
||||
// Prepare mma tile offsets
|
||||
const short2 simd_coord = OSubTile::NAXFrag_t::get_coord();
|
||||
const short sm = simd_coord.y;
|
||||
const short sn = simd_coord.x;
|
||||
const short tm = UQ * TQ * simd_group_id;
|
||||
|
||||
Q += (tm + sm) * int(params->Q_strides[2]) + sn;
|
||||
K += sm * int(params->K_strides[2]) + sn;
|
||||
V += sm * int(params->V_strides[2]) + sn;
|
||||
|
||||
// Init row reduction variables
|
||||
constexpr short kRowsPT = decltype(Otile)::kRowsPerThread;
|
||||
|
||||
metal::vec<AccumType, kRowsPT> max_score;
|
||||
metal::vec<AccumType, kRowsPT> sum_score{0};
|
||||
|
||||
// Init to -Inf
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
max_score[i] = Limits<AccumType>::finite_min;
|
||||
}
|
||||
|
||||
if (has_sinks) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
max_score[i] = M_LOG2E_F * static_cast<AccumType>(sinks[tidl.y]);
|
||||
sum_score[i] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
int kb_lim = params->NK;
|
||||
|
||||
if (do_causal) {
|
||||
int q_max = (tid.x + 1) * BQ + params->qL_off;
|
||||
kb_lim = (q_max + BK - 1) / BK;
|
||||
kb_lim = min(params->NK, kb_lim);
|
||||
}
|
||||
|
||||
const bool is_last_bq = int(tid.x) == (params->NQ_aligned);
|
||||
// const bool is_last_tq = int(simd_group_id) >= (params->qL_rem / UQ);
|
||||
const bool is_last_q = is_last_bq;
|
||||
|
||||
const short lim_rows_q = params->qL_rem - (tm + sm);
|
||||
const short lim_rows_k = params->kL_rem - sm;
|
||||
|
||||
// Loop over KV seq length
|
||||
for (int kb = 0; kb < kb_lim; kb++) {
|
||||
const int is_last_k = (kb == (params->NK_aligned));
|
||||
|
||||
// Do S = Q @ K.T
|
||||
constexpr short UDs = 16;
|
||||
constexpr short UKs = 32;
|
||||
|
||||
constexpr short TDs = BD / UDs;
|
||||
constexpr short TKs = BK / UKs;
|
||||
|
||||
using SSubTile = NAXSubTile<AccumType, UQ, UKs>;
|
||||
using QSubTile = NAXSubTile<T, UQ, UDs>;
|
||||
using KSubTile = NAXSubTile<T, UKs, UDs>;
|
||||
|
||||
NAXTile<AccumType, TQ, TKs, SSubTile> Stile;
|
||||
|
||||
Stile.clear();
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short iq = 0; iq < TQ; iq++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ik = 0; ik < TKs; ik++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short id = 0; id < TDs; id++) {
|
||||
NAXTile<T, 1, 1, QSubTile> Qtile;
|
||||
NAXTile<T, 1, 1, KSubTile> Ktile;
|
||||
|
||||
const int Q_load_off = iq * UQ * int(params->Q_strides[2]) + id * UDs;
|
||||
const int K_load_off =
|
||||
ik * UKs * int(params->K_strides[2]) + id * UDs;
|
||||
|
||||
if (!align_Q && is_last_q) {
|
||||
// Qtile.load_rows(
|
||||
// Q + Q_load_off,
|
||||
// int(params->Q_strides[2]),
|
||||
// lim_rows_q - iq * UQ);
|
||||
Qtile.load_safe(
|
||||
Q + Q_load_off,
|
||||
int(params->Q_strides[2]),
|
||||
short2(BD, lim_rows_q - iq * UQ));
|
||||
} else {
|
||||
Qtile.load(Q + Q_load_off, int(params->Q_strides[2]));
|
||||
}
|
||||
|
||||
if (!align_K && is_last_k) {
|
||||
// Ktile.load_rows(
|
||||
// K + K_load_off,
|
||||
// int(params->K_strides[2]),
|
||||
// lim_rows_k - ik * UKs);
|
||||
Ktile.load_safe(
|
||||
K + K_load_off,
|
||||
int(params->K_strides[2]),
|
||||
short2(BD, lim_rows_k - ik * UKs));
|
||||
} else {
|
||||
Ktile.load(K + K_load_off, int(params->K_strides[2]));
|
||||
}
|
||||
|
||||
subtile_matmad_nax(
|
||||
Stile.subtile_at(iq, ik),
|
||||
Qtile.subtile_at(0, 0),
|
||||
metal::false_type{},
|
||||
Ktile.subtile_at(0, 0),
|
||||
metal::true_type{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Scale S
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) {
|
||||
Stile.elems()[ii] *= float(scale2);
|
||||
}
|
||||
|
||||
// Scale and Retile S
|
||||
constexpr short UK = 16;
|
||||
constexpr short TK = BK / UK;
|
||||
using PSubTile = NAXSubTile<AccumType, UQ, UK>;
|
||||
|
||||
NAXTile<AccumType, TQ, TK, PSubTile> Ptile;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) {
|
||||
Ptile.elems()[ii] = Stile.elems()[ii];
|
||||
}
|
||||
|
||||
// Mask out length sequence
|
||||
if (!align_K && is_last_k) {
|
||||
constexpr auto neg_inf = Limits<AccumType>::finite_min;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short iq = 0; iq < TQ; iq++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ik = 0; ik < TK; ik++) {
|
||||
const short col_pos = sn + ik * UK;
|
||||
|
||||
thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) {
|
||||
const auto loc = ii * PSubTile::kFragThrCols + jj;
|
||||
fg[loc] = ((col_pos + jj) >= params->kL_rem) ? neg_inf : fg[loc];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mask out if causal
|
||||
if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) {
|
||||
constexpr auto neg_inf = Limits<AccumType>::finite_min;
|
||||
|
||||
const int base_row = tid.x * BQ + params->qL_off + tm;
|
||||
const int base_col = kb * BK;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short iq = 0; iq < TQ; iq++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ik = 0; ik < TK; ik++) {
|
||||
const short row_pos = base_row + iq * UQ;
|
||||
const short col_pos = base_col + ik * UK;
|
||||
|
||||
thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) {
|
||||
const auto r = row_pos + ii * PSubTile::kFragRowsJump + sm;
|
||||
const auto c = col_pos + jj + sn;
|
||||
const auto loc = ii * PSubTile::kFragThrCols + jj;
|
||||
fg[loc] = (r < c) ? neg_inf : fg[loc];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Other masking as needed
|
||||
if (has_mask) {
|
||||
constexpr auto neg_inf = Limits<AccumType>::finite_min;
|
||||
|
||||
const int base_row = tid.x * BQ + tm;
|
||||
const int base_col = kb * BK;
|
||||
|
||||
constexpr bool is_bool = is_same_v<MaskType, bool>;
|
||||
using melem_t = typename metal::conditional_t<is_bool, bool, AccumType>;
|
||||
using MSubTile = NAXSubTile<melem_t, UQ, UK>;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short iq = 0; iq < TQ; iq++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ik = 0; ik < TK; ik++) {
|
||||
const short row_pos = base_row + iq * UQ + sm;
|
||||
const short col_pos = base_col + ik * UK + sn;
|
||||
|
||||
MSubTile mfrag;
|
||||
mfrag.load_safe(
|
||||
mask,
|
||||
int(mask_params->M_strides[2]),
|
||||
Int<1>{},
|
||||
params->qL,
|
||||
params->kL,
|
||||
row_pos,
|
||||
col_pos);
|
||||
|
||||
thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short jj = 0; jj < MSubTile::kElemsPerFrag; jj++) {
|
||||
if constexpr (is_bool) {
|
||||
fg[jj] = mfrag.elems()[jj] ? fg[jj] : neg_inf;
|
||||
} else {
|
||||
fg[jj] += M_LOG2E_F * AccumType(mfrag.elems()[jj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Do softmax
|
||||
|
||||
// Temp variables
|
||||
metal::vec<AccumType, kRowsPT> new_max;
|
||||
metal::vec<AccumType, kRowsPT> factor;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
new_max[i] = max_score[i];
|
||||
}
|
||||
|
||||
// Row max
|
||||
Ptile.template row_reduce<MaxOp>(new_max);
|
||||
|
||||
// exp(Si - rowmax(Si))
|
||||
Ptile.template row_bin_op<ExpSubOp>(new_max);
|
||||
|
||||
// Factor exp(rowmax(Si) - rowmax(Si-1))
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
factor[i] = fast::exp2(max_score[i] - new_max[i]);
|
||||
max_score[i] = new_max[i];
|
||||
}
|
||||
|
||||
// Row Sum
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
sum_score[i] = sum_score[i] * factor[i];
|
||||
}
|
||||
|
||||
Ptile.template row_reduce<SumOp>(sum_score);
|
||||
|
||||
// Update O
|
||||
Otile.template row_bin_op<MulOp>(factor);
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Do O = P @ V
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short iq = 0; iq < TQ; iq++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short id = 0; id < TD; id++) {
|
||||
if constexpr (BD == 128) {
|
||||
if (id == 2) {
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
}
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ik = 0; ik < TK; ik++) {
|
||||
using VSubTile = NAXSubTile<T, UK, UD>;
|
||||
NAXTile<T, 1, 1, VSubTile> Vtile;
|
||||
|
||||
const int V_load_off = ik * UK * int(params->V_strides[2]) + id * UD;
|
||||
|
||||
if (!align_K && is_last_k) {
|
||||
// Vtile.load_rows(
|
||||
// V + V_load_off,
|
||||
// int(params->V_strides[2]),
|
||||
// lim_rows_k - ik * UK);
|
||||
Vtile.load_safe(
|
||||
V + V_load_off,
|
||||
int(params->V_strides[2]),
|
||||
short2(BD, lim_rows_k - ik * UK));
|
||||
} else {
|
||||
Vtile.load(V + V_load_off, int(params->V_strides[2]));
|
||||
}
|
||||
|
||||
subtile_matmad_nax(
|
||||
Otile.subtile_at(iq, id),
|
||||
Ptile.subtile_at(iq, ik),
|
||||
metal::bool_constant<false>{},
|
||||
Vtile.subtile_at(0, 0),
|
||||
metal::bool_constant<false>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare for next iteration
|
||||
K += BK * int(params->K_strides[2]);
|
||||
V += BK * int(params->V_strides[2]);
|
||||
}
|
||||
|
||||
// Normalize output
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
metal::vec<AccumType, kRowsPT> rcp;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
rcp[i] = (1.f / sum_score[i]);
|
||||
}
|
||||
|
||||
Otile.template row_bin_op<MulOp>(rcp);
|
||||
|
||||
// Store results
|
||||
O += (tm + sm) * int(params->O_strides[2]) + sn;
|
||||
|
||||
if (!align_Q && is_last_q) {
|
||||
if (lim_rows_q <= 0)
|
||||
return;
|
||||
|
||||
// Otile.store_rows(O, params->O_strides[2], lim_rows_q);
|
||||
Otile.store_safe(O, params->O_strides[2], short2(BD, lim_rows_q));
|
||||
} else {
|
||||
Otile.store(O, int(params->O_strides[2]));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
// Copyright © 2024-25 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/attn/nax.h"
|
||||
#include "mlx/backend/metal/kernels/steel/attn/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/attn/transforms.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h"
|
||||
|
||||
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \
|
||||
instantiate_kernel( \
|
||||
"steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \
|
||||
"_wm" #wm "_wn" #wn "_mask" #mname, \
|
||||
attention_nax, dtype, bq, bk, bd, wm, wn, mtype, float)
|
||||
|
||||
#define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \
|
||||
instantiate_attn(iname, itype, 64, 32, 128, 4, 1, mname, mtype) \
|
||||
instantiate_attn(iname, itype, 64, 32, 64, 4, 1, mname, mtype) \
|
||||
instantiate_attn(iname, itype, 64, 64, 128, 4, 1, mname, mtype) \
|
||||
instantiate_attn(iname, itype, 64, 64, 64, 4, 1, mname, mtype)
|
||||
|
||||
#define instantiate_attn_mask_helper(iname, itype) \
|
||||
instantiate_attn_shapes_helper(iname, itype, iname, itype) \
|
||||
instantiate_attn_shapes_helper(iname, itype, bool_, bool)
|
||||
|
||||
instantiate_attn_mask_helper(float16, half);
|
||||
instantiate_attn_mask_helper(bfloat16, bfloat);
|
||||
|
||||
instantiate_attn_mask_helper(float32, float);
|
||||
// clang-format on
|
||||
1076
mlx/backend/metal/kernels/steel/attn/nax.h
Normal file
1076
mlx/backend/metal/kernels/steel/attn/nax.h
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#define STEEL_CONST static constant constexpr const
|
||||
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
||||
#define STEEL_PRAGMA_NO_UNROLL _Pragma("clang loop unroll(disable)")
|
||||
|
||||
154
mlx/backend/metal/kernels/steel/gemm/gemm_nax.h
Normal file
154
mlx/backend/metal/kernels/steel/gemm/gemm_nax.h
Normal file
@@ -0,0 +1,154 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
namespace mlx::steel {
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short SM,
|
||||
short SN,
|
||||
short SK,
|
||||
short BK,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool kAlignedM,
|
||||
bool kAlignedN,
|
||||
bool kAlignedK,
|
||||
short UM,
|
||||
short UN,
|
||||
short UK,
|
||||
typename AccumType = float>
|
||||
auto gemm_loop(
|
||||
const device T* A,
|
||||
const device T* B,
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const short sgp_sm,
|
||||
const short sgp_sn) {
|
||||
constexpr short TM = SM / UM;
|
||||
constexpr short TN = SN / UN;
|
||||
constexpr short TK = SK / UK;
|
||||
|
||||
constexpr int RA = transpose_a ? TK : TM;
|
||||
constexpr int CA = transpose_a ? TM : TK;
|
||||
|
||||
constexpr int RB = transpose_b ? TN : TK;
|
||||
constexpr int CB = transpose_b ? TK : TN;
|
||||
|
||||
using DSubTile = NAXSubTile<AccumType, UM, UN>;
|
||||
using ASubTile =
|
||||
NAXSubTile<T, (transpose_a ? UK : UM), (transpose_a ? UM : UK)>;
|
||||
using BSubTile =
|
||||
NAXSubTile<T, (transpose_b ? UN : UK), (transpose_b ? UK : UN)>;
|
||||
|
||||
NAXTile<AccumType, TM, TN, DSubTile> Dtile;
|
||||
Dtile.clear();
|
||||
|
||||
int gemm_k_iterations_ = params->gemm_k_iterations_aligned;
|
||||
|
||||
STEEL_PRAGMA_NO_UNROLL
|
||||
for (int kk0 = 0; kk0 < gemm_k_iterations_; kk0++) {
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
STEEL_PRAGMA_NO_UNROLL
|
||||
for (int kk1 = 0; kk1 < BK; kk1 += SK) {
|
||||
NAXTile<T, RA, CA, ASubTile> Atile;
|
||||
NAXTile<T, RB, CB, BSubTile> Btile;
|
||||
const int k = kk1;
|
||||
|
||||
volatile int compiler_barrier;
|
||||
|
||||
const int A_offset = transpose_a ? k * params->lda : k;
|
||||
const int B_offset = transpose_b ? k : k * params->ldb;
|
||||
|
||||
if constexpr (kAlignedM) {
|
||||
Atile.load(A + A_offset, params->lda);
|
||||
} else {
|
||||
const short rmax = transpose_a ? SK : sgp_sm;
|
||||
const short cmax = transpose_a ? sgp_sm : SK;
|
||||
Atile.load_safe(A + A_offset, params->lda, short2(cmax, rmax));
|
||||
}
|
||||
|
||||
if constexpr (kAlignedN) {
|
||||
Btile.load(B + B_offset, params->ldb);
|
||||
} else {
|
||||
const short rmax = transpose_b ? sgp_sn : SK;
|
||||
const short cmax = transpose_b ? SK : sgp_sn;
|
||||
Btile.load_safe(B + B_offset, params->ldb, short2(cmax, rmax));
|
||||
}
|
||||
|
||||
tile_matmad_nax(
|
||||
Dtile,
|
||||
Atile,
|
||||
metal::bool_constant<transpose_a>{},
|
||||
Btile,
|
||||
metal::bool_constant<transpose_b>{});
|
||||
|
||||
(void)compiler_barrier;
|
||||
}
|
||||
|
||||
A += transpose_a ? (BK * params->lda) : BK;
|
||||
B += transpose_b ? BK : (BK * params->ldb);
|
||||
}
|
||||
|
||||
if constexpr (!kAlignedK) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
const short rem_bk = params->K - gemm_k_iterations_ * BK;
|
||||
|
||||
STEEL_PRAGMA_NO_UNROLL
|
||||
for (int kk1 = 0; kk1 < rem_bk; kk1 += SK) {
|
||||
NAXTile<T, 1, 1, ASubTile> Atile;
|
||||
NAXTile<T, 1, 1, BSubTile> Btile;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int mm = 0; mm < TM; mm++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int nn = 0; nn < TN; nn++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int kk = 0; kk < TK; kk++) {
|
||||
const int m = mm * UM;
|
||||
const int n = nn * UN;
|
||||
const int k = kk1 + kk * UK;
|
||||
const short psk = max(0, rem_bk - k);
|
||||
|
||||
const int A_offset =
|
||||
transpose_a ? (m + k * params->lda) : (m * params->lda + k);
|
||||
const int B_offset =
|
||||
transpose_b ? (k + n * params->ldb) : (k * params->ldb + n);
|
||||
|
||||
{
|
||||
const short psm = kAlignedM ? SM : max(0, sgp_sm - m);
|
||||
const short rmax = transpose_a ? psk : psm;
|
||||
const short cmax = transpose_a ? psm : psk;
|
||||
Atile.load_safe(A + A_offset, params->lda, short2(cmax, rmax));
|
||||
}
|
||||
|
||||
{
|
||||
const short psn = kAlignedN ? SN : max(0, sgp_sn - n);
|
||||
const short rmax = transpose_b ? psn : psk;
|
||||
const short cmax = transpose_b ? psk : psn;
|
||||
Btile.load_safe(B + B_offset, params->ldb, short2(cmax, rmax));
|
||||
}
|
||||
|
||||
subtile_matmad_nax(
|
||||
Dtile.subtile_at(mm, nn),
|
||||
Atile.subtile_at(0, 0),
|
||||
metal::bool_constant<transpose_a>{},
|
||||
Btile.subtile_at(0, 0),
|
||||
metal::bool_constant<transpose_b>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Dtile;
|
||||
}
|
||||
|
||||
} // namespace mlx::steel
|
||||
@@ -0,0 +1,207 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
constant bool has_batch [[function_constant(10)]];
|
||||
|
||||
constant bool use_out_source [[function_constant(100)]];
|
||||
constant bool do_axpby [[function_constant(110)]];
|
||||
|
||||
constant bool align_M [[function_constant(200)]];
|
||||
constant bool align_N [[function_constant(201)]];
|
||||
constant bool align_K [[function_constant(202)]];
|
||||
|
||||
// clang-format off
|
||||
template <
|
||||
bool kAlignedM,
|
||||
bool kAlignedN,
|
||||
typename NAXTile_t,
|
||||
typename T>
|
||||
void gemm_epilogue(
|
||||
thread NAXTile_t& Dtile,
|
||||
const device T* C,
|
||||
const constant GEMMParams* params,
|
||||
const constant GEMMAddMMParams* addmm_params,
|
||||
const short sgp_sm,
|
||||
const short sgp_sn) { // clang-format on
|
||||
|
||||
(void)params;
|
||||
|
||||
constexpr short UM = NAXTile_t::kSubTileRows;
|
||||
constexpr short UN = NAXTile_t::kSubTileCols;
|
||||
using CSubTile = NAXSubTile<T, UM, UN>;
|
||||
|
||||
using V = typename NAXTile_t::elem_type;
|
||||
|
||||
constexpr short TM = NAXTile_t::kTileRows;
|
||||
constexpr short TN = NAXTile_t::kTileCols;
|
||||
constexpr short kElemsPerSubTile = NAXTile_t::kElemsPerSubTile;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short mm = 0; mm < TM; mm++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short nn = 0; nn < TN; nn++) {
|
||||
const short m = mm * UM;
|
||||
const short n = nn * UN;
|
||||
|
||||
CSubTile CTile;
|
||||
|
||||
if constexpr (kAlignedM && kAlignedN) {
|
||||
CTile.load(C, addmm_params->ldc, addmm_params->fdc, m, n);
|
||||
} else {
|
||||
CTile.load_safe(
|
||||
C, addmm_params->ldc, addmm_params->fdc, sgp_sm, sgp_sn, m, n);
|
||||
}
|
||||
|
||||
auto delems = Dtile.subtile_at(mm, nn).elems();
|
||||
auto celems = CTile.elems();
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemsPerSubTile; i++) {
|
||||
if (do_axpby) {
|
||||
delems[i] = addmm_params->alpha * delems[i] +
|
||||
addmm_params->beta * static_cast<V>(celems[i]);
|
||||
} else {
|
||||
delems[i] += static_cast<V>(celems[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
typename AccumType = float>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
const device T* C [[buffer(2), function_constant(use_out_source)]],
|
||||
device T* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||
const constant int* batch_shape [[buffer(6), function_constant(has_batch)]],
|
||||
const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]]) { // clang-format on
|
||||
// Find block
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
|
||||
// Exit early if out of bounds
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Adjust for batch
|
||||
if (has_batch) {
|
||||
const constant auto* A_bstrides = batch_strides;
|
||||
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
|
||||
if (use_out_source) {
|
||||
const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
|
||||
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
|
||||
}
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
|
||||
if (use_out_source) {
|
||||
C += addmm_params->batch_stride_c * tid.z;
|
||||
}
|
||||
}
|
||||
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
// Prepare threadgroup memory
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
D += c_row_long * params->ldd + c_col_long;
|
||||
|
||||
if (use_out_source) {
|
||||
C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
|
||||
}
|
||||
|
||||
constexpr short UM = 16;
|
||||
constexpr short UN = 32;
|
||||
constexpr short UK = 16;
|
||||
constexpr short SM = BM / WM;
|
||||
constexpr short SN = BN / WN;
|
||||
constexpr short SK = 32;
|
||||
|
||||
constexpr short TM = SM / UM;
|
||||
constexpr short TN = SN / UN;
|
||||
|
||||
const short tm = SM * (simd_group_id / WN);
|
||||
const short tn = SN * (simd_group_id % WN);
|
||||
|
||||
const short sgp_sm = align_M ? SM : min(SM, short(params->M - (c_row + tm)));
|
||||
const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
|
||||
|
||||
const short sgp_sn = align_N ? SN : min(SN, short(params->N - (c_col + tn)));
|
||||
const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);
|
||||
|
||||
A += transpose_a ? tm : (tm * params->lda);
|
||||
B += transpose_b ? (tn * params->ldb) : tn;
|
||||
D += tm * params->ldd + tn;
|
||||
|
||||
if (use_out_source) {
|
||||
C += tm * addmm_params->ldc + tn * addmm_params->fdc;
|
||||
}
|
||||
|
||||
using DSubTile = NAXSubTile<AccumType, UM, UN>;
|
||||
NAXTile<AccumType, TM, TN, DSubTile> Dtile;
|
||||
|
||||
dispatch_bool(align_K, [&](auto kAlignedK) {
|
||||
dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {
|
||||
dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) {
|
||||
Dtile = gemm_loop<
|
||||
T,
|
||||
SM,
|
||||
SN,
|
||||
SK,
|
||||
BK,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
kAlignedM.value,
|
||||
kAlignedN.value,
|
||||
kAlignedK.value,
|
||||
UM,
|
||||
UN,
|
||||
UK,
|
||||
AccumType>(A, B, params, sgp_sm, sgp_sn);
|
||||
if (use_out_source) {
|
||||
gemm_epilogue<kAlignedM.value, kAlignedN.value>(
|
||||
Dtile, C, params, addmm_params, sgp_sm, sgp_sn);
|
||||
}
|
||||
if constexpr (kAlignedM && kAlignedN) {
|
||||
Dtile.store(D, int(params->ldd));
|
||||
} else {
|
||||
Dtile.store_safe(D, int(params->ldd), short2(sgp_sn, sgp_sm));
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm_nax.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h"
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_kernel( \
|
||||
"steel_gemm_fused_nax_" #tname "_" #iname "_" #oname \
|
||||
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn, \
|
||||
gemm, itype, bm, bn, bk, wm, wn, trans_a, trans_b, float)
|
||||
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 256, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 128, 128, 512, 4, 4)
|
||||
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat);
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,132 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
constant bool align_M [[function_constant(200)]];
|
||||
constant bool align_N [[function_constant(201)]];
|
||||
constant bool align_K [[function_constant(202)]];
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
typename AccumType = float>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||
gather_mm_rhs_nax(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
const device uint32_t* rhs_indices [[buffer(2)]],
|
||||
device T* C [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
constexpr short UM = 16;
|
||||
constexpr short UN = 32;
|
||||
constexpr short UK = 16;
|
||||
constexpr short SM = BM / WM;
|
||||
constexpr short SN = BN / WN;
|
||||
constexpr short SK = 32;
|
||||
constexpr short TM = SM / UM;
|
||||
constexpr short TN = SN / UN;
|
||||
|
||||
if (params->tiles_n <= static_cast<int>(tid.x) ||
|
||||
params->tiles_m <= static_cast<int>(tid.y)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Find the block in A, B, C
|
||||
const int c_row = tid.y * BM;
|
||||
const int c_col = tid.x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
C += c_row_long * params->ldd + c_col_long;
|
||||
rhs_indices += c_row;
|
||||
|
||||
const short tm = SM * (simd_group_id / WN);
|
||||
const short tn = SN * (simd_group_id % WN);
|
||||
|
||||
const short sgp_sm = align_M ? SM : min(SM, short(params->M - (c_row + tm)));
|
||||
const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
|
||||
|
||||
const short sgp_sn = align_N ? SN : min(SN, short(params->N - (c_col + tn)));
|
||||
const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);
|
||||
|
||||
A += transpose_a ? tm : (tm * params->lda);
|
||||
B += transpose_b ? (tn * params->ldb) : tn;
|
||||
C += tm * params->ldd + tn;
|
||||
rhs_indices += tm;
|
||||
|
||||
// Do as many matmuls as necessary
|
||||
uint32_t index;
|
||||
short offset;
|
||||
uint32_t index_next = rhs_indices[0];
|
||||
short offset_next = 0;
|
||||
int n = 0;
|
||||
while (n < sgp_sm) {
|
||||
n++;
|
||||
offset = offset_next;
|
||||
index = index_next;
|
||||
offset_next = sgp_sm;
|
||||
for (; n < sgp_sm; n++) {
|
||||
if (rhs_indices[n] != index) {
|
||||
offset_next = n;
|
||||
index_next = rhs_indices[n];
|
||||
break;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
using DSubTile = NAXSubTile<AccumType, UM, UN>;
|
||||
NAXTile<AccumType, TM, TN, DSubTile> Ctile;
|
||||
|
||||
dispatch_bool(align_K, [&](auto kAlignedK) {
|
||||
dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {
|
||||
dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) {
|
||||
auto do_gemm = gemm_loop<
|
||||
T,
|
||||
SM,
|
||||
SN,
|
||||
SK,
|
||||
BK,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
kAlignedM.value,
|
||||
kAlignedN.value,
|
||||
kAlignedK.value,
|
||||
UM,
|
||||
UN,
|
||||
UK,
|
||||
AccumType>;
|
||||
Ctile = do_gemm(
|
||||
A, B + index * params->batch_stride_b, params, sgp_sm, sgp_sn);
|
||||
|
||||
if constexpr (kAlignedN.value) {
|
||||
if (offset_next - offset == SM) {
|
||||
Ctile.store(C, int(params->ldd));
|
||||
} else {
|
||||
Ctile.store_slice(
|
||||
C,
|
||||
int(params->ldd),
|
||||
short2(0, offset),
|
||||
short2(SN, offset_next));
|
||||
}
|
||||
} else {
|
||||
Ctile.store_slice(
|
||||
C,
|
||||
int(params->ldd),
|
||||
short2(0, offset),
|
||||
short2(sgp_sn, offset_next));
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm_nax.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_kernel( \
|
||||
"steel_gather_mm_rhs_nax_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
|
||||
"_bk" #bk "_wm" #wm "_wn" #wn, \
|
||||
gather_mm_rhs_nax, \
|
||||
itype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
float)
|
||||
|
||||
#define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gather_mm_rhs(nt, false, true, iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 128, 128, 1, 4) \
|
||||
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 32, 128, 128, 1, 4) \
|
||||
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 64, 128, 128, 2, 4)
|
||||
// clang-format on
|
||||
|
||||
instantiate_gather_mm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gather_mm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat);
|
||||
1084
mlx/backend/metal/kernels/steel/gemm/nax.h
Normal file
1084
mlx/backend/metal/kernels/steel/gemm/nax.h
Normal file
File diff suppressed because it is too large
Load Diff
@@ -74,6 +74,44 @@ integral_const_binop(>=, operator>=);
|
||||
integral_const_binop(&&, operator&&);
|
||||
integral_const_binop(||, operator||);
|
||||
|
||||
template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>
|
||||
METAL_FUNC constexpr auto operator||(true_type, T) {
|
||||
return true_type{};
|
||||
}
|
||||
template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>
|
||||
METAL_FUNC constexpr auto operator||(T, true_type) {
|
||||
return true_type{};
|
||||
}
|
||||
|
||||
template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>
|
||||
METAL_FUNC constexpr auto operator&&(false_type, T) {
|
||||
return false_type{};
|
||||
}
|
||||
|
||||
template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>
|
||||
METAL_FUNC constexpr auto operator&&(T, false_type) {
|
||||
return false_type{};
|
||||
}
|
||||
|
||||
// Dispatch utilities
|
||||
template <typename F>
|
||||
void dispatch_bool(bool v, F f) {
|
||||
if (v) {
|
||||
f(true_type{});
|
||||
} else {
|
||||
f(false_type{});
|
||||
}
|
||||
}
|
||||
|
||||
template <int start, int stop, int step, typename F>
|
||||
constexpr void const_for_loop(F f) {
|
||||
if constexpr (start < stop) {
|
||||
constexpr auto idx = Int<start>{};
|
||||
f(idx);
|
||||
const_for_loop<start + step, stop, step, F>(f);
|
||||
}
|
||||
}
|
||||
|
||||
#undef integral_const_binop
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -172,6 +172,165 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
||||
// Regular steel matmul dispatch
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
template <bool CHECK_AB>
|
||||
void steel_matmul_regular_axpby_nax(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int batch_size_out,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldd,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies,
|
||||
Shape batch_shape,
|
||||
Strides batch_strides,
|
||||
int64_t A_batch_stride,
|
||||
int64_t B_batch_stride,
|
||||
int64_t matrix_stride_out,
|
||||
int64_t C_batch_stride /* = 0*/,
|
||||
float alpha /* = 1.0f */,
|
||||
float beta /* = 0.0f */) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 128, bn = 128, bk = 512;
|
||||
int wm = 4, wn = 4;
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
|
||||
// clang-format off
|
||||
kname << "steel_gemm_fused_nax_"
|
||||
<< (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n')
|
||||
<< "_" << type_to_name(a)
|
||||
<< "_" << type_to_name(out)
|
||||
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
||||
|
||||
std::string base_name = kname.str();
|
||||
|
||||
const bool has_batch = (batch_shape.size() > 1);
|
||||
const bool use_out_source = CHECK_AB && (alpha != 0.0f || beta != 1.0f);
|
||||
const bool do_axpby = use_out_source && (alpha != 1.0f || beta != 1.0f);
|
||||
const bool align_M = (M % bm) == 0;
|
||||
const bool align_N = (N % bn) == 0;
|
||||
const bool align_K = (K % bk) == 0;
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
||||
{&use_out_source, MTL::DataType::DataTypeBool, 100},
|
||||
{&do_axpby, MTL::DataType::DataTypeBool, 110},
|
||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
kname << "_has_batch_" << (has_batch ? 't' : 'n')
|
||||
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
|
||||
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
||||
<< "_align_M_" << (align_M ? 't' : 'n')
|
||||
<< "_align_N_" << (align_N ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = get_steel_gemm_fused_kernel(
|
||||
/* metal::Device& d = */ d,
|
||||
/* const std::string& kernel_name = */ base_name,
|
||||
/* const std::string& hash_name = */ hash_name,
|
||||
/* const metal::MTLFCList& func_consts = */ func_consts,
|
||||
/* const array& out = */ out,
|
||||
/* bool transpose_a = */ transpose_a,
|
||||
/* bool transpose_b = */ transpose_b,
|
||||
/* int bm = */ bm,
|
||||
/* int bn = */ bn,
|
||||
/* int bk = */ bk,
|
||||
/* int wm = */ wm,
|
||||
/* int wn = */ wn);
|
||||
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
int tn = (N + bn - 1) / bn;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
|
||||
// TODO: Explore device-based tuning for swizzle
|
||||
int swizzle_log = tm <= 3 ? 0 : 1;
|
||||
|
||||
// Prepare steel matmul params
|
||||
GEMMParams params{
|
||||
/* const int M = */ M,
|
||||
/* const int N = */ N,
|
||||
/* const int K = */ K,
|
||||
/* const int lda = */ lda,
|
||||
/* const int ldb = */ ldb,
|
||||
/* const int ldd = */ ldd,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const int64_t batch_stride_a = */ A_batch_stride,
|
||||
/* const int64_t batch_stride_b = */ B_batch_stride,
|
||||
/* const int64_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
|
||||
// Prepare launch grid params
|
||||
int tile = 1 << swizzle_log;
|
||||
tm = (tm + tile - 1) / tile;
|
||||
tn = tn * tile;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
|
||||
if (has_batch) {
|
||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
||||
}
|
||||
|
||||
if (use_out_source) {
|
||||
int ldc = c.strides()[c.ndim() - 2];
|
||||
int fdc = c.strides()[c.ndim() - 1];
|
||||
|
||||
GEMMAddMMParams params{
|
||||
/* const int ldc = */ ldc,
|
||||
/* const int fdc = */ fdc,
|
||||
/* const int64_t batch_stride_c = */ C_batch_stride,
|
||||
/* const float alpha = */ alpha,
|
||||
/* const float beta = */ beta};
|
||||
|
||||
compute_encoder.set_input_array(c, 2);
|
||||
compute_encoder.set_bytes(params, 5);
|
||||
}
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
// Record copies
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
template <bool CHECK_AB>
|
||||
void steel_matmul_regular_axpby(
|
||||
const Stream& s,
|
||||
@@ -198,6 +357,41 @@ void steel_matmul_regular_axpby(
|
||||
int64_t C_batch_stride /* = 0*/,
|
||||
float alpha /* = 1.0f */,
|
||||
float beta /* = 0.0f */) {
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) &&
|
||||
(env::enable_tf32() || a.dtype() != float32)) {
|
||||
return steel_matmul_regular_axpby_nax<CHECK_AB>(
|
||||
/* const Stream& s = */ s,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const array& a = */ a,
|
||||
/* const array& b = */ b,
|
||||
/* const array& c = */ c,
|
||||
/* array& out = */ out,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* int batch_size_out = */ batch_size_out,
|
||||
/* int lda = */ lda,
|
||||
/* int ldb = */ ldb,
|
||||
/* int ldd = */ ldd,
|
||||
/* bool transpose_a = */ transpose_a,
|
||||
/* bool transpose_b = */ transpose_b,
|
||||
/* std::vector<array>& copies = */ copies,
|
||||
/* Shape batch_shape = */ batch_shape,
|
||||
/* Strides batch_strides = */ batch_strides,
|
||||
/* int64_t A_batch_stride = */ A_batch_stride,
|
||||
/* int64_t B_batch_stride = */ B_batch_stride,
|
||||
/* int64_t matrix_stride_out = */ matrix_stride_out,
|
||||
/* int64_t C_batch_stride = */ C_batch_stride,
|
||||
/* float alpha = */ alpha,
|
||||
/* float beta = */ beta);
|
||||
}
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
// Determine dispatch kernel
|
||||
@@ -1572,6 +1766,153 @@ void gather_mm_rhs(
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
void gather_mm_rhs_nax(
|
||||
const array& a_,
|
||||
const array& b_,
|
||||
const array& indices_,
|
||||
array& out,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
array indices = ensure_row_contiguous(indices_, d, s);
|
||||
auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s);
|
||||
|
||||
// Broadcast a with indices. If we are here that means lhs_indices were not
|
||||
// provided so the lhs_indices are implied to be the shape of a broadcasted
|
||||
// with rhs_indices. We need only broadcast a and copy it as if applying the
|
||||
// lhs_indices.
|
||||
auto broadcast_with_indices = [&d, &s, &indices](const array& x) {
|
||||
if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) {
|
||||
return ensure_row_contiguous(x, d, s);
|
||||
}
|
||||
|
||||
auto x_shape = indices.shape();
|
||||
x_shape.push_back(x.shape(-2));
|
||||
x_shape.push_back(x.shape(-1));
|
||||
array new_x(std::move(x_shape), x.dtype(), nullptr, {});
|
||||
broadcast(x, new_x);
|
||||
return ensure_row_contiguous(new_x, d, s);
|
||||
};
|
||||
array a = broadcast_with_indices(a_);
|
||||
|
||||
// Extract the matmul shapes
|
||||
int K = a.shape(-1);
|
||||
int M = a.size() / K;
|
||||
int N = b.shape(-1);
|
||||
int lda = a.strides()[a.ndim() - 2]; // should be K
|
||||
int E = b.shape(0);
|
||||
|
||||
// Define the dispatch blocks
|
||||
int bm, bn = 128, bk = 128, wm, wn = 4;
|
||||
if (M / E > 48) {
|
||||
bm = 64;
|
||||
wm = 2;
|
||||
} else if (M / E > 24) {
|
||||
bm = 32l;
|
||||
wm = 1;
|
||||
} else {
|
||||
bm = 16;
|
||||
wm = 1;
|
||||
}
|
||||
|
||||
const bool align_M = (M % bm) == 0;
|
||||
const bool align_N = (N % bn) == 0;
|
||||
const bool align_K = (K % bk) == 0;
|
||||
|
||||
// Define the kernel name
|
||||
std::string base_name;
|
||||
base_name.reserve(64);
|
||||
concatenate(
|
||||
base_name,
|
||||
"steel_gather_mm_rhs_nax_n",
|
||||
transpose_b ? 't' : 'n',
|
||||
'_',
|
||||
type_to_name(a),
|
||||
'_',
|
||||
type_to_name(out),
|
||||
"_bm",
|
||||
bm,
|
||||
"_bn",
|
||||
bn,
|
||||
"_bk",
|
||||
bk,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn);
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||
};
|
||||
|
||||
// And the kernel hash that includes the function constants
|
||||
std::string hash_name;
|
||||
hash_name.reserve(128);
|
||||
concatenate(
|
||||
hash_name,
|
||||
base_name,
|
||||
"_align_M_",
|
||||
align_M ? 't' : 'n',
|
||||
"_align_N_",
|
||||
align_N ? 't' : 'n',
|
||||
"_align_K_",
|
||||
align_K ? 't' : 'n');
|
||||
|
||||
// Get and set the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = get_steel_gemm_gather_kernel(
|
||||
d,
|
||||
base_name,
|
||||
hash_name,
|
||||
func_consts,
|
||||
out,
|
||||
false,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
true);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Prepare the matmul params
|
||||
auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size();
|
||||
steel::GEMMParams params{
|
||||
/* const int M = */ M,
|
||||
/* const int N = */ N,
|
||||
/* const int K = */ K,
|
||||
/* const int lda = */ lda,
|
||||
/* const int ldb = */ static_cast<int>(ldb),
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ (N + bn - 1) / bn,
|
||||
/* const int tiles_m = */ (M + bm - 1) / bm,
|
||||
/* const int64_t batch_stride_a = */ 0,
|
||||
/* const int64_t batch_stride_b = */ static_cast<int64_t>(batch_stride_b),
|
||||
/* const int64_t batch_stride_d = */ 0,
|
||||
/* const int swizzle_log = */ 0,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ 0};
|
||||
|
||||
// Prepare the grid
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_input_array(indices, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
void gather_mv(
|
||||
const array& mat_,
|
||||
const array& vec_,
|
||||
@@ -1855,6 +2196,19 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// We are walking a in order and b is also in order so we can batch up the
|
||||
// matmuls and reuse reading a and b.
|
||||
if (M == 1 && right_sorted_ == true) {
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
if (__builtin_available(
|
||||
macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
if (metal::is_nax_available() &&
|
||||
!issubdtype(a.dtype(), complexfloating) &&
|
||||
(env::enable_tf32() || a.dtype() != float32)) {
|
||||
return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s);
|
||||
}
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
gather_mm_rhs(a, b, rhs_indices, out, d, s);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -451,6 +451,210 @@ void qvm(
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
void qmm_nax(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const std::optional<array>& biases,
|
||||
array& out,
|
||||
bool transpose,
|
||||
int group_size,
|
||||
int bits,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
metal::Device& d,
|
||||
const Stream& s,
|
||||
const std::string& mode) {
|
||||
int B = out.size() / M / N;
|
||||
|
||||
int wm = 2;
|
||||
int wn = 2;
|
||||
int bm = 64;
|
||||
int bn = 64;
|
||||
int bk = 64;
|
||||
MTL::Size group_dims(32, wn, wm);
|
||||
MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B);
|
||||
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
bool aligned = N % 64 == 0;
|
||||
bool batched = B > 1;
|
||||
std::string type_string = get_type_string(x.dtype());
|
||||
concatenate(
|
||||
kname,
|
||||
mode + (transpose ? "_qmm_t_nax_" : "_qmm_n_nax_"),
|
||||
type_string,
|
||||
"_gs_",
|
||||
group_size,
|
||||
"_b_",
|
||||
bits,
|
||||
"_bm",
|
||||
bm,
|
||||
"_bn",
|
||||
bn,
|
||||
"_bk",
|
||||
bk,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn,
|
||||
transpose ? (aligned ? "_alN_true" : "_alN_false") : "",
|
||||
batched ? "_batch_1" : "_batch_0");
|
||||
std::string template_def;
|
||||
MTL::ComputePipelineState* kernel;
|
||||
if (transpose) {
|
||||
kernel = get_quantized_kernel_wrapped(
|
||||
d,
|
||||
kname,
|
||||
"qmm_t_nax",
|
||||
mode,
|
||||
type_string,
|
||||
group_size,
|
||||
bits,
|
||||
aligned,
|
||||
batched);
|
||||
} else {
|
||||
kernel = get_quantized_kernel_wrapped(
|
||||
d, kname, "qmm_n_nax", mode, type_string, group_size, bits, batched);
|
||||
}
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int c = 0;
|
||||
compute_encoder.set_input_array(w, c++);
|
||||
compute_encoder.set_input_array(scales, c++);
|
||||
if (biases) {
|
||||
compute_encoder.set_input_array(*biases, c++);
|
||||
}
|
||||
compute_encoder.set_input_array(x, c++);
|
||||
compute_encoder.set_output_array(out, c++);
|
||||
compute_encoder.set_bytes(K, c++);
|
||||
compute_encoder.set_bytes(N, c++);
|
||||
compute_encoder.set_bytes(M, c++);
|
||||
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void gather_qmm_nax(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const std::optional<array>& biases,
|
||||
const array& lhs_indices,
|
||||
const array& rhs_indices,
|
||||
array& out,
|
||||
bool transpose,
|
||||
int group_size,
|
||||
int bits,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
metal::Device& d,
|
||||
const Stream& s,
|
||||
const std::string& mode) {
|
||||
int B = out.size() / M / N;
|
||||
|
||||
int wm = 2;
|
||||
int wn = 2;
|
||||
int bm = 64;
|
||||
int bn = 64;
|
||||
int bk = 32;
|
||||
MTL::Size group_dims(32, wn, wm);
|
||||
MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B);
|
||||
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
bool aligned = N % 64 == 0;
|
||||
std::string type_string = get_type_string(x.dtype());
|
||||
concatenate(
|
||||
kname,
|
||||
mode + (transpose ? "_gather_qmm_t_nax_" : "_gather_qmm_n_nax_"),
|
||||
type_string,
|
||||
"_gs_",
|
||||
group_size,
|
||||
"_b_",
|
||||
bits,
|
||||
"_bm",
|
||||
bm,
|
||||
"_bn",
|
||||
bn,
|
||||
"_bk",
|
||||
bk,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn,
|
||||
transpose ? (aligned ? "_alN_true" : "_alN_false") : "");
|
||||
MTL::ComputePipelineState* kernel;
|
||||
if (transpose) {
|
||||
kernel = get_quantized_kernel_wrapped(
|
||||
d,
|
||||
kname,
|
||||
"gather_qmm_t_nax_",
|
||||
mode,
|
||||
type_string,
|
||||
group_size,
|
||||
bits,
|
||||
"_bm",
|
||||
bm,
|
||||
"_bn",
|
||||
bn,
|
||||
"_bk",
|
||||
bk,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn,
|
||||
aligned);
|
||||
} else {
|
||||
kernel = get_quantized_kernel_wrapped(
|
||||
d,
|
||||
kname,
|
||||
"gather_qmm_n_nax_",
|
||||
mode,
|
||||
type_string,
|
||||
group_size,
|
||||
bits,
|
||||
"_bm",
|
||||
bm,
|
||||
"_bn",
|
||||
bn,
|
||||
"_bk",
|
||||
bk,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn);
|
||||
}
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int c = 0;
|
||||
compute_encoder.set_input_array(w, c++);
|
||||
compute_encoder.set_input_array(scales, c++);
|
||||
if (biases) {
|
||||
compute_encoder.set_input_array(*biases, c++);
|
||||
}
|
||||
compute_encoder.set_input_array(x, c++);
|
||||
compute_encoder.set_input_array(lhs_indices, c++);
|
||||
compute_encoder.set_input_array(rhs_indices, c++);
|
||||
compute_encoder.set_output_array(out, c++);
|
||||
compute_encoder.set_bytes(K, c++);
|
||||
compute_encoder.set_bytes(N, c++);
|
||||
compute_encoder.set_bytes(M, c++);
|
||||
c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c);
|
||||
add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
void qmm(
|
||||
const array& x,
|
||||
const array& w,
|
||||
@@ -466,6 +670,31 @@ void qmm(
|
||||
metal::Device& d,
|
||||
const Stream& s,
|
||||
const std::string& mode) {
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
|
||||
(env::enable_tf32() || x.dtype() != float32)) {
|
||||
return qmm_nax(
|
||||
/* const array& x = */ x,
|
||||
/* const array& w = */ w,
|
||||
/* const array& scales = */ scales,
|
||||
/* const std::optional<array>& biases = */ biases,
|
||||
/* array& out = */ out,
|
||||
/* bool transpose = */ transpose,
|
||||
/* int group_size = */ group_size,
|
||||
/* int bits = */ bits,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const Stream& s = */ s,
|
||||
/* const std::string& mode = */ mode);
|
||||
}
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
int B = out.size() / M / N;
|
||||
|
||||
int wm = 2;
|
||||
@@ -543,6 +772,33 @@ void gather_qmm(
|
||||
metal::Device& d,
|
||||
const Stream& s,
|
||||
const std::string& mode) {
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
|
||||
(env::enable_tf32() || x.dtype() != float32)) {
|
||||
return gather_qmm_nax(
|
||||
/* const array& x = */ x,
|
||||
/* const array& w = */ w,
|
||||
/* const array& scales = */ scales,
|
||||
/* const std::optional<array>& biases = */ biases,
|
||||
/* const array& lhs_indices = */ lhs_indices,
|
||||
/* const array& rhs_indices = */ rhs_indices,
|
||||
/* array& out = */ out,
|
||||
/* bool transpose = */ transpose,
|
||||
/* int group_size = */ group_size,
|
||||
/* int bits = */ bits,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const Stream& s = */ s,
|
||||
/* const std::string& mode = */ mode);
|
||||
}
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
int B = out.size() / M / N;
|
||||
|
||||
int wm = 2;
|
||||
@@ -719,6 +975,141 @@ void gather_qvm(
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
void gather_qmm_rhs_nax(
|
||||
const array& x_,
|
||||
const array& w_,
|
||||
const array& scales_,
|
||||
const std::optional<array>& biases_,
|
||||
const array& indices_,
|
||||
array& out,
|
||||
bool transpose,
|
||||
int group_size,
|
||||
int bits,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
metal::Device& d,
|
||||
const Stream& s,
|
||||
const std::string mode) {
|
||||
// Start by normalizing the indices
|
||||
array indices = ensure_row_contiguous(indices_, d, s);
|
||||
|
||||
// Broadcast x with indices. If we are here that means lhs_indices were not
|
||||
// provided so the lhs_indices are implied to be the shape of x broadcasted
|
||||
// with rhs_indices. We need only broadcast x and copy it as if applying the
|
||||
// lhs_indices.
|
||||
auto broadcast_with_indices = [&d, &s, &indices](const array& x) {
|
||||
if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) {
|
||||
return ensure_row_contiguous(x, d, s);
|
||||
}
|
||||
|
||||
auto x_shape = indices.shape();
|
||||
x_shape.push_back(x.shape(-2));
|
||||
x_shape.push_back(x.shape(-1));
|
||||
array new_x(std::move(x_shape), x.dtype(), nullptr, {});
|
||||
broadcast(x, new_x);
|
||||
return ensure_row_contiguous(new_x, d, s);
|
||||
};
|
||||
|
||||
// Normalize the input arrays
|
||||
array x = broadcast_with_indices(x_);
|
||||
array w = ensure_row_contiguous(w_, d, s);
|
||||
array scales = ensure_row_contiguous(scales_, d, s);
|
||||
|
||||
// TODO: Tune the block sizes
|
||||
int bm = 64, bn = 64, bk = 64;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
const bool align_M = (M % bm) == 0;
|
||||
const bool align_N = (N % bn) == 0;
|
||||
const bool align_K = (K % bk) == 0;
|
||||
|
||||
// Make the kernel name
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
std::string type_string = get_type_string(x.dtype());
|
||||
concatenate(
|
||||
kname,
|
||||
mode +
|
||||
(transpose ? "_gather_qmm_rhs_nax_nt_" : "_gather_qmm_rhs_nax_nn_"),
|
||||
type_string,
|
||||
"_gs_",
|
||||
group_size,
|
||||
"_b_",
|
||||
bits,
|
||||
"_bm_",
|
||||
bm,
|
||||
"_bn_",
|
||||
bn,
|
||||
"_bk_",
|
||||
bk,
|
||||
"_wm_",
|
||||
wm,
|
||||
"_wn_",
|
||||
wn);
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||
};
|
||||
|
||||
// And the kernel hash that includes the function constants
|
||||
std::string hash_name;
|
||||
hash_name.reserve(128);
|
||||
concatenate(
|
||||
hash_name,
|
||||
kname,
|
||||
"_align_M_",
|
||||
align_M ? 't' : 'n',
|
||||
"_align_N_",
|
||||
align_N ? 't' : 'n',
|
||||
"_align_K_",
|
||||
align_K ? 't' : 'n');
|
||||
|
||||
// Get and set the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = get_gather_qmm_kernel(
|
||||
d,
|
||||
kname,
|
||||
hash_name,
|
||||
func_consts,
|
||||
x,
|
||||
group_size,
|
||||
bits,
|
||||
mode,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
MTL::Size group_dims(32, wn, wm);
|
||||
MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, 1);
|
||||
|
||||
int c = 0;
|
||||
compute_encoder.set_input_array(x, c++);
|
||||
compute_encoder.set_input_array(w, c++);
|
||||
compute_encoder.set_input_array(scales, c++);
|
||||
if (biases_) {
|
||||
array biases = ensure_row_contiguous(*biases_, d, s);
|
||||
compute_encoder.set_input_array(biases, c++);
|
||||
}
|
||||
compute_encoder.set_input_array(indices, c++);
|
||||
compute_encoder.set_output_array(out, c++);
|
||||
compute_encoder.set_bytes(M, c++);
|
||||
compute_encoder.set_bytes(N, c++);
|
||||
compute_encoder.set_bytes(K, c++);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
void gather_qmm_rhs(
|
||||
const array& x_,
|
||||
const array& w_,
|
||||
@@ -735,6 +1126,32 @@ void gather_qmm_rhs(
|
||||
metal::Device& d,
|
||||
const Stream& s,
|
||||
const std::string mode) {
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
if (metal::is_nax_available() && transpose &&
|
||||
(env::enable_tf32() || x_.dtype() != float32)) {
|
||||
return gather_qmm_rhs_nax(
|
||||
/* const array& x_ = */ x_,
|
||||
/* const array& w_ = */ w_,
|
||||
/* const array& scales_ = */ scales_,
|
||||
/* const std::optional<array>& biases_ = */ biases_,
|
||||
/* const array& indices_ = */ indices_,
|
||||
/* array& out = */ out,
|
||||
/* bool transpose = */ transpose,
|
||||
/* int group_size = */ group_size,
|
||||
/* int bits = */ bits,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const Stream& s = */ s,
|
||||
/* const std::string mode = */ mode);
|
||||
}
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
// Start by normalizing the indices
|
||||
array indices = ensure_row_contiguous(indices_, d, s);
|
||||
|
||||
|
||||
@@ -12,6 +12,146 @@
|
||||
namespace mlx::core::fast {
|
||||
|
||||
namespace {
|
||||
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
void sdpa_full_self_attention_nax(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
const float scale,
|
||||
array& o,
|
||||
bool do_causal_,
|
||||
const std::optional<array>& mask,
|
||||
const std::optional<array>& sinks) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
int wm = 4;
|
||||
int wn = 1;
|
||||
|
||||
int bd = q.shape(-1);
|
||||
int bq = 64;
|
||||
int bk = 32;
|
||||
|
||||
int B = q.shape(0);
|
||||
int H = q.shape(1);
|
||||
int D = q.shape(3);
|
||||
int gqa_factor = q.shape(1) / k.shape(1);
|
||||
|
||||
int qL = q.shape(2);
|
||||
int kL = k.shape(2);
|
||||
|
||||
const bool align_Q = (qL % bq) == 0;
|
||||
const bool align_K = (kL % bk) == 0;
|
||||
const bool has_mask = mask.has_value();
|
||||
const bool do_causal = do_causal_;
|
||||
const bool has_sinks = sinks.has_value();
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&align_Q, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 201},
|
||||
{&has_mask, MTL::DataType::DataTypeBool, 300},
|
||||
{&do_causal, MTL::DataType::DataTypeBool, 301},
|
||||
{&has_sinks, MTL::DataType::DataTypeBool, 302}};
|
||||
|
||||
std::string base_name;
|
||||
concatenate(
|
||||
base_name,
|
||||
"steel_attention_",
|
||||
type_to_name(q),
|
||||
"_bq",
|
||||
bq,
|
||||
"_bk",
|
||||
bk,
|
||||
"_bd",
|
||||
bd,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn,
|
||||
"_mask",
|
||||
type_to_name(has_mask ? *mask : q));
|
||||
|
||||
std::string hash_name;
|
||||
concatenate(
|
||||
hash_name,
|
||||
base_name,
|
||||
"_align_Q_",
|
||||
(align_Q ? 't' : 'n'),
|
||||
"_align_K_",
|
||||
(align_K ? 't' : 'n'),
|
||||
"_has_mask_",
|
||||
(has_mask ? 't' : 'n'),
|
||||
"_do_causal_",
|
||||
(do_causal ? 't' : 'n'),
|
||||
"_has_sinks_",
|
||||
(has_sinks ? 't' : 'n'));
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
const int NQ = (qL + bq - 1) / bq;
|
||||
const int NK = (kL + bk - 1) / bk;
|
||||
|
||||
const int NQ_aligned = qL / bq;
|
||||
const int NK_aligned = kL / bk;
|
||||
|
||||
AttnParams params{
|
||||
/* int B = */ B,
|
||||
/* int H = */ H,
|
||||
/* int D = */ D,
|
||||
|
||||
/* int qL = */ qL,
|
||||
/* int kL = */ kL,
|
||||
|
||||
/* int gqa_factor = */ gqa_factor,
|
||||
/* float scale = */ scale,
|
||||
|
||||
/* int NQ = */ NQ,
|
||||
/* int NK = */ NK,
|
||||
|
||||
/* int NQ_aligned = */ NQ_aligned,
|
||||
/* int NK_aligned = */ NK_aligned,
|
||||
|
||||
/* int qL_rem = */ (qL - NQ_aligned * bq),
|
||||
/* int kL_rem = */ (kL - NK_aligned * bk),
|
||||
/* int qL_off = */ (kL - qL),
|
||||
|
||||
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
|
||||
|
||||
compute_encoder.set_input_array(q, 0);
|
||||
compute_encoder.set_input_array(k, 1);
|
||||
compute_encoder.set_input_array(v, 2);
|
||||
compute_encoder.set_output_array(o, 3);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
|
||||
if (has_mask) {
|
||||
auto& m = *mask;
|
||||
|
||||
AttnMaskParams mask_params{/* int64_t M_strides[3] = */ {
|
||||
m.strides(0), m.strides(1), m.strides(2)}};
|
||||
|
||||
compute_encoder.set_bytes(mask_params, 5);
|
||||
compute_encoder.set_input_array(m, 6);
|
||||
}
|
||||
if (has_sinks) {
|
||||
compute_encoder.set_input_array(*sinks, 7);
|
||||
}
|
||||
|
||||
MTL::Size grid_dims = MTL::Size(NQ, H, B);
|
||||
MTL::Size group_dims = MTL::Size(32, wm, wn);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
void sdpa_full_self_attention_metal(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@@ -23,6 +163,25 @@ void sdpa_full_self_attention_metal(
|
||||
bool do_causal_,
|
||||
const std::optional<array>& mask,
|
||||
const std::optional<array>& sinks) {
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
if (metal::is_nax_available() && q.shape(3) != 80 &&
|
||||
(env::enable_tf32() || q.dtype() != float32)) {
|
||||
return sdpa_full_self_attention_nax(
|
||||
/* const Stream& s = */ s,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const array& q = */ q,
|
||||
/* const array& k = */ k,
|
||||
/* const array& v = */ v,
|
||||
/* const float scale = */ scale,
|
||||
/* array& o = */ o,
|
||||
/* bool do_causal_ = */ do_causal_,
|
||||
/* const std::optional<array>& mask = */ mask,
|
||||
/* const std::optional<array>& sinks = */ sinks);
|
||||
}
|
||||
}
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
int wm = 4;
|
||||
|
||||
@@ -163,6 +163,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
def test_qmm(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
dtype = mx.float16 if (mx.default_device() == mx.gpu) else mx.float32
|
||||
tests = product(
|
||||
[128, 64, 32], # group_size
|
||||
[2, 4, 8], # bits
|
||||
@@ -178,8 +179,13 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
bits=bits,
|
||||
transposed=transposed,
|
||||
):
|
||||
x = mx.random.normal(shape=(M, K), key=k1)
|
||||
w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
|
||||
x = mx.random.normal(shape=(M, K), key=k1) / K**0.5
|
||||
w = (
|
||||
mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
|
||||
/ K**0.5
|
||||
)
|
||||
x = x.astype(dtype)
|
||||
w = w.astype(dtype)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||
y_q = mx.quantized_matmul(
|
||||
@@ -187,7 +193,9 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
)
|
||||
y_hat = (x @ w_hat.T) if transposed else (x @ w_hat)
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
tol = 1e-3 if dtype == mx.float32 else 1.5e-3
|
||||
self.assertLess((y_q - y_hat).abs().max(), tol)
|
||||
|
||||
def test_qmm_vjp(self):
|
||||
key = mx.random.key(0)
|
||||
@@ -833,48 +841,75 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
(133, 512, 555, 4, 2, False, "affine"),
|
||||
(64, 512, 512, 4, 2, False, "affine"),
|
||||
]
|
||||
|
||||
key = mx.random.key(0)
|
||||
k1, k2, k3 = mx.random.split(key, 3)
|
||||
dtype = mx.float16 if (mx.default_device() == mx.gpu) else mx.float32
|
||||
|
||||
for L, K, D, E, I, transpose, mode in parameters:
|
||||
if mode == "mxfp4":
|
||||
group_size = 32
|
||||
else:
|
||||
group_size = 64
|
||||
K, D = (K, D) if transpose else (D, K)
|
||||
ishape = (L, I)
|
||||
xshape = (L, 1, 1, K)
|
||||
wshape = (E, D, K) if transpose else (E, K, D)
|
||||
with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode):
|
||||
if mode == "mxfp4":
|
||||
group_size = 32
|
||||
dtype = (
|
||||
mx.bfloat16 if (mx.default_device() == mx.gpu) else mx.float32
|
||||
)
|
||||
else:
|
||||
group_size = 64
|
||||
dtype = (
|
||||
mx.float16 if (mx.default_device() == mx.gpu) else mx.float32
|
||||
)
|
||||
|
||||
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
|
||||
x = mx.random.normal(xshape) / K**0.5
|
||||
w = mx.random.normal(wshape) / K**0.5
|
||||
w, *wq = quantize(w, group_size=group_size, mode=mode, transpose=transpose)
|
||||
K, D = (K, D) if transpose else (D, K)
|
||||
ishape = (L, I)
|
||||
xshape = (L, 1, 1, K)
|
||||
wshape = (E, D, K) if transpose else (E, K, D)
|
||||
|
||||
y1 = mx.gather_mm(x, w, rhs_indices=indices)
|
||||
y2 = mx.gather_qmm(
|
||||
x,
|
||||
*wq,
|
||||
group_size=group_size,
|
||||
mode=mode,
|
||||
transpose=transpose,
|
||||
rhs_indices=indices
|
||||
)
|
||||
xs, idx, inv_order = gather_sort(x, indices)
|
||||
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
|
||||
indices = (mx.random.uniform(shape=ishape, key=k1) * E).astype(
|
||||
mx.uint32
|
||||
)
|
||||
x = mx.random.normal(xshape, key=k2) / K**0.5
|
||||
w = mx.random.normal(wshape, key=k3) / K**0.5
|
||||
|
||||
y4 = mx.gather_qmm(
|
||||
xs,
|
||||
*wq,
|
||||
group_size=group_size,
|
||||
mode=mode,
|
||||
rhs_indices=idx,
|
||||
transpose=transpose,
|
||||
sorted_indices=True
|
||||
)
|
||||
y3 = scatter_unsort(y3, inv_order, indices.shape)
|
||||
y4 = scatter_unsort(y4, inv_order, indices.shape)
|
||||
x = x.astype(dtype)
|
||||
w = w.astype(dtype)
|
||||
|
||||
self.assertTrue(mx.allclose(y1, y2, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
|
||||
w, *wq = quantize(
|
||||
w, group_size=group_size, mode=mode, transpose=transpose
|
||||
)
|
||||
|
||||
y1 = mx.gather_mm(x, w, rhs_indices=indices)
|
||||
y2 = mx.gather_qmm(
|
||||
x,
|
||||
*wq,
|
||||
group_size=group_size,
|
||||
mode=mode,
|
||||
transpose=transpose,
|
||||
rhs_indices=indices
|
||||
)
|
||||
xs, idx, inv_order = gather_sort(x, indices)
|
||||
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
|
||||
|
||||
y4 = mx.gather_qmm(
|
||||
xs,
|
||||
*wq,
|
||||
group_size=group_size,
|
||||
mode=mode,
|
||||
rhs_indices=idx,
|
||||
transpose=transpose,
|
||||
sorted_indices=True
|
||||
)
|
||||
y3 = scatter_unsort(y3, inv_order, indices.shape)
|
||||
y4 = scatter_unsort(y4, inv_order, indices.shape)
|
||||
|
||||
tol = 1.5e-5 if (dtype == mx.float32) else 2.5e-4
|
||||
|
||||
self.assertLess((y1 - y2).abs().max(), tol)
|
||||
self.assertLess((y1 - y3).abs().max(), tol)
|
||||
self.assertLess((y1 - y4).abs().max(), tol)
|
||||
|
||||
self.assertTrue(mx.allclose(y1, y2, atol=tol))
|
||||
self.assertTrue(mx.allclose(y1, y3, atol=tol))
|
||||
self.assertTrue(mx.allclose(y1, y4, atol=tol))
|
||||
|
||||
def test_gather_qmm_grad(self):
|
||||
def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort):
|
||||
@@ -898,10 +933,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
sorted_indices=sort,
|
||||
)
|
||||
|
||||
x = mx.random.normal((16, 1, 256))
|
||||
w, s, b = mx.quantize(mx.random.normal((4, 256, 256)))
|
||||
indices = mx.sort(mx.random.randint(0, 4, shape=(16,)))
|
||||
cotan = mx.random.normal((16, 1, 256))
|
||||
key = mx.random.key(0)
|
||||
k1, k2, k3, k4 = mx.random.split(key, 4)
|
||||
dtype = mx.float32
|
||||
|
||||
x = mx.random.normal((16, 1, 256), key=k1).astype(dtype)
|
||||
w, s, b = mx.quantize(mx.random.normal((4, 256, 256), key=k2).astype(dtype))
|
||||
indices = mx.sort(mx.random.randint(0, 4, shape=(16,), key=k3))
|
||||
cotan = mx.random.normal((16, 1, 256), key=k4).astype(dtype)
|
||||
|
||||
(o1,), (dx1, ds1, db1) = mx.vjp(
|
||||
lambda x, s, b: gather_qmm_ref(x, w, s, b, None, indices, True, True),
|
||||
@@ -914,6 +953,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
[cotan],
|
||||
)
|
||||
|
||||
self.assertLess((o1 - o2).abs().max(), 1e-4)
|
||||
self.assertTrue(mx.allclose(o1, o2, atol=1e-4))
|
||||
self.assertTrue(mx.allclose(dx1, dx2, atol=1e-4))
|
||||
self.assertTrue(mx.allclose(ds1, ds2, atol=1e-3))
|
||||
|
||||
Reference in New Issue
Block a user