[WIP] Add init QMMs (some CI tests failing)

This commit is contained in:
Jagrit Digani
2025-11-18 12:12:59 -08:00
parent a0558cd3af
commit c532eb94c1
11 changed files with 2035 additions and 87 deletions

View File

@@ -269,7 +269,6 @@ std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
inline bool is_nax_available() { inline bool is_nax_available() {
static bool is_nax_available_ = static bool is_nax_available_ =
/* __builtin_available(macOS 26.2, *) && */
metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17; metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17;
return is_nax_available_; return is_nax_available_;
} }

View File

@@ -137,6 +137,8 @@ if(MLX_ENABLE_NAX)
build_kernel(steel/gemm/kernels/steel_gemm_fused_nax ${STEEL_NAX_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_fused_nax ${STEEL_NAX_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS})
build_kernel(quantized_nax quantized_nax.h ${STEEL_NAX_HEADERS})
set(STEEL_NAX_ATTN_HEADERS set(STEEL_NAX_ATTN_HEADERS
steel/defines.h steel/utils.h steel/attn/nax.h steel/utils/type_traits.h steel/defines.h steel/utils.h steel/attn/nax.h steel/utils/type_traits.h
steel/utils/integral_constant.h) steel/utils/integral_constant.h)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,105 @@
// Copyright © 2023-2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
#include "mlx/backend/metal/kernels/steel/gemm/loader.h"
#include "mlx/backend/metal/kernels/quantized_nax.h"
#define instantiate_quantized(name, type, group_size, bits, bm, bn, bk, wm, wn) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits, \
name, \
type, \
group_size, \
bits, bm, bk, bn, wm, wn)
#define instantiate_quantized_batched(name, type, group_size, bits, bm, bn, bk, wm, wn, batched) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_batch_" #batched, \
name, \
type, \
group_size, \
bits, \
batched, bm, bk, bn, wm, wn)
#define instantiate_quantized_aligned(name, type, group_size, bits, bm, bn, bk, wm, wn, aligned) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned, \
name, \
type, \
group_size, \
bits, \
aligned, bm, bk, bn, wm, wn)
#define instantiate_quantized_aligned_batched(name, type, group_size, bits, bm, bn, bk, wm, wn, aligned, batched) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned "_batch_" #batched, \
name, \
type, \
group_size, \
bits, \
aligned, \
batched, bm, bk, bn, wm, wn)
#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
func, \
type, \
group_size, \
bits, \
bm, \
bn, \
bk, \
wm, \
wn, \
transpose)
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
instantiate_quantized_batched(name, type, group_size, bits, 64, 64, 64, 2, 2, 1) \
instantiate_quantized_batched(name, type, group_size, bits, 64, 64, 64, 2, 2, 0)
#define instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_batched_wrap(affine_qmm_n_nax, type, group_size, bits)
#define instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized(affine_gather_qmm_n_nax, type, group_size, bits, 64, 64, 64, 2, 2)
#define instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_aligned(affine_gather_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true) \
instantiate_quantized_aligned(affine_gather_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false) \
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true, 1) \
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true, 0) \
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false, 1) \
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false, 0)
#define instantiate_quantized_all_rhs(type, group_size, bits) \
instantiate_gather_qmm_rhs(affine_gather_qmm_rhs_nax, affine_gather_qmm_rhs_nax_nt, type, group_size, bits, 64, 64, 64, 2, 2, true) \
instantiate_gather_qmm_rhs(affine_gather_qmm_rhs_nax, affine_gather_qmm_rhs_nax_nn, type, group_size, bits, 64, 64, 64, 2, 2, false)
#define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_all_rhs(type, group_size, bits)
#define instantiate_quantized_types(group_size, bits) \
instantiate_quantized_funcs(float, group_size, bits) \
instantiate_quantized_funcs(float16_t, group_size, bits) \
instantiate_quantized_funcs(bfloat16_t, group_size, bits)
#define instantiate_quantized_groups(bits) \
instantiate_quantized_types(128, bits) \
instantiate_quantized_types(64, bits)
#define instantiate_quantized_all() \
instantiate_quantized_groups(2) \
instantiate_quantized_groups(3) \
instantiate_quantized_groups(4) \
instantiate_quantized_groups(5) \
instantiate_quantized_groups(6) \
instantiate_quantized_groups(8)
instantiate_quantized_all() // clang-format on

View File

@@ -178,7 +178,7 @@ template <
} }
const bool is_last_bq = int(tid.x) == (params->NQ_aligned); const bool is_last_bq = int(tid.x) == (params->NQ_aligned);
const bool is_last_tq = int(simd_group_id) >= (params->qL_rem / UQ); // const bool is_last_tq = int(simd_group_id) >= (params->qL_rem / UQ);
const bool is_last_q = is_last_bq; const bool is_last_q = is_last_bq;
const short lim_rows_q = params->qL_rem - (tm + sm); const short lim_rows_q = params->qL_rem - (tm + sm);

View File

@@ -39,9 +39,6 @@ struct BaseNAXFrag {
kElemRows * kElemCols == kElemsPerFrag, kElemRows * kElemCols == kElemsPerFrag,
"MMAFrag shape is not consistent with MMAFrag size"); "MMAFrag shape is not consistent with MMAFrag size");
template <typename U>
using dtype_mat_t = typename metal::simdgroup_matrix<U, kFragRows, kFragCols>;
template <typename U> template <typename U>
using dtype_frag_t = typename metal::vec<U, kElemsPerFrag>; using dtype_frag_t = typename metal::vec<U, kElemsPerFrag>;

View File

@@ -40,9 +40,6 @@ struct BaseNAXFrag {
kElemRows * kElemCols == kElemsPerFrag, kElemRows * kElemCols == kElemsPerFrag,
"MMAFrag shape is not consistent with MMAFrag size"); "MMAFrag shape is not consistent with MMAFrag size");
template <typename U>
using dtype_mat_t = typename metal::simdgroup_matrix<U, kFragRows, kFragCols>;
template <typename U> template <typename U>
using dtype_frag_t = typename metal::vec<U, kElemsPerFrag>; using dtype_frag_t = typename metal::vec<U, kElemsPerFrag>;

View File

@@ -359,33 +359,35 @@ void steel_matmul_regular_axpby(
float beta /* = 0.0f */) { float beta /* = 0.0f */) {
#ifdef MLX_ENABLE_NAX #ifdef MLX_ENABLE_NAX
if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) && if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
(a.dtype() != float32 || env::enable_tf32())) { if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) &&
return steel_matmul_regular_axpby_nax<CHECK_AB>( (a.dtype() != float32 || env::enable_tf32())) {
/* const Stream& s = */ s, return steel_matmul_regular_axpby_nax<CHECK_AB>(
/* metal::Device& d = */ d, /* const Stream& s = */ s,
/* const array& a = */ a, /* metal::Device& d = */ d,
/* const array& b = */ b, /* const array& a = */ a,
/* const array& c = */ c, /* const array& b = */ b,
/* array& out = */ out, /* const array& c = */ c,
/* int M = */ M, /* array& out = */ out,
/* int N = */ N, /* int M = */ M,
/* int K = */ K, /* int N = */ N,
/* int batch_size_out = */ batch_size_out, /* int K = */ K,
/* int lda = */ lda, /* int batch_size_out = */ batch_size_out,
/* int ldb = */ ldb, /* int lda = */ lda,
/* int ldd = */ ldd, /* int ldb = */ ldb,
/* bool transpose_a = */ transpose_a, /* int ldd = */ ldd,
/* bool transpose_b = */ transpose_b, /* bool transpose_a = */ transpose_a,
/* std::vector<array>& copies = */ copies, /* bool transpose_b = */ transpose_b,
/* Shape batch_shape = */ batch_shape, /* std::vector<array>& copies = */ copies,
/* Strides batch_strides = */ batch_strides, /* Shape batch_shape = */ batch_shape,
/* int64_t A_batch_stride = */ A_batch_stride, /* Strides batch_strides = */ batch_strides,
/* int64_t B_batch_stride = */ B_batch_stride, /* int64_t A_batch_stride = */ A_batch_stride,
/* int64_t matrix_stride_out = */ matrix_stride_out, /* int64_t B_batch_stride = */ B_batch_stride,
/* int64_t C_batch_stride = */ C_batch_stride, /* int64_t matrix_stride_out = */ matrix_stride_out,
/* float alpha = */ alpha, /* int64_t C_batch_stride = */ C_batch_stride,
/* float beta = */ beta); /* float alpha = */ alpha,
/* float beta = */ beta);
}
} }
#endif // MLX_ENABLE_NAX #endif // MLX_ENABLE_NAX
@@ -2196,8 +2198,11 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
if (M == 1 && right_sorted_ == true) { if (M == 1 && right_sorted_ == true) {
#ifdef MLX_ENABLE_NAX #ifdef MLX_ENABLE_NAX
if (metal::is_nax_available() && a.dtype() != float32) { if (__builtin_available(
return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s); macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && a.dtype() != float32) {
return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s);
}
} }
#endif // MLX_ENABLE_NAX #endif // MLX_ENABLE_NAX

View File

@@ -451,6 +451,96 @@ void qvm(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
#ifdef MLX_ENABLE_NAX
void qmm_nax(
const array& x,
const array& w,
const array& scales,
const std::optional<array>& biases,
array& out,
bool transpose,
int group_size,
int bits,
int M,
int N,
int K,
metal::Device& d,
const Stream& s,
const std::string& mode) {
int B = out.size() / M / N;
int wm = 2;
int wn = 2;
int bm = 64;
int bn = 64;
int bk = 64;
MTL::Size group_dims(32, wn, wm);
MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B);
std::string kname;
kname.reserve(64);
bool aligned = N % 64 == 0;
bool batched = B > 1;
std::string type_string = get_type_string(x.dtype());
concatenate(
kname,
mode + (transpose ? "_qmm_t_nax_" : "_qmm_n_nax_"),
type_string,
"_gs_",
group_size,
"_b_",
bits,
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn,
transpose ? (aligned ? "_alN_true" : "_alN_false") : "",
batched ? "_batch_1" : "_batch_0");
std::string template_def;
MTL::ComputePipelineState* kernel;
if (transpose) {
kernel = get_quantized_kernel_wrapped(
d,
kname,
"qmm_t_nax",
mode,
type_string,
group_size,
bits,
aligned,
batched);
} else {
kernel = get_quantized_kernel_wrapped(
d, kname, "qmm_n_nax", mode, type_string, group_size, bits, batched);
}
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
compute_encoder.set_bytes(M, c++);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
#endif // MLX_ENABLE_NAX
void qmm( void qmm(
const array& x, const array& x,
const array& w, const array& w,
@@ -466,6 +556,31 @@ void qmm(
metal::Device& d, metal::Device& d,
const Stream& s, const Stream& s,
const std::string& mode) { const std::string& mode) {
#ifdef MLX_ENABLE_NAX
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && mode == "affine" && (group_size >= 64) &&
transpose && (M % 64 == 0) && (N % 64 == 0) && (K % 64 == 0)) {
return qmm_nax(
/* const array& x = */ x,
/* const array& w = */ w,
/* const array& scales = */ scales,
/* const std::optional<array>& biases = */ biases,
/* array& out = */ out,
/* bool transpose = */ transpose,
/* int group_size = */ group_size,
/* int bits = */ bits,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* metal::Device& d = */ d,
/* const Stream& s = */ s,
/* const std::string& mode = */ mode);
}
}
#endif // MLX_ENABLE_NAX
int B = out.size() / M / N; int B = out.size() / M / N;
int wm = 2; int wm = 2;
@@ -719,6 +834,141 @@ void gather_qvm(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
#ifdef MLX_ENABLE_NAX
void gather_qmm_rhs_nax(
const array& x_,
const array& w_,
const array& scales_,
const std::optional<array>& biases_,
const array& indices_,
array& out,
bool transpose,
int group_size,
int bits,
int M,
int N,
int K,
metal::Device& d,
const Stream& s,
const std::string mode) {
// Start by normalizing the indices
array indices = ensure_row_contiguous(indices_, d, s);
// Broadcast x with indices. If we are here that means lhs_indices were not
// provided so the lhs_indices are implied to be the shape of x broadcasted
// with rhs_indices. We need only broadcast x and copy it as if applying the
// lhs_indices.
auto broadcast_with_indices = [&d, &s, &indices](const array& x) {
if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) {
return ensure_row_contiguous(x, d, s);
}
auto x_shape = indices.shape();
x_shape.push_back(x.shape(-2));
x_shape.push_back(x.shape(-1));
array new_x(std::move(x_shape), x.dtype(), nullptr, {});
broadcast(x, new_x);
return ensure_row_contiguous(new_x, d, s);
};
// Normalize the input arrays
array x = broadcast_with_indices(x_);
array w = ensure_row_contiguous(w_, d, s);
array scales = ensure_row_contiguous(scales_, d, s);
// TODO: Tune the block sizes
int bm = 64, bn = 64, bk = 64;
int wm = 2, wn = 2;
const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0;
// Make the kernel name
std::string kname;
kname.reserve(64);
std::string type_string = get_type_string(x.dtype());
concatenate(
kname,
mode +
(transpose ? "_gather_qmm_rhs_nax_nt_" : "_gather_qmm_rhs_nax_nn_"),
type_string,
"_gs_",
group_size,
"_b_",
bits,
"_bm_",
bm,
"_bn_",
bn,
"_bk_",
bk,
"_wm_",
wm,
"_wn_",
wn);
metal::MTLFCList func_consts = {
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202},
};
// And the kernel hash that includes the function constants
std::string hash_name;
hash_name.reserve(128);
concatenate(
hash_name,
kname,
"_align_M_",
align_M ? 't' : 'n',
"_align_N_",
align_N ? 't' : 'n',
"_align_K_",
align_K ? 't' : 'n');
// Get and set the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_gather_qmm_kernel(
d,
kname,
hash_name,
func_consts,
x,
group_size,
bits,
mode,
bm,
bn,
bk,
wm,
wn,
transpose);
compute_encoder.set_compute_pipeline_state(kernel);
MTL::Size group_dims(32, wn, wm);
MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, 1);
int c = 0;
compute_encoder.set_input_array(x, c++);
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases_) {
array biases = ensure_row_contiguous(*biases_, d, s);
compute_encoder.set_input_array(biases, c++);
}
compute_encoder.set_input_array(indices, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(M, c++);
compute_encoder.set_bytes(N, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
#endif // MLX_ENABLE_NAX
void gather_qmm_rhs( void gather_qmm_rhs(
const array& x_, const array& x_,
const array& w_, const array& w_,
@@ -735,6 +985,31 @@ void gather_qmm_rhs(
metal::Device& d, metal::Device& d,
const Stream& s, const Stream& s,
const std::string mode) { const std::string mode) {
#ifdef MLX_ENABLE_NAX
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && mode == "affine" && (group_size >= 64)) {
return gather_qmm_rhs_nax(
/* const array& x_ = */ x_,
/* const array& w_ = */ w_,
/* const array& scales_ = */ scales_,
/* const std::optional<array>& biases_ = */ biases_,
/* const array& indices_ = */ indices_,
/* array& out = */ out,
/* bool transpose = */ transpose,
/* int group_size = */ group_size,
/* int bits = */ bits,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* metal::Device& d = */ d,
/* const Stream& s = */ s,
/* const std::string mode = */ mode);
}
}
#endif // MLX_ENABLE_NAX
// Start by normalizing the indices // Start by normalizing the indices
array indices = ensure_row_contiguous(indices_, d, s); array indices = ensure_row_contiguous(indices_, d, s);

View File

@@ -164,19 +164,21 @@ void sdpa_full_self_attention_metal(
const std::optional<array>& mask, const std::optional<array>& mask,
const std::optional<array>& sinks) { const std::optional<array>& sinks) {
#ifdef MLX_ENABLE_NAX #ifdef MLX_ENABLE_NAX
if (metal::is_nax_available() && q.shape(3) != 80 && if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
(q.dtype() != float32 || env::enable_tf32())) { if (metal::is_nax_available() && q.shape(3) != 80 &&
return sdpa_full_self_attention_nax( (q.dtype() != float32 || env::enable_tf32())) {
/* const Stream& s = */ s, return sdpa_full_self_attention_nax(
/* metal::Device& d = */ d, /* const Stream& s = */ s,
/* const array& q = */ q, /* metal::Device& d = */ d,
/* const array& k = */ k, /* const array& q = */ q,
/* const array& v = */ v, /* const array& k = */ k,
/* const float scale = */ scale, /* const array& v = */ v,
/* array& o = */ o, /* const float scale = */ scale,
/* bool do_causal_ = */ do_causal_, /* array& o = */ o,
/* const std::optional<array>& mask = */ mask, /* bool do_causal_ = */ do_causal_,
/* const std::optional<array>& sinks = */ sinks); /* const std::optional<array>& mask = */ mask,
/* const std::optional<array>& sinks = */ sinks);
}
} }
#endif // MLX_ENABLE_NAX #endif // MLX_ENABLE_NAX

View File

@@ -834,47 +834,54 @@ class TestQuantized(mlx_tests.MLXTestCase):
(64, 512, 512, 4, 2, False, "affine"), (64, 512, 512, 4, 2, False, "affine"),
] ]
for L, K, D, E, I, transpose, mode in parameters: for L, K, D, E, I, transpose, mode in parameters:
if mode == "mxfp4": with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode):
group_size = 32 if mode == "mxfp4":
else: group_size = 32
group_size = 64 else:
K, D = (K, D) if transpose else (D, K) group_size = 64
ishape = (L, I) K, D = (K, D) if transpose else (D, K)
xshape = (L, 1, 1, K) ishape = (L, I)
wshape = (E, D, K) if transpose else (E, K, D) xshape = (L, 1, 1, K)
wshape = (E, D, K) if transpose else (E, K, D)
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32) indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
x = mx.random.normal(xshape) / K**0.5 x = mx.random.normal(xshape) / K**0.5
w = mx.random.normal(wshape) / K**0.5 w = mx.random.normal(wshape) / K**0.5
w, *wq = quantize(w, group_size=group_size, mode=mode, transpose=transpose) w, *wq = quantize(
w, group_size=group_size, mode=mode, transpose=transpose
)
y1 = mx.gather_mm(x, w, rhs_indices=indices) y1 = mx.gather_mm(x, w, rhs_indices=indices)
y2 = mx.gather_qmm( y2 = mx.gather_qmm(
x, x,
*wq, *wq,
group_size=group_size, group_size=group_size,
mode=mode, mode=mode,
transpose=transpose, transpose=transpose,
rhs_indices=indices rhs_indices=indices
) )
xs, idx, inv_order = gather_sort(x, indices) xs, idx, inv_order = gather_sort(x, indices)
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True) y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
y4 = mx.gather_qmm( y4 = mx.gather_qmm(
xs, xs,
*wq, *wq,
group_size=group_size, group_size=group_size,
mode=mode, mode=mode,
rhs_indices=idx, rhs_indices=idx,
transpose=transpose, transpose=transpose,
sorted_indices=True sorted_indices=True
) )
y3 = scatter_unsort(y3, inv_order, indices.shape) y3 = scatter_unsort(y3, inv_order, indices.shape)
y4 = scatter_unsort(y4, inv_order, indices.shape) y4 = scatter_unsort(y4, inv_order, indices.shape)
self.assertTrue(mx.allclose(y1, y2, atol=1e-5)) self.assertLess((y1 - y2).abs().max(), 1e-5)
self.assertTrue(mx.allclose(y1, y3, atol=1e-5)) self.assertLess((y1 - y3).abs().max(), 1e-5)
self.assertTrue(mx.allclose(y1, y4, atol=1e-5)) self.assertLess((y1 - y4).abs().max(), 2e-4)
self.assertTrue(mx.allclose(y1, y2, atol=1e-5))
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
self.assertTrue(mx.allclose(y1, y4, atol=2e-4))
def test_gather_qmm_grad(self): def test_gather_qmm_grad(self):
def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort): def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort):