mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[WIP] Add init QMMs (some CI tests failing)
This commit is contained in:
@@ -269,7 +269,6 @@ std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
|||||||
|
|
||||||
inline bool is_nax_available() {
|
inline bool is_nax_available() {
|
||||||
static bool is_nax_available_ =
|
static bool is_nax_available_ =
|
||||||
/* __builtin_available(macOS 26.2, *) && */
|
|
||||||
metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17;
|
metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17;
|
||||||
return is_nax_available_;
|
return is_nax_available_;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -137,6 +137,8 @@ if(MLX_ENABLE_NAX)
|
|||||||
build_kernel(steel/gemm/kernels/steel_gemm_fused_nax ${STEEL_NAX_HEADERS})
|
build_kernel(steel/gemm/kernels/steel_gemm_fused_nax ${STEEL_NAX_HEADERS})
|
||||||
build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS})
|
build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS})
|
||||||
|
|
||||||
|
build_kernel(quantized_nax quantized_nax.h ${STEEL_NAX_HEADERS})
|
||||||
|
|
||||||
set(STEEL_NAX_ATTN_HEADERS
|
set(STEEL_NAX_ATTN_HEADERS
|
||||||
steel/defines.h steel/utils.h steel/attn/nax.h steel/utils/type_traits.h
|
steel/defines.h steel/utils.h steel/attn/nax.h steel/utils/type_traits.h
|
||||||
steel/utils/integral_constant.h)
|
steel/utils/integral_constant.h)
|
||||||
|
|||||||
1559
mlx/backend/metal/kernels/quantized_nax.h
Normal file
1559
mlx/backend/metal/kernels/quantized_nax.h
Normal file
File diff suppressed because it is too large
Load Diff
105
mlx/backend/metal/kernels/quantized_nax.metal
Normal file
105
mlx/backend/metal/kernels/quantized_nax.metal
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/loader.h"
|
||||||
|
#include "mlx/backend/metal/kernels/quantized_nax.h"
|
||||||
|
|
||||||
|
#define instantiate_quantized(name, type, group_size, bits, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
#name "_" #type "_gs_" #group_size "_b_" #bits, \
|
||||||
|
name, \
|
||||||
|
type, \
|
||||||
|
group_size, \
|
||||||
|
bits, bm, bk, bn, wm, wn)
|
||||||
|
|
||||||
|
#define instantiate_quantized_batched(name, type, group_size, bits, bm, bn, bk, wm, wn, batched) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_batch_" #batched, \
|
||||||
|
name, \
|
||||||
|
type, \
|
||||||
|
group_size, \
|
||||||
|
bits, \
|
||||||
|
batched, bm, bk, bn, wm, wn)
|
||||||
|
|
||||||
|
#define instantiate_quantized_aligned(name, type, group_size, bits, bm, bn, bk, wm, wn, aligned) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned, \
|
||||||
|
name, \
|
||||||
|
type, \
|
||||||
|
group_size, \
|
||||||
|
bits, \
|
||||||
|
aligned, bm, bk, bn, wm, wn)
|
||||||
|
|
||||||
|
#define instantiate_quantized_aligned_batched(name, type, group_size, bits, bm, bn, bk, wm, wn, aligned, batched) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned "_batch_" #batched, \
|
||||||
|
name, \
|
||||||
|
type, \
|
||||||
|
group_size, \
|
||||||
|
bits, \
|
||||||
|
aligned, \
|
||||||
|
batched, bm, bk, bn, wm, wn)
|
||||||
|
|
||||||
|
#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
|
||||||
|
func, \
|
||||||
|
type, \
|
||||||
|
group_size, \
|
||||||
|
bits, \
|
||||||
|
bm, \
|
||||||
|
bn, \
|
||||||
|
bk, \
|
||||||
|
wm, \
|
||||||
|
wn, \
|
||||||
|
transpose)
|
||||||
|
|
||||||
|
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
|
||||||
|
instantiate_quantized_batched(name, type, group_size, bits, 64, 64, 64, 2, 2, 1) \
|
||||||
|
instantiate_quantized_batched(name, type, group_size, bits, 64, 64, 64, 2, 2, 0)
|
||||||
|
|
||||||
|
#define instantiate_quantized_all_batched(type, group_size, bits) \
|
||||||
|
instantiate_quantized_batched_wrap(affine_qmm_n_nax, type, group_size, bits)
|
||||||
|
|
||||||
|
|
||||||
|
#define instantiate_quantized_all_single(type, group_size, bits) \
|
||||||
|
instantiate_quantized(affine_gather_qmm_n_nax, type, group_size, bits, 64, 64, 64, 2, 2)
|
||||||
|
|
||||||
|
#define instantiate_quantized_all_aligned(type, group_size, bits) \
|
||||||
|
instantiate_quantized_aligned(affine_gather_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true) \
|
||||||
|
instantiate_quantized_aligned(affine_gather_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false) \
|
||||||
|
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true, 1) \
|
||||||
|
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true, 0) \
|
||||||
|
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false, 1) \
|
||||||
|
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false, 0)
|
||||||
|
|
||||||
|
#define instantiate_quantized_all_rhs(type, group_size, bits) \
|
||||||
|
instantiate_gather_qmm_rhs(affine_gather_qmm_rhs_nax, affine_gather_qmm_rhs_nax_nt, type, group_size, bits, 64, 64, 64, 2, 2, true) \
|
||||||
|
instantiate_gather_qmm_rhs(affine_gather_qmm_rhs_nax, affine_gather_qmm_rhs_nax_nn, type, group_size, bits, 64, 64, 64, 2, 2, false)
|
||||||
|
|
||||||
|
#define instantiate_quantized_funcs(type, group_size, bits) \
|
||||||
|
instantiate_quantized_all_batched(type, group_size, bits) \
|
||||||
|
instantiate_quantized_all_aligned(type, group_size, bits) \
|
||||||
|
instantiate_quantized_all_rhs(type, group_size, bits)
|
||||||
|
|
||||||
|
#define instantiate_quantized_types(group_size, bits) \
|
||||||
|
instantiate_quantized_funcs(float, group_size, bits) \
|
||||||
|
instantiate_quantized_funcs(float16_t, group_size, bits) \
|
||||||
|
instantiate_quantized_funcs(bfloat16_t, group_size, bits)
|
||||||
|
|
||||||
|
#define instantiate_quantized_groups(bits) \
|
||||||
|
instantiate_quantized_types(128, bits) \
|
||||||
|
instantiate_quantized_types(64, bits)
|
||||||
|
|
||||||
|
#define instantiate_quantized_all() \
|
||||||
|
instantiate_quantized_groups(2) \
|
||||||
|
instantiate_quantized_groups(3) \
|
||||||
|
instantiate_quantized_groups(4) \
|
||||||
|
instantiate_quantized_groups(5) \
|
||||||
|
instantiate_quantized_groups(6) \
|
||||||
|
instantiate_quantized_groups(8)
|
||||||
|
|
||||||
|
instantiate_quantized_all() // clang-format on
|
||||||
@@ -178,7 +178,7 @@ template <
|
|||||||
}
|
}
|
||||||
|
|
||||||
const bool is_last_bq = int(tid.x) == (params->NQ_aligned);
|
const bool is_last_bq = int(tid.x) == (params->NQ_aligned);
|
||||||
const bool is_last_tq = int(simd_group_id) >= (params->qL_rem / UQ);
|
// const bool is_last_tq = int(simd_group_id) >= (params->qL_rem / UQ);
|
||||||
const bool is_last_q = is_last_bq;
|
const bool is_last_q = is_last_bq;
|
||||||
|
|
||||||
const short lim_rows_q = params->qL_rem - (tm + sm);
|
const short lim_rows_q = params->qL_rem - (tm + sm);
|
||||||
|
|||||||
@@ -39,9 +39,6 @@ struct BaseNAXFrag {
|
|||||||
kElemRows * kElemCols == kElemsPerFrag,
|
kElemRows * kElemCols == kElemsPerFrag,
|
||||||
"MMAFrag shape is not consistent with MMAFrag size");
|
"MMAFrag shape is not consistent with MMAFrag size");
|
||||||
|
|
||||||
template <typename U>
|
|
||||||
using dtype_mat_t = typename metal::simdgroup_matrix<U, kFragRows, kFragCols>;
|
|
||||||
|
|
||||||
template <typename U>
|
template <typename U>
|
||||||
using dtype_frag_t = typename metal::vec<U, kElemsPerFrag>;
|
using dtype_frag_t = typename metal::vec<U, kElemsPerFrag>;
|
||||||
|
|
||||||
|
|||||||
@@ -40,9 +40,6 @@ struct BaseNAXFrag {
|
|||||||
kElemRows * kElemCols == kElemsPerFrag,
|
kElemRows * kElemCols == kElemsPerFrag,
|
||||||
"MMAFrag shape is not consistent with MMAFrag size");
|
"MMAFrag shape is not consistent with MMAFrag size");
|
||||||
|
|
||||||
template <typename U>
|
|
||||||
using dtype_mat_t = typename metal::simdgroup_matrix<U, kFragRows, kFragCols>;
|
|
||||||
|
|
||||||
template <typename U>
|
template <typename U>
|
||||||
using dtype_frag_t = typename metal::vec<U, kElemsPerFrag>;
|
using dtype_frag_t = typename metal::vec<U, kElemsPerFrag>;
|
||||||
|
|
||||||
|
|||||||
@@ -359,33 +359,35 @@ void steel_matmul_regular_axpby(
|
|||||||
float beta /* = 0.0f */) {
|
float beta /* = 0.0f */) {
|
||||||
#ifdef MLX_ENABLE_NAX
|
#ifdef MLX_ENABLE_NAX
|
||||||
|
|
||||||
if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) &&
|
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||||
(a.dtype() != float32 || env::enable_tf32())) {
|
if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) &&
|
||||||
return steel_matmul_regular_axpby_nax<CHECK_AB>(
|
(a.dtype() != float32 || env::enable_tf32())) {
|
||||||
/* const Stream& s = */ s,
|
return steel_matmul_regular_axpby_nax<CHECK_AB>(
|
||||||
/* metal::Device& d = */ d,
|
/* const Stream& s = */ s,
|
||||||
/* const array& a = */ a,
|
/* metal::Device& d = */ d,
|
||||||
/* const array& b = */ b,
|
/* const array& a = */ a,
|
||||||
/* const array& c = */ c,
|
/* const array& b = */ b,
|
||||||
/* array& out = */ out,
|
/* const array& c = */ c,
|
||||||
/* int M = */ M,
|
/* array& out = */ out,
|
||||||
/* int N = */ N,
|
/* int M = */ M,
|
||||||
/* int K = */ K,
|
/* int N = */ N,
|
||||||
/* int batch_size_out = */ batch_size_out,
|
/* int K = */ K,
|
||||||
/* int lda = */ lda,
|
/* int batch_size_out = */ batch_size_out,
|
||||||
/* int ldb = */ ldb,
|
/* int lda = */ lda,
|
||||||
/* int ldd = */ ldd,
|
/* int ldb = */ ldb,
|
||||||
/* bool transpose_a = */ transpose_a,
|
/* int ldd = */ ldd,
|
||||||
/* bool transpose_b = */ transpose_b,
|
/* bool transpose_a = */ transpose_a,
|
||||||
/* std::vector<array>& copies = */ copies,
|
/* bool transpose_b = */ transpose_b,
|
||||||
/* Shape batch_shape = */ batch_shape,
|
/* std::vector<array>& copies = */ copies,
|
||||||
/* Strides batch_strides = */ batch_strides,
|
/* Shape batch_shape = */ batch_shape,
|
||||||
/* int64_t A_batch_stride = */ A_batch_stride,
|
/* Strides batch_strides = */ batch_strides,
|
||||||
/* int64_t B_batch_stride = */ B_batch_stride,
|
/* int64_t A_batch_stride = */ A_batch_stride,
|
||||||
/* int64_t matrix_stride_out = */ matrix_stride_out,
|
/* int64_t B_batch_stride = */ B_batch_stride,
|
||||||
/* int64_t C_batch_stride = */ C_batch_stride,
|
/* int64_t matrix_stride_out = */ matrix_stride_out,
|
||||||
/* float alpha = */ alpha,
|
/* int64_t C_batch_stride = */ C_batch_stride,
|
||||||
/* float beta = */ beta);
|
/* float alpha = */ alpha,
|
||||||
|
/* float beta = */ beta);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
#endif // MLX_ENABLE_NAX
|
||||||
@@ -2196,8 +2198,11 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (M == 1 && right_sorted_ == true) {
|
if (M == 1 && right_sorted_ == true) {
|
||||||
#ifdef MLX_ENABLE_NAX
|
#ifdef MLX_ENABLE_NAX
|
||||||
|
|
||||||
if (metal::is_nax_available() && a.dtype() != float32) {
|
if (__builtin_available(
|
||||||
return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s);
|
macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||||
|
if (metal::is_nax_available() && a.dtype() != float32) {
|
||||||
|
return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLX_ENABLE_NAX
|
#endif // MLX_ENABLE_NAX
|
||||||
|
|||||||
@@ -451,6 +451,96 @@ void qvm(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef MLX_ENABLE_NAX
|
||||||
|
|
||||||
|
void qmm_nax(
|
||||||
|
const array& x,
|
||||||
|
const array& w,
|
||||||
|
const array& scales,
|
||||||
|
const std::optional<array>& biases,
|
||||||
|
array& out,
|
||||||
|
bool transpose,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
metal::Device& d,
|
||||||
|
const Stream& s,
|
||||||
|
const std::string& mode) {
|
||||||
|
int B = out.size() / M / N;
|
||||||
|
|
||||||
|
int wm = 2;
|
||||||
|
int wn = 2;
|
||||||
|
int bm = 64;
|
||||||
|
int bn = 64;
|
||||||
|
int bk = 64;
|
||||||
|
MTL::Size group_dims(32, wn, wm);
|
||||||
|
MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B);
|
||||||
|
|
||||||
|
std::string kname;
|
||||||
|
kname.reserve(64);
|
||||||
|
bool aligned = N % 64 == 0;
|
||||||
|
bool batched = B > 1;
|
||||||
|
std::string type_string = get_type_string(x.dtype());
|
||||||
|
concatenate(
|
||||||
|
kname,
|
||||||
|
mode + (transpose ? "_qmm_t_nax_" : "_qmm_n_nax_"),
|
||||||
|
type_string,
|
||||||
|
"_gs_",
|
||||||
|
group_size,
|
||||||
|
"_b_",
|
||||||
|
bits,
|
||||||
|
"_bm",
|
||||||
|
bm,
|
||||||
|
"_bn",
|
||||||
|
bn,
|
||||||
|
"_bk",
|
||||||
|
bk,
|
||||||
|
"_wm",
|
||||||
|
wm,
|
||||||
|
"_wn",
|
||||||
|
wn,
|
||||||
|
transpose ? (aligned ? "_alN_true" : "_alN_false") : "",
|
||||||
|
batched ? "_batch_1" : "_batch_0");
|
||||||
|
std::string template_def;
|
||||||
|
MTL::ComputePipelineState* kernel;
|
||||||
|
if (transpose) {
|
||||||
|
kernel = get_quantized_kernel_wrapped(
|
||||||
|
d,
|
||||||
|
kname,
|
||||||
|
"qmm_t_nax",
|
||||||
|
mode,
|
||||||
|
type_string,
|
||||||
|
group_size,
|
||||||
|
bits,
|
||||||
|
aligned,
|
||||||
|
batched);
|
||||||
|
} else {
|
||||||
|
kernel = get_quantized_kernel_wrapped(
|
||||||
|
d, kname, "qmm_n_nax", mode, type_string, group_size, bits, batched);
|
||||||
|
}
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
|
int c = 0;
|
||||||
|
compute_encoder.set_input_array(w, c++);
|
||||||
|
compute_encoder.set_input_array(scales, c++);
|
||||||
|
if (biases) {
|
||||||
|
compute_encoder.set_input_array(*biases, c++);
|
||||||
|
}
|
||||||
|
compute_encoder.set_input_array(x, c++);
|
||||||
|
compute_encoder.set_output_array(out, c++);
|
||||||
|
compute_encoder.set_bytes(K, c++);
|
||||||
|
compute_encoder.set_bytes(N, c++);
|
||||||
|
compute_encoder.set_bytes(M, c++);
|
||||||
|
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c);
|
||||||
|
|
||||||
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // MLX_ENABLE_NAX
|
||||||
|
|
||||||
void qmm(
|
void qmm(
|
||||||
const array& x,
|
const array& x,
|
||||||
const array& w,
|
const array& w,
|
||||||
@@ -466,6 +556,31 @@ void qmm(
|
|||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
const std::string& mode) {
|
const std::string& mode) {
|
||||||
|
#ifdef MLX_ENABLE_NAX
|
||||||
|
|
||||||
|
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||||
|
if (metal::is_nax_available() && mode == "affine" && (group_size >= 64) &&
|
||||||
|
transpose && (M % 64 == 0) && (N % 64 == 0) && (K % 64 == 0)) {
|
||||||
|
return qmm_nax(
|
||||||
|
/* const array& x = */ x,
|
||||||
|
/* const array& w = */ w,
|
||||||
|
/* const array& scales = */ scales,
|
||||||
|
/* const std::optional<array>& biases = */ biases,
|
||||||
|
/* array& out = */ out,
|
||||||
|
/* bool transpose = */ transpose,
|
||||||
|
/* int group_size = */ group_size,
|
||||||
|
/* int bits = */ bits,
|
||||||
|
/* int M = */ M,
|
||||||
|
/* int N = */ N,
|
||||||
|
/* int K = */ K,
|
||||||
|
/* metal::Device& d = */ d,
|
||||||
|
/* const Stream& s = */ s,
|
||||||
|
/* const std::string& mode = */ mode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // MLX_ENABLE_NAX
|
||||||
|
|
||||||
int B = out.size() / M / N;
|
int B = out.size() / M / N;
|
||||||
|
|
||||||
int wm = 2;
|
int wm = 2;
|
||||||
@@ -719,6 +834,141 @@ void gather_qvm(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef MLX_ENABLE_NAX
|
||||||
|
|
||||||
|
void gather_qmm_rhs_nax(
|
||||||
|
const array& x_,
|
||||||
|
const array& w_,
|
||||||
|
const array& scales_,
|
||||||
|
const std::optional<array>& biases_,
|
||||||
|
const array& indices_,
|
||||||
|
array& out,
|
||||||
|
bool transpose,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
metal::Device& d,
|
||||||
|
const Stream& s,
|
||||||
|
const std::string mode) {
|
||||||
|
// Start by normalizing the indices
|
||||||
|
array indices = ensure_row_contiguous(indices_, d, s);
|
||||||
|
|
||||||
|
// Broadcast x with indices. If we are here that means lhs_indices were not
|
||||||
|
// provided so the lhs_indices are implied to be the shape of x broadcasted
|
||||||
|
// with rhs_indices. We need only broadcast x and copy it as if applying the
|
||||||
|
// lhs_indices.
|
||||||
|
auto broadcast_with_indices = [&d, &s, &indices](const array& x) {
|
||||||
|
if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) {
|
||||||
|
return ensure_row_contiguous(x, d, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto x_shape = indices.shape();
|
||||||
|
x_shape.push_back(x.shape(-2));
|
||||||
|
x_shape.push_back(x.shape(-1));
|
||||||
|
array new_x(std::move(x_shape), x.dtype(), nullptr, {});
|
||||||
|
broadcast(x, new_x);
|
||||||
|
return ensure_row_contiguous(new_x, d, s);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Normalize the input arrays
|
||||||
|
array x = broadcast_with_indices(x_);
|
||||||
|
array w = ensure_row_contiguous(w_, d, s);
|
||||||
|
array scales = ensure_row_contiguous(scales_, d, s);
|
||||||
|
|
||||||
|
// TODO: Tune the block sizes
|
||||||
|
int bm = 64, bn = 64, bk = 64;
|
||||||
|
int wm = 2, wn = 2;
|
||||||
|
|
||||||
|
const bool align_M = (M % bm) == 0;
|
||||||
|
const bool align_N = (N % bn) == 0;
|
||||||
|
const bool align_K = (K % bk) == 0;
|
||||||
|
|
||||||
|
// Make the kernel name
|
||||||
|
std::string kname;
|
||||||
|
kname.reserve(64);
|
||||||
|
std::string type_string = get_type_string(x.dtype());
|
||||||
|
concatenate(
|
||||||
|
kname,
|
||||||
|
mode +
|
||||||
|
(transpose ? "_gather_qmm_rhs_nax_nt_" : "_gather_qmm_rhs_nax_nn_"),
|
||||||
|
type_string,
|
||||||
|
"_gs_",
|
||||||
|
group_size,
|
||||||
|
"_b_",
|
||||||
|
bits,
|
||||||
|
"_bm_",
|
||||||
|
bm,
|
||||||
|
"_bn_",
|
||||||
|
bn,
|
||||||
|
"_bk_",
|
||||||
|
bk,
|
||||||
|
"_wm_",
|
||||||
|
wm,
|
||||||
|
"_wn_",
|
||||||
|
wn);
|
||||||
|
|
||||||
|
metal::MTLFCList func_consts = {
|
||||||
|
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||||
|
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||||
|
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||||
|
};
|
||||||
|
|
||||||
|
// And the kernel hash that includes the function constants
|
||||||
|
std::string hash_name;
|
||||||
|
hash_name.reserve(128);
|
||||||
|
concatenate(
|
||||||
|
hash_name,
|
||||||
|
kname,
|
||||||
|
"_align_M_",
|
||||||
|
align_M ? 't' : 'n',
|
||||||
|
"_align_N_",
|
||||||
|
align_N ? 't' : 'n',
|
||||||
|
"_align_K_",
|
||||||
|
align_K ? 't' : 'n');
|
||||||
|
|
||||||
|
// Get and set the kernel
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = get_gather_qmm_kernel(
|
||||||
|
d,
|
||||||
|
kname,
|
||||||
|
hash_name,
|
||||||
|
func_consts,
|
||||||
|
x,
|
||||||
|
group_size,
|
||||||
|
bits,
|
||||||
|
mode,
|
||||||
|
bm,
|
||||||
|
bn,
|
||||||
|
bk,
|
||||||
|
wm,
|
||||||
|
wn,
|
||||||
|
transpose);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
|
MTL::Size group_dims(32, wn, wm);
|
||||||
|
MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, 1);
|
||||||
|
|
||||||
|
int c = 0;
|
||||||
|
compute_encoder.set_input_array(x, c++);
|
||||||
|
compute_encoder.set_input_array(w, c++);
|
||||||
|
compute_encoder.set_input_array(scales, c++);
|
||||||
|
if (biases_) {
|
||||||
|
array biases = ensure_row_contiguous(*biases_, d, s);
|
||||||
|
compute_encoder.set_input_array(biases, c++);
|
||||||
|
}
|
||||||
|
compute_encoder.set_input_array(indices, c++);
|
||||||
|
compute_encoder.set_output_array(out, c++);
|
||||||
|
compute_encoder.set_bytes(M, c++);
|
||||||
|
compute_encoder.set_bytes(N, c++);
|
||||||
|
compute_encoder.set_bytes(K, c++);
|
||||||
|
|
||||||
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // MLX_ENABLE_NAX
|
||||||
|
|
||||||
void gather_qmm_rhs(
|
void gather_qmm_rhs(
|
||||||
const array& x_,
|
const array& x_,
|
||||||
const array& w_,
|
const array& w_,
|
||||||
@@ -735,6 +985,31 @@ void gather_qmm_rhs(
|
|||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
const std::string mode) {
|
const std::string mode) {
|
||||||
|
#ifdef MLX_ENABLE_NAX
|
||||||
|
|
||||||
|
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||||
|
if (metal::is_nax_available() && mode == "affine" && (group_size >= 64)) {
|
||||||
|
return gather_qmm_rhs_nax(
|
||||||
|
/* const array& x_ = */ x_,
|
||||||
|
/* const array& w_ = */ w_,
|
||||||
|
/* const array& scales_ = */ scales_,
|
||||||
|
/* const std::optional<array>& biases_ = */ biases_,
|
||||||
|
/* const array& indices_ = */ indices_,
|
||||||
|
/* array& out = */ out,
|
||||||
|
/* bool transpose = */ transpose,
|
||||||
|
/* int group_size = */ group_size,
|
||||||
|
/* int bits = */ bits,
|
||||||
|
/* int M = */ M,
|
||||||
|
/* int N = */ N,
|
||||||
|
/* int K = */ K,
|
||||||
|
/* metal::Device& d = */ d,
|
||||||
|
/* const Stream& s = */ s,
|
||||||
|
/* const std::string mode = */ mode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // MLX_ENABLE_NAX
|
||||||
|
|
||||||
// Start by normalizing the indices
|
// Start by normalizing the indices
|
||||||
array indices = ensure_row_contiguous(indices_, d, s);
|
array indices = ensure_row_contiguous(indices_, d, s);
|
||||||
|
|
||||||
|
|||||||
@@ -164,19 +164,21 @@ void sdpa_full_self_attention_metal(
|
|||||||
const std::optional<array>& mask,
|
const std::optional<array>& mask,
|
||||||
const std::optional<array>& sinks) {
|
const std::optional<array>& sinks) {
|
||||||
#ifdef MLX_ENABLE_NAX
|
#ifdef MLX_ENABLE_NAX
|
||||||
if (metal::is_nax_available() && q.shape(3) != 80 &&
|
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||||
(q.dtype() != float32 || env::enable_tf32())) {
|
if (metal::is_nax_available() && q.shape(3) != 80 &&
|
||||||
return sdpa_full_self_attention_nax(
|
(q.dtype() != float32 || env::enable_tf32())) {
|
||||||
/* const Stream& s = */ s,
|
return sdpa_full_self_attention_nax(
|
||||||
/* metal::Device& d = */ d,
|
/* const Stream& s = */ s,
|
||||||
/* const array& q = */ q,
|
/* metal::Device& d = */ d,
|
||||||
/* const array& k = */ k,
|
/* const array& q = */ q,
|
||||||
/* const array& v = */ v,
|
/* const array& k = */ k,
|
||||||
/* const float scale = */ scale,
|
/* const array& v = */ v,
|
||||||
/* array& o = */ o,
|
/* const float scale = */ scale,
|
||||||
/* bool do_causal_ = */ do_causal_,
|
/* array& o = */ o,
|
||||||
/* const std::optional<array>& mask = */ mask,
|
/* bool do_causal_ = */ do_causal_,
|
||||||
/* const std::optional<array>& sinks = */ sinks);
|
/* const std::optional<array>& mask = */ mask,
|
||||||
|
/* const std::optional<array>& sinks = */ sinks);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#endif // MLX_ENABLE_NAX
|
#endif // MLX_ENABLE_NAX
|
||||||
|
|
||||||
|
|||||||
@@ -834,47 +834,54 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
(64, 512, 512, 4, 2, False, "affine"),
|
(64, 512, 512, 4, 2, False, "affine"),
|
||||||
]
|
]
|
||||||
for L, K, D, E, I, transpose, mode in parameters:
|
for L, K, D, E, I, transpose, mode in parameters:
|
||||||
if mode == "mxfp4":
|
with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode):
|
||||||
group_size = 32
|
if mode == "mxfp4":
|
||||||
else:
|
group_size = 32
|
||||||
group_size = 64
|
else:
|
||||||
K, D = (K, D) if transpose else (D, K)
|
group_size = 64
|
||||||
ishape = (L, I)
|
K, D = (K, D) if transpose else (D, K)
|
||||||
xshape = (L, 1, 1, K)
|
ishape = (L, I)
|
||||||
wshape = (E, D, K) if transpose else (E, K, D)
|
xshape = (L, 1, 1, K)
|
||||||
|
wshape = (E, D, K) if transpose else (E, K, D)
|
||||||
|
|
||||||
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
|
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
|
||||||
x = mx.random.normal(xshape) / K**0.5
|
x = mx.random.normal(xshape) / K**0.5
|
||||||
w = mx.random.normal(wshape) / K**0.5
|
w = mx.random.normal(wshape) / K**0.5
|
||||||
w, *wq = quantize(w, group_size=group_size, mode=mode, transpose=transpose)
|
w, *wq = quantize(
|
||||||
|
w, group_size=group_size, mode=mode, transpose=transpose
|
||||||
|
)
|
||||||
|
|
||||||
y1 = mx.gather_mm(x, w, rhs_indices=indices)
|
y1 = mx.gather_mm(x, w, rhs_indices=indices)
|
||||||
y2 = mx.gather_qmm(
|
y2 = mx.gather_qmm(
|
||||||
x,
|
x,
|
||||||
*wq,
|
*wq,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
transpose=transpose,
|
transpose=transpose,
|
||||||
rhs_indices=indices
|
rhs_indices=indices
|
||||||
)
|
)
|
||||||
xs, idx, inv_order = gather_sort(x, indices)
|
xs, idx, inv_order = gather_sort(x, indices)
|
||||||
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
|
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
|
||||||
|
|
||||||
y4 = mx.gather_qmm(
|
y4 = mx.gather_qmm(
|
||||||
xs,
|
xs,
|
||||||
*wq,
|
*wq,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
rhs_indices=idx,
|
rhs_indices=idx,
|
||||||
transpose=transpose,
|
transpose=transpose,
|
||||||
sorted_indices=True
|
sorted_indices=True
|
||||||
)
|
)
|
||||||
y3 = scatter_unsort(y3, inv_order, indices.shape)
|
y3 = scatter_unsort(y3, inv_order, indices.shape)
|
||||||
y4 = scatter_unsort(y4, inv_order, indices.shape)
|
y4 = scatter_unsort(y4, inv_order, indices.shape)
|
||||||
|
|
||||||
self.assertTrue(mx.allclose(y1, y2, atol=1e-5))
|
self.assertLess((y1 - y2).abs().max(), 1e-5)
|
||||||
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
|
self.assertLess((y1 - y3).abs().max(), 1e-5)
|
||||||
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
|
self.assertLess((y1 - y4).abs().max(), 2e-4)
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(y1, y2, atol=1e-5))
|
||||||
|
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
|
||||||
|
self.assertTrue(mx.allclose(y1, y4, atol=2e-4))
|
||||||
|
|
||||||
def test_gather_qmm_grad(self):
|
def test_gather_qmm_grad(self):
|
||||||
def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort):
|
def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort):
|
||||||
|
|||||||
Reference in New Issue
Block a user