From c532eb94c11d1257c352b4b3a711a7b5cdeaff1c Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Tue, 18 Nov 2025 12:12:59 -0800 Subject: [PATCH] [WIP] Add init QMMs (some CI tests failing) --- mlx/backend/metal/device.h | 1 - mlx/backend/metal/kernels/CMakeLists.txt | 2 + mlx/backend/metal/kernels/quantized_nax.h | 1559 +++++++++++++++++ mlx/backend/metal/kernels/quantized_nax.metal | 105 ++ .../steel/attn/kernels/steel_attention_nax.h | 2 +- mlx/backend/metal/kernels/steel/attn/nax.h | 3 - mlx/backend/metal/kernels/steel/gemm/nax.h | 3 - mlx/backend/metal/matmul.cpp | 63 +- mlx/backend/metal/quantized.cpp | 275 +++ .../metal/scaled_dot_product_attention.cpp | 28 +- python/tests/test_quantized.py | 81 +- 11 files changed, 2035 insertions(+), 87 deletions(-) create mode 100644 mlx/backend/metal/kernels/quantized_nax.h create mode 100644 mlx/backend/metal/kernels/quantized_nax.metal diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 0b08607be..746fbc088 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -269,7 +269,6 @@ std::unique_ptr> new_scoped_memory_pool(); inline bool is_nax_available() { static bool is_nax_available_ = - /* __builtin_available(macOS 26.2, *) && */ metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17; return is_nax_available_; } diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index d921e557f..514f6038c 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -137,6 +137,8 @@ if(MLX_ENABLE_NAX) build_kernel(steel/gemm/kernels/steel_gemm_fused_nax ${STEEL_NAX_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS}) + build_kernel(quantized_nax quantized_nax.h ${STEEL_NAX_HEADERS}) + set(STEEL_NAX_ATTN_HEADERS steel/defines.h steel/utils.h steel/attn/nax.h steel/utils/type_traits.h steel/utils/integral_constant.h) diff --git a/mlx/backend/metal/kernels/quantized_nax.h b/mlx/backend/metal/kernels/quantized_nax.h new file mode 100644 index 000000000..57d0a7211 --- /dev/null +++ b/mlx/backend/metal/kernels/quantized_nax.h @@ -0,0 +1,1559 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include + +using namespace metal; +using namespace mlx::steel; + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +using namespace metal; + +#define MLX_MTL_CONST static constant constexpr const + +MLX_MTL_CONST int SIMD_SIZE = 32; +MLX_MTL_CONST int QUAD_SIZE = 4; + +template +inline constexpr short get_pack_factor() { + return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); +} + +template +inline constexpr short get_bytes_per_pack() { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); +} + +template +inline U load_vector(const device T* x, thread U* x_thread) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U sum = 0; + + if (bits == 2) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 4.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 64.0f; + } + } + + else if (bits == 3) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 8.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 2.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 128.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 32.0f; + } + } + + else if (bits == 4) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 16.0f; + x_thread[i + 2] = x[i + 2] / 256.0f; + x_thread[i + 3] = x[i + 3] / 4096.0f; + } + } + + else if (bits == 5) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + + else if (bits == 6) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 64.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 4.0f; + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + sum += x[i]; + x_thread[i] = x[i]; + } + } + + return sum; +} + +template +inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U sum = 0; + + if (bits == 2) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 4.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 64.0f; + } + } + + else if (bits == 3) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 8.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 2.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 128.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 32.0f; + } + } + + else if (bits == 4) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 16.0f; + x_thread[i + 2] = x[i + 2] / 256.0f; + x_thread[i + 3] = x[i + 3] / 4096.0f; + } + } + + else if (bits == 5) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + + else if (bits == 6) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 64.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 4.0f; + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + sum += x[i]; + x_thread[i] = x[i]; + } + } + + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; + } + + return sum; +} + +template +inline U qdot( + const device uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U accum = 0; + + if (bits == 2) { + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * (w[i] & 0x03) + + x_thread[4 * i + 1] * (w[i] & 0x0c) + + x_thread[4 * i + 2] * (w[i] & 0x30) + + x_thread[4 * i + 3] * (w[i] & 0xc0)); + } + } + + else if (bits == 3) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 3 * i; + + accum += (w[0] & 0x07) * x_thread[0]; + accum += (w[0] & 0x38) * x_thread[1]; + accum += (w[0] & 0xc0) * x_thread[2]; + accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); + + accum += (w[1] & 0x0e) * x_thread[3]; + accum += (w[1] & 0x70) * x_thread[4]; + accum += (w[1] & 0x80) * x_thread[5]; + accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); + + accum += (w[2] & 0x1c) * x_thread[6]; + accum += (w[2] & 0xe0) * x_thread[7]; + } + } + + else if (bits == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); + } + } + + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + + else if (bits == 6) { + for (int i = 0; i < (values_per_thread / 4); i++) { + x_thread += 4 * i; + w += 3 * i; + + accum += (w[0] & 0x3f) * x_thread[0]; + + accum += (w[0] & 0xc0) * x_thread[1]; + accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); + + accum += (w[1] & 0xf0) * x_thread[2]; + accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); + + accum += (w[2] & 0xfc) * x_thread[3]; + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + accum += x_thread[i] * w[i]; + } + } + + return scale * accum + sum * bias; +} + +template +inline U qdot_safe( + const device uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum, + int N) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U accum = 0; + + if (bits == 2) { + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * (w[i] & 0x03) + + x_thread[4 * i + 1] * (w[i] & 0x0c) + + x_thread[4 * i + 2] * (w[i] & 0x30) + + x_thread[4 * i + 3] * (w[i] & 0xc0)); + } + } + + else if (bits == 3) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 3 * i; + + accum += (w[0] & 0x07) * x_thread[0]; + accum += (w[0] & 0x38) * x_thread[1]; + accum += (w[0] & 0xc0) * x_thread[2]; + accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); + + accum += (w[1] & 0x0e) * x_thread[3]; + accum += (w[1] & 0x70) * x_thread[4]; + accum += (w[1] & 0x80) * x_thread[5]; + accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); + + accum += (w[2] & 0x1c) * x_thread[6]; + accum += (w[2] & 0xe0) * x_thread[7]; + } + } + + else if (bits == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); + } + } + + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + + else if (bits == 6) { + for (int i = 0; i < (N / 4); i++) { + x_thread += 4 * i; + w += 3 * i; + + accum += (w[0] & 0x3f) * x_thread[0]; + + accum += (w[0] & 0xc0) * x_thread[1]; + accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); + + accum += (w[1] & 0xf0) * x_thread[2]; + accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); + + accum += (w[2] & 0xfc) * x_thread[3]; + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + accum += x_thread[i] * w[i]; + } + } + + return scale * accum + sum * bias; +} + +template +inline void +qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + if (bits == 2) { + U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; + for (int i = 0; i < (values_per_thread / 4); i++) { + result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias); + result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias); + result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias); + result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias); + } + } + + else if (bits == 3) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t w0 = w[3 * i]; + uint8_t w1 = w[3 * i + 1]; + uint8_t w2 = w[3 * i + 2]; + + result[8 * i] += x * ((w0 & 0x7) * scale + bias); + result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias); + result[8 * i + 2] += + x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias); + result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias); + result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias); + result[8 * i + 5] += + x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias); + result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias); + result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias); + } + } + + else if (bits == 4) { + U s[2] = {scale, scale / 16.0f}; + for (int i = 0; i < (values_per_thread / 2); i++) { + result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); + result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); + } + } + + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t w0 = w[5 * i]; + uint8_t w1 = w[5 * i + 1]; + uint8_t w2 = w[5 * i + 2]; + uint8_t w3 = w[5 * i + 3]; + uint8_t w4 = w[5 * i + 4]; + result[8 * i] += x * ((w0 & 0x1f) * scale + bias); + result[8 * i + 1] += + x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias); + result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias); + result[8 * i + 3] += + x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias); + result[8 * i + 4] += + x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias); + result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias); + result[8 * i + 6] += + x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias); + result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias); + } + } + + else if (bits == 6) { + for (int i = 0; i < (values_per_thread / 4); i++) { + uint8_t w0 = w[3 * i]; + uint8_t w1 = w[3 * i + 1]; + uint8_t w2 = w[3 * i + 2]; + + result[4 * i] += x * ((w0 & 0x3f) * scale + bias); + result[4 * i + 1] += + x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias); + result[4 * i + 2] += + x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias); + result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias); + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + result[i] += x * (scale * w[i] + bias); + } + } +} + +template +inline void +dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + if (bits == 2) { + U s[4] = { + scale, + scale / static_cast(4.0f), + scale / static_cast(16.0f), + scale / static_cast(64.0f)}; + for (int i = 0; i < (N / 4); i++) { + w_local[4 * i] = s[0] * (w[i] & 0x03) + bias; + w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias; + w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias; + w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias; + } + } + + else if (bits == 3) { + for (int i = 0; i < (N / 8); i++) { + w_local += 8 * i; + w += 3 * i; + + w_local[0] = (w[0] & 0x7) * scale + bias; + w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias; + w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias; + w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias; + w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias; + w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; + w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias; + w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias; + } + } + + else if (bits == 4) { + U s[2] = {scale, scale / static_cast(16.0f)}; + for (int i = 0; i < (N / 2); i++) { + w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias; + w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias; + } + } + + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + w_local += 8 * i; + w += 5 * i; + + w_local[0] = (w[0] & 0x1f) * scale + bias; + w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; + w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias; + w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; + w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; + w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias; + w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; + w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias; + } + } + + else if (bits == 6) { + for (int i = 0; i < (N / 4); i++) { + w_local += 4 * i; + w += 3 * i; + w_local[0] = (w[0] & 0x3f) * scale + bias; + w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; + w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; + w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias; + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + w_local[i] = scale * w[i] + bias; + } + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short group_size, + short bits> +struct QuantizedBlockLoader { + static_assert( + BCOLS <= group_size, + "The group size should be larger than the columns"); + static_assert( + group_size % BCOLS == 0, + "The group size should be divisible by the columns"); + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + MLX_MTL_CONST short pack_factor = get_pack_factor(); + MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); + MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; + MLX_MTL_CONST short n_reads = + (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; + MLX_MTL_CONST short group_steps = group_size / BCOLS; + + const int src_ld; + const int tile_stride; + short group_step_cnt; + const int group_stride; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint8_t* src; + const device T* scales; + const device T* biases; + + QuantizedBlockLoader( + const device uint8_t* src_, + const device T* scales_, + const device T* biases_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride( + reduction_dim ? BCOLS_PACKED * bytes_per_pack + : BROWS * src_ld * bytes_per_pack / pack_factor), + group_step_cnt(0), + group_stride(BROWS * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + dst(dst_ + bi * dst_ld + bj * pack_factor), + src(src_ + bi * src_ld * bytes_per_pack / pack_factor + + bj * bytes_per_pack), + scales(scales_ + bi * src_ld / group_size), + biases(biases_ + bi * src_ld / group_size) {} + + void load_unsafe() const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + src + i * bytes_per_pack, scale, bias, dst + i * pack_factor); + } + } + + void load_safe(short2 src_tile_dim) const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + if (reduction_dim == 1 && bi >= src_tile_dim.x) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + if (reduction_dim == 0 && bi >= src_tile_dim.y) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + (device uint8_t*)(src + i * bytes_per_pack), + scale, + bias, + dst + i * pack_factor); + } + } + + void next() { + src += tile_stride; + if (reduction_dim == 1) { + if (group_steps > 1) { + group_step_cnt++; + if (group_step_cnt == group_steps) { + group_step_cnt = 0; + scales++; + biases++; + } + } else { + scales++; + biases++; + } + } else { + scales += group_stride; + biases += group_stride; + } + } +}; + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device T*& scales, + const device T*& biases, + device T*& y, + int output_stride, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int64_t* b_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx = tid.z; + uint32_t w_idx = tid.z; + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + biases += w_idx * b_strides[0]; + } else { + ulong3 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + biases += idx.z; + } + y += tid.z * output_stride; +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device T*& scales, + const device T*& biases, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T*& y, + int output_stride, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int64_t* b_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx; + uint32_t w_idx; + if (batch_ndims == 1) { + x_idx = lhs_indices[tid.z * lhs_strides[0]]; + w_idx = rhs_indices[tid.z * rhs_strides[0]]; + } else { + ulong2 idx = elem_to_loc_broadcast( + tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); + x_idx = lhs_indices[idx.x]; + w_idx = rhs_indices[idx.y]; + } + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + biases += w_idx * b_strides[0]; + } else { + ulong3 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + biases += idx.z; + } + y += tid.z * output_stride; +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2> +METAL_FUNC void qmm_t_nax_tgp_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + using loader_w_t = QuantizedBlockLoader< + T, + BN, + BK, + BK_padded, + 1, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + // Set the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + auto wl = (const device uint8_t*)w; + + x += y_row * static_cast(K); + wl += y_col * K_w; + scales += y_col * K_g; + biases += y_col * K_g; + y += y_row * static_cast(N) + y_col; + + // Make the x loader and mma operation + loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + const short tm = SM * (simd_gid / WN); + const short tn = SN * (simd_gid % WN); + + const short lda_tgp = BK_padded; + const short ldb_tgp = BK_padded; + + constexpr bool transpose_a = false; + constexpr bool transpose_b = true; + + const short sgp_sm = min(SM, short(M - (y_row + tm))); + const bool is_unaligned_sm = (sgp_sm != SM); + + const short sgp_sn = aligned_N ? SN : min(SN, short(N - (y_col + tn))); + + const short tgp_bn = aligned_N ? BN : min(BN, int(N - (y_col))); + const bool is_unaligned_bn = aligned_N ? false : (tgp_bn != BN); + + using AccumType = float; + + using ASubTile = NAXSubTile; + using BSubTile = NAXSubTile; + using DSubTile = NAXSubTile; + + NAXTile Dtile; + + Dtile.clear(); + + x += tm * K; + + dispatch_bool(!is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(aligned_N || !is_unaligned_bn, [&](auto kAlignedN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if constexpr (kAlignedN.value) { + loader_w.load_unsafe(); + } else { + loader_w.load_safe(short2(BK, tgp_bn)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + +#pragma clang loop unroll(disable) + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + if constexpr (kAlignedM.value) { + Atile.load(x + kk1, K); + } else { + Atile.load_safe(x + kk1, K, short2(SK, sgp_sm)); + } + + Btile.template load(Ws + tn * ldb_tgp + kk1); + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + x += BK; + loader_w.next(); + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + + if constexpr (kAlignedM.value && kAlignedN.value) { + Dtile.store(y + tm * N + tn, N); + } else if (kAlignedM.value && sgp_sn == SN) { + Dtile.store(y + tm * N + tn, N); + } else { + Dtile.store_safe(y + tm * N + tn, N, short2(sgp_sn, sgp_sm)); + } + }); + }); +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2> +METAL_FUNC void qmm_n_nax_tgp_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + using loader_w_t = QuantizedBlockLoader< + T, + BK, + BN, + BN_padded, + 0, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + // Set the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + auto wl = (const device uint8_t*)w; + + x += y_row * static_cast(K); + wl += y_col * K_w; + scales += y_col * K_g; + biases += y_col * K_g; + y += y_row * static_cast(N) + y_col; + + // Make the x loader and mma operation + const short num_els = min(BM, M - y_row); + const short num_outs = min(BN, N - y_col); + loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + const short tm = SM * (simd_gid / WN); + const short tn = SN * (simd_gid % WN); + + const short ldb_tgp = BN_padded; + + constexpr bool transpose_a = false; + constexpr bool transpose_b = false; + + using AccumType = float; + + using ASubTile = NAXSubTile; + using BSubTile = NAXSubTile; + using DSubTile = NAXSubTile; + + NAXTile Dtile; + + Dtile.clear(); + + x += tm * K; + + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + +#pragma clang loop unroll(disable) + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + Atile.load(x + kk1, K); + Btile.template load(Ws + tn + kk1 * ldb_tgp); + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + x += BK; + loader_w.next(); + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + + Dtile.store(y + tm * N + tn, N); +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const bool batched, + const int BM = 64, + const int BK = 32, + const int BN = 64, + const int WM = 2, + const int WN = 2> +[[kernel]] void affine_qmm_t_nax( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& K [[buffer(5)]], + const constant int& N [[buffer(6)]], + const constant int& M [[buffer(7)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant int64_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant int64_t* w_strides [[buffer(13)]], + const constant int64_t* s_strides [[buffer(14)]], + const constant int64_t* b_strides [[buffer(15)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Ws[BN * BK_padded]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + qmm_t_nax_tgp_impl( + w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool batched, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2> +[[kernel]] void affine_qmm_n_nax( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& K [[buffer(5)]], + const constant int& N [[buffer(6)]], + const constant int& M [[buffer(7)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant int64_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant int64_t* w_strides [[buffer(13)]], + const constant int64_t* s_strides [[buffer(14)]], + const constant int64_t* b_strides [[buffer(15)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Ws[BK * BN_padded]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + + qmm_n_nax_tgp_impl( + w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2> +[[kernel]] void affine_gather_qmm_t_nax( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& K [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& M [[buffer(9)]], + const constant int& x_batch_ndims [[buffer(10)]], + const constant int* x_shape [[buffer(11)]], + const constant int64_t* x_strides [[buffer(12)]], + const constant int& w_batch_ndims [[buffer(13)]], + const constant int* w_shape [[buffer(14)]], + const constant int64_t* w_strides [[buffer(15)]], + const constant int64_t* s_strides [[buffer(16)]], + const constant int64_t* b_strides [[buffer(17)]], + const constant int& batch_ndims [[buffer(18)]], + const constant int* batch_shape [[buffer(19)]], + const constant int64_t* lhs_strides [[buffer(20)]], + const constant int64_t* rhs_strides [[buffer(21)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Ws[BN * BK_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmm_t_nax_tgp_impl( + w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2> +[[kernel]] void affine_gather_qmm_n_nax( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& K [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& M [[buffer(9)]], + const constant int& x_batch_ndims [[buffer(10)]], + const constant int* x_shape [[buffer(11)]], + const constant int64_t* x_strides [[buffer(12)]], + const constant int& w_batch_ndims [[buffer(13)]], + const constant int* w_shape [[buffer(14)]], + const constant int64_t* w_strides [[buffer(15)]], + const constant int64_t* s_strides [[buffer(16)]], + const constant int64_t* b_strides [[buffer(17)]], + const constant int& batch_ndims [[buffer(18)]], + const constant int* batch_shape [[buffer(19)]], + const constant int64_t* lhs_strides [[buffer(20)]], + const constant int64_t* rhs_strides [[buffer(21)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Ws[BK * BN_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmm_n_nax_tgp_impl( + w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + int group_size, + int bits, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose> +[[kernel]] void affine_gather_qmm_rhs_nax( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + const device uint32_t* indices [[buffer(4)]], + device T* y [[buffer(5)]], + const constant int& M [[buffer(6)]], + const constant int& N [[buffer(7)]], + const constant int& K [[buffer(8)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + using loader_w_t = QuantizedBlockLoader< + T, + transpose ? BN : BK, + transpose ? BK : BN, + transpose ? BK_padded : BN_padded, + transpose, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; + + // Compute the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int N_w = N * bytes_per_pack / pack_factor; + const int N_g = N / group_size; + const int K_it = K / BK; + const size_t stride_w = transpose ? N * K_w : K * N_w; + const size_t stride_s = transpose ? N * K_g : K * N_g; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + const size_t y_row_long = size_t(y_row); + const size_t y_col_long = size_t(y_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); + const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); + + // Calculate the final tiles in the case that K is not aligned + const int k_remain = K - K_it * BK; + const short2 tile_w = + transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + // Move x and output to the correct block + auto wl = (const device uint8_t*)w; + x += y_row_long * K; + y += y_row_long * N + y_col_long; + wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; + scales += transpose ? y_col_long * K_g : y_col / group_size; + biases += transpose ? y_col_long * K_g : y_col / group_size; + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + const short tm = SM * (simd_group_id / WN); + const short tn = SN * (simd_group_id % WN); + + const short sgp_sm = align_M ? SM : min(SM, short(M - (y_row + tm))); + const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); + const bool is_unaligned_bn = align_N ? false : (tgp_bn != BN); + + using AccumType = float; + + using ASubTile = NAXSubTile; + using BSubTile = NAXSubTile; + using DSubTile = NAXSubTile; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = indices[y_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (indices[y_row + n] != index) { + offset_next = n; + index_next = indices[y_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + NAXTile Dtile; + + Dtile.clear(); + + const device T* xn = x + tm * K; + + // Prepare threadgroup loading operations + thread loader_w_t loader_w( + wl + index * stride_w, + scales + index * stride_s, + biases + index * stride_s, + transpose ? K : N, + Ws, + simd_group_id, + simd_lane_id); + + dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(align_N || !is_unaligned_bn, [&](auto kAlignedN) { + for (int k = 0; k < K_it; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if constexpr (kAlignedN.value) { + loader_w.load_unsafe(); + } else { + loader_w.load_safe( + transpose ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + +#pragma clang loop unroll(disable) + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile + Btile; + + volatile int compiler_barrier; + + if constexpr (kAlignedM.value) { + Atile.load(xn + kk1, K); + } else { + Atile.load_safe(xn + kk1, K, short2(SK, sgp_sm)); + } + + if constexpr (transpose) { + Btile.template load(Ws + tn * BK_padded + kk1); + } else { + Btile.template load(Ws + tn + kk1 * BN_padded); + } + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + xn += BK; + loader_w.next(); + } + + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_w.load_safe(tile_w); + threadgroup_barrier(mem_flags::mem_threadgroup); + +#pragma clang loop unroll(disable) + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile + Btile; + + volatile int compiler_barrier; + + Atile.load_safe(xn + kk1, K, short2((BK - kk1), sgp_sm)); + + if constexpr (transpose) { + Btile.template load(Ws + tn * BK_padded + kk1); + } else { + Btile.template load(Ws + tn + kk1 * BN_padded); + } + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Store results to device memory + if constexpr (kAlignedN.value) { + if ((offset_next - offset) == BM) { + Dtile.store(y + tm * N + tn, N); + } else { + Dtile.store_slice( + y + tm * N + tn, + N, + short2(0, min(int(sgp_sm), max(0, offset - tm))), + short2(BN, min(int(sgp_sm), max(0, offset_next - tm)))); + } + } else { + Dtile.store_slice( + y + tm * N + tn, + N, + short2(0, max(0, offset - tm)), + short2(max(0, tgp_bn - tn), max(0, offset_next - tm))); + } + }); + }); + } +} \ No newline at end of file diff --git a/mlx/backend/metal/kernels/quantized_nax.metal b/mlx/backend/metal/kernels/quantized_nax.metal new file mode 100644 index 000000000..98703a608 --- /dev/null +++ b/mlx/backend/metal/kernels/quantized_nax.metal @@ -0,0 +1,105 @@ +// Copyright © 2023-2024 Apple Inc. + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/steel/gemm/nax.h" +#include "mlx/backend/metal/kernels/steel/gemm/loader.h" +#include "mlx/backend/metal/kernels/quantized_nax.h" + +#define instantiate_quantized(name, type, group_size, bits, bm, bn, bk, wm, wn) \ + instantiate_kernel( \ + #name "_" #type "_gs_" #group_size "_b_" #bits, \ + name, \ + type, \ + group_size, \ + bits, bm, bk, bn, wm, wn) + +#define instantiate_quantized_batched(name, type, group_size, bits, bm, bn, bk, wm, wn, batched) \ + instantiate_kernel( \ + #name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_batch_" #batched, \ + name, \ + type, \ + group_size, \ + bits, \ + batched, bm, bk, bn, wm, wn) + +#define instantiate_quantized_aligned(name, type, group_size, bits, bm, bn, bk, wm, wn, aligned) \ + instantiate_kernel( \ + #name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned, \ + name, \ + type, \ + group_size, \ + bits, \ + aligned, bm, bk, bn, wm, wn) + +#define instantiate_quantized_aligned_batched(name, type, group_size, bits, bm, bn, bk, wm, wn, aligned, batched) \ + instantiate_kernel( \ + #name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned "_batch_" #batched, \ + name, \ + type, \ + group_size, \ + bits, \ + aligned, \ + batched, bm, bk, bn, wm, wn) + +#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \ + instantiate_kernel( \ + #name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \ + func, \ + type, \ + group_size, \ + bits, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + transpose) + +#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \ + instantiate_quantized_batched(name, type, group_size, bits, 64, 64, 64, 2, 2, 1) \ + instantiate_quantized_batched(name, type, group_size, bits, 64, 64, 64, 2, 2, 0) + +#define instantiate_quantized_all_batched(type, group_size, bits) \ + instantiate_quantized_batched_wrap(affine_qmm_n_nax, type, group_size, bits) + + +#define instantiate_quantized_all_single(type, group_size, bits) \ + instantiate_quantized(affine_gather_qmm_n_nax, type, group_size, bits, 64, 64, 64, 2, 2) + +#define instantiate_quantized_all_aligned(type, group_size, bits) \ + instantiate_quantized_aligned(affine_gather_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true) \ + instantiate_quantized_aligned(affine_gather_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false) \ + instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true, 1) \ + instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true, 0) \ + instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false, 1) \ + instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false, 0) + +#define instantiate_quantized_all_rhs(type, group_size, bits) \ + instantiate_gather_qmm_rhs(affine_gather_qmm_rhs_nax, affine_gather_qmm_rhs_nax_nt, type, group_size, bits, 64, 64, 64, 2, 2, true) \ + instantiate_gather_qmm_rhs(affine_gather_qmm_rhs_nax, affine_gather_qmm_rhs_nax_nn, type, group_size, bits, 64, 64, 64, 2, 2, false) + +#define instantiate_quantized_funcs(type, group_size, bits) \ + instantiate_quantized_all_batched(type, group_size, bits) \ + instantiate_quantized_all_aligned(type, group_size, bits) \ + instantiate_quantized_all_rhs(type, group_size, bits) + +#define instantiate_quantized_types(group_size, bits) \ + instantiate_quantized_funcs(float, group_size, bits) \ + instantiate_quantized_funcs(float16_t, group_size, bits) \ + instantiate_quantized_funcs(bfloat16_t, group_size, bits) + +#define instantiate_quantized_groups(bits) \ + instantiate_quantized_types(128, bits) \ + instantiate_quantized_types(64, bits) + +#define instantiate_quantized_all() \ + instantiate_quantized_groups(2) \ + instantiate_quantized_groups(3) \ + instantiate_quantized_groups(4) \ + instantiate_quantized_groups(5) \ + instantiate_quantized_groups(6) \ + instantiate_quantized_groups(8) + +instantiate_quantized_all() // clang-format on diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h index e067ed005..23c2c05f4 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h @@ -178,7 +178,7 @@ template < } const bool is_last_bq = int(tid.x) == (params->NQ_aligned); - const bool is_last_tq = int(simd_group_id) >= (params->qL_rem / UQ); + // const bool is_last_tq = int(simd_group_id) >= (params->qL_rem / UQ); const bool is_last_q = is_last_bq; const short lim_rows_q = params->qL_rem - (tm + sm); diff --git a/mlx/backend/metal/kernels/steel/attn/nax.h b/mlx/backend/metal/kernels/steel/attn/nax.h index 27a2dac59..c8f3ea5ef 100644 --- a/mlx/backend/metal/kernels/steel/attn/nax.h +++ b/mlx/backend/metal/kernels/steel/attn/nax.h @@ -39,9 +39,6 @@ struct BaseNAXFrag { kElemRows * kElemCols == kElemsPerFrag, "MMAFrag shape is not consistent with MMAFrag size"); - template - using dtype_mat_t = typename metal::simdgroup_matrix; - template using dtype_frag_t = typename metal::vec; diff --git a/mlx/backend/metal/kernels/steel/gemm/nax.h b/mlx/backend/metal/kernels/steel/gemm/nax.h index bc57a1657..5839176c2 100644 --- a/mlx/backend/metal/kernels/steel/gemm/nax.h +++ b/mlx/backend/metal/kernels/steel/gemm/nax.h @@ -40,9 +40,6 @@ struct BaseNAXFrag { kElemRows * kElemCols == kElemsPerFrag, "MMAFrag shape is not consistent with MMAFrag size"); - template - using dtype_mat_t = typename metal::simdgroup_matrix; - template using dtype_frag_t = typename metal::vec; diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 9f435f73c..94e9e80b9 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -359,33 +359,35 @@ void steel_matmul_regular_axpby( float beta /* = 0.0f */) { #ifdef MLX_ENABLE_NAX - if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) && - (a.dtype() != float32 || env::enable_tf32())) { - return steel_matmul_regular_axpby_nax( - /* const Stream& s = */ s, - /* metal::Device& d = */ d, - /* const array& a = */ a, - /* const array& b = */ b, - /* const array& c = */ c, - /* array& out = */ out, - /* int M = */ M, - /* int N = */ N, - /* int K = */ K, - /* int batch_size_out = */ batch_size_out, - /* int lda = */ lda, - /* int ldb = */ ldb, - /* int ldd = */ ldd, - /* bool transpose_a = */ transpose_a, - /* bool transpose_b = */ transpose_b, - /* std::vector& copies = */ copies, - /* Shape batch_shape = */ batch_shape, - /* Strides batch_strides = */ batch_strides, - /* int64_t A_batch_stride = */ A_batch_stride, - /* int64_t B_batch_stride = */ B_batch_stride, - /* int64_t matrix_stride_out = */ matrix_stride_out, - /* int64_t C_batch_stride = */ C_batch_stride, - /* float alpha = */ alpha, - /* float beta = */ beta); + if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { + if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) && + (a.dtype() != float32 || env::enable_tf32())) { + return steel_matmul_regular_axpby_nax( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* int ldd = */ ldd, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides batch_strides = */ batch_strides, + /* int64_t A_batch_stride = */ A_batch_stride, + /* int64_t B_batch_stride = */ B_batch_stride, + /* int64_t matrix_stride_out = */ matrix_stride_out, + /* int64_t C_batch_stride = */ C_batch_stride, + /* float alpha = */ alpha, + /* float beta = */ beta); + } } #endif // MLX_ENABLE_NAX @@ -2196,8 +2198,11 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { if (M == 1 && right_sorted_ == true) { #ifdef MLX_ENABLE_NAX - if (metal::is_nax_available() && a.dtype() != float32) { - return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s); + if (__builtin_available( + macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { + if (metal::is_nax_available() && a.dtype() != float32) { + return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s); + } } #endif // MLX_ENABLE_NAX diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index e03e5dca2..15a5ceeb8 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -451,6 +451,96 @@ void qvm( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } +#ifdef MLX_ENABLE_NAX + +void qmm_nax( + const array& x, + const array& w, + const array& scales, + const std::optional& biases, + array& out, + bool transpose, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s, + const std::string& mode) { + int B = out.size() / M / N; + + int wm = 2; + int wn = 2; + int bm = 64; + int bn = 64; + int bk = 64; + MTL::Size group_dims(32, wn, wm); + MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B); + + std::string kname; + kname.reserve(64); + bool aligned = N % 64 == 0; + bool batched = B > 1; + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, + mode + (transpose ? "_qmm_t_nax_" : "_qmm_n_nax_"), + type_string, + "_gs_", + group_size, + "_b_", + bits, + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn, + transpose ? (aligned ? "_alN_true" : "_alN_false") : "", + batched ? "_batch_1" : "_batch_0"); + std::string template_def; + MTL::ComputePipelineState* kernel; + if (transpose) { + kernel = get_quantized_kernel_wrapped( + d, + kname, + "qmm_t_nax", + mode, + type_string, + group_size, + bits, + aligned, + batched); + } else { + kernel = get_quantized_kernel_wrapped( + d, kname, "qmm_n_nax", mode, type_string, group_size, bits, batched); + } + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + int c = 0; + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases) { + compute_encoder.set_input_array(*biases, c++); + } + compute_encoder.set_input_array(x, c++); + compute_encoder.set_output_array(out, c++); + compute_encoder.set_bytes(K, c++); + compute_encoder.set_bytes(N, c++); + compute_encoder.set_bytes(M, c++); + add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +#endif // MLX_ENABLE_NAX + void qmm( const array& x, const array& w, @@ -466,6 +556,31 @@ void qmm( metal::Device& d, const Stream& s, const std::string& mode) { +#ifdef MLX_ENABLE_NAX + + if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { + if (metal::is_nax_available() && mode == "affine" && (group_size >= 64) && + transpose && (M % 64 == 0) && (N % 64 == 0) && (K % 64 == 0)) { + return qmm_nax( + /* const array& x = */ x, + /* const array& w = */ w, + /* const array& scales = */ scales, + /* const std::optional& biases = */ biases, + /* array& out = */ out, + /* bool transpose = */ transpose, + /* int group_size = */ group_size, + /* int bits = */ bits, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* metal::Device& d = */ d, + /* const Stream& s = */ s, + /* const std::string& mode = */ mode); + } + } + +#endif // MLX_ENABLE_NAX + int B = out.size() / M / N; int wm = 2; @@ -719,6 +834,141 @@ void gather_qvm( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } +#ifdef MLX_ENABLE_NAX + +void gather_qmm_rhs_nax( + const array& x_, + const array& w_, + const array& scales_, + const std::optional& biases_, + const array& indices_, + array& out, + bool transpose, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s, + const std::string mode) { + // Start by normalizing the indices + array indices = ensure_row_contiguous(indices_, d, s); + + // Broadcast x with indices. If we are here that means lhs_indices were not + // provided so the lhs_indices are implied to be the shape of x broadcasted + // with rhs_indices. We need only broadcast x and copy it as if applying the + // lhs_indices. + auto broadcast_with_indices = [&d, &s, &indices](const array& x) { + if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) { + return ensure_row_contiguous(x, d, s); + } + + auto x_shape = indices.shape(); + x_shape.push_back(x.shape(-2)); + x_shape.push_back(x.shape(-1)); + array new_x(std::move(x_shape), x.dtype(), nullptr, {}); + broadcast(x, new_x); + return ensure_row_contiguous(new_x, d, s); + }; + + // Normalize the input arrays + array x = broadcast_with_indices(x_); + array w = ensure_row_contiguous(w_, d, s); + array scales = ensure_row_contiguous(scales_, d, s); + + // TODO: Tune the block sizes + int bm = 64, bn = 64, bk = 64; + int wm = 2, wn = 2; + + const bool align_M = (M % bm) == 0; + const bool align_N = (N % bn) == 0; + const bool align_K = (K % bk) == 0; + + // Make the kernel name + std::string kname; + kname.reserve(64); + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, + mode + + (transpose ? "_gather_qmm_rhs_nax_nt_" : "_gather_qmm_rhs_nax_nn_"), + type_string, + "_gs_", + group_size, + "_b_", + bits, + "_bm_", + bm, + "_bn_", + bn, + "_bk_", + bk, + "_wm_", + wm, + "_wn_", + wn); + + metal::MTLFCList func_consts = { + {&align_M, MTL::DataType::DataTypeBool, 200}, + {&align_N, MTL::DataType::DataTypeBool, 201}, + {&align_K, MTL::DataType::DataTypeBool, 202}, + }; + + // And the kernel hash that includes the function constants + std::string hash_name; + hash_name.reserve(128); + concatenate( + hash_name, + kname, + "_align_M_", + align_M ? 't' : 'n', + "_align_N_", + align_N ? 't' : 'n', + "_align_K_", + align_K ? 't' : 'n'); + + // Get and set the kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_gather_qmm_kernel( + d, + kname, + hash_name, + func_consts, + x, + group_size, + bits, + mode, + bm, + bn, + bk, + wm, + wn, + transpose); + compute_encoder.set_compute_pipeline_state(kernel); + + MTL::Size group_dims(32, wn, wm); + MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, 1); + + int c = 0; + compute_encoder.set_input_array(x, c++); + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases_) { + array biases = ensure_row_contiguous(*biases_, d, s); + compute_encoder.set_input_array(biases, c++); + } + compute_encoder.set_input_array(indices, c++); + compute_encoder.set_output_array(out, c++); + compute_encoder.set_bytes(M, c++); + compute_encoder.set_bytes(N, c++); + compute_encoder.set_bytes(K, c++); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +#endif // MLX_ENABLE_NAX + void gather_qmm_rhs( const array& x_, const array& w_, @@ -735,6 +985,31 @@ void gather_qmm_rhs( metal::Device& d, const Stream& s, const std::string mode) { +#ifdef MLX_ENABLE_NAX + + if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { + if (metal::is_nax_available() && mode == "affine" && (group_size >= 64)) { + return gather_qmm_rhs_nax( + /* const array& x_ = */ x_, + /* const array& w_ = */ w_, + /* const array& scales_ = */ scales_, + /* const std::optional& biases_ = */ biases_, + /* const array& indices_ = */ indices_, + /* array& out = */ out, + /* bool transpose = */ transpose, + /* int group_size = */ group_size, + /* int bits = */ bits, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* metal::Device& d = */ d, + /* const Stream& s = */ s, + /* const std::string mode = */ mode); + } + } + +#endif // MLX_ENABLE_NAX + // Start by normalizing the indices array indices = ensure_row_contiguous(indices_, d, s); diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 2eac8551f..156093757 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -164,19 +164,21 @@ void sdpa_full_self_attention_metal( const std::optional& mask, const std::optional& sinks) { #ifdef MLX_ENABLE_NAX - if (metal::is_nax_available() && q.shape(3) != 80 && - (q.dtype() != float32 || env::enable_tf32())) { - return sdpa_full_self_attention_nax( - /* const Stream& s = */ s, - /* metal::Device& d = */ d, - /* const array& q = */ q, - /* const array& k = */ k, - /* const array& v = */ v, - /* const float scale = */ scale, - /* array& o = */ o, - /* bool do_causal_ = */ do_causal_, - /* const std::optional& mask = */ mask, - /* const std::optional& sinks = */ sinks); + if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { + if (metal::is_nax_available() && q.shape(3) != 80 && + (q.dtype() != float32 || env::enable_tf32())) { + return sdpa_full_self_attention_nax( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& q = */ q, + /* const array& k = */ k, + /* const array& v = */ v, + /* const float scale = */ scale, + /* array& o = */ o, + /* bool do_causal_ = */ do_causal_, + /* const std::optional& mask = */ mask, + /* const std::optional& sinks = */ sinks); + } } #endif // MLX_ENABLE_NAX diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index e75106303..5ae1d8104 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -834,47 +834,54 @@ class TestQuantized(mlx_tests.MLXTestCase): (64, 512, 512, 4, 2, False, "affine"), ] for L, K, D, E, I, transpose, mode in parameters: - if mode == "mxfp4": - group_size = 32 - else: - group_size = 64 - K, D = (K, D) if transpose else (D, K) - ishape = (L, I) - xshape = (L, 1, 1, K) - wshape = (E, D, K) if transpose else (E, K, D) + with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode): + if mode == "mxfp4": + group_size = 32 + else: + group_size = 64 + K, D = (K, D) if transpose else (D, K) + ishape = (L, I) + xshape = (L, 1, 1, K) + wshape = (E, D, K) if transpose else (E, K, D) - indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32) - x = mx.random.normal(xshape) / K**0.5 - w = mx.random.normal(wshape) / K**0.5 - w, *wq = quantize(w, group_size=group_size, mode=mode, transpose=transpose) + indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32) + x = mx.random.normal(xshape) / K**0.5 + w = mx.random.normal(wshape) / K**0.5 + w, *wq = quantize( + w, group_size=group_size, mode=mode, transpose=transpose + ) - y1 = mx.gather_mm(x, w, rhs_indices=indices) - y2 = mx.gather_qmm( - x, - *wq, - group_size=group_size, - mode=mode, - transpose=transpose, - rhs_indices=indices - ) - xs, idx, inv_order = gather_sort(x, indices) - y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True) + y1 = mx.gather_mm(x, w, rhs_indices=indices) + y2 = mx.gather_qmm( + x, + *wq, + group_size=group_size, + mode=mode, + transpose=transpose, + rhs_indices=indices + ) + xs, idx, inv_order = gather_sort(x, indices) + y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True) - y4 = mx.gather_qmm( - xs, - *wq, - group_size=group_size, - mode=mode, - rhs_indices=idx, - transpose=transpose, - sorted_indices=True - ) - y3 = scatter_unsort(y3, inv_order, indices.shape) - y4 = scatter_unsort(y4, inv_order, indices.shape) + y4 = mx.gather_qmm( + xs, + *wq, + group_size=group_size, + mode=mode, + rhs_indices=idx, + transpose=transpose, + sorted_indices=True + ) + y3 = scatter_unsort(y3, inv_order, indices.shape) + y4 = scatter_unsort(y4, inv_order, indices.shape) - self.assertTrue(mx.allclose(y1, y2, atol=1e-5)) - self.assertTrue(mx.allclose(y1, y3, atol=1e-5)) - self.assertTrue(mx.allclose(y1, y4, atol=1e-5)) + self.assertLess((y1 - y2).abs().max(), 1e-5) + self.assertLess((y1 - y3).abs().max(), 1e-5) + self.assertLess((y1 - y4).abs().max(), 2e-4) + + self.assertTrue(mx.allclose(y1, y2, atol=1e-5)) + self.assertTrue(mx.allclose(y1, y3, atol=1e-5)) + self.assertTrue(mx.allclose(y1, y4, atol=2e-4)) def test_gather_qmm_grad(self): def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort):