mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add fixes to int QMMs (CI passing)
This commit is contained in:
@@ -860,9 +860,6 @@ METAL_FUNC void qmm_t_nax_tgp_impl(
|
||||
const short tm = SM * (simd_gid / WN);
|
||||
const short tn = SN * (simd_gid % WN);
|
||||
|
||||
const short lda_tgp = BK_padded;
|
||||
const short ldb_tgp = BK_padded;
|
||||
|
||||
constexpr bool transpose_a = false;
|
||||
constexpr bool transpose_b = true;
|
||||
|
||||
@@ -898,7 +895,7 @@ METAL_FUNC void qmm_t_nax_tgp_impl(
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
#pragma clang loop unroll(disable)
|
||||
STEEL_PRAGMA_NO_UNROLL
|
||||
for (int kk1 = 0; kk1 < BK; kk1 += SK) {
|
||||
NAXTile<T, TM, TK, ASubTile> Atile;
|
||||
NAXTile<T, TN, TK, BSubTile> Btile;
|
||||
@@ -911,7 +908,7 @@ METAL_FUNC void qmm_t_nax_tgp_impl(
|
||||
Atile.load_safe(x + kk1, K, short2(SK, sgp_sm));
|
||||
}
|
||||
|
||||
Btile.template load<T, BK_padded, 1>(Ws + tn * ldb_tgp + kk1);
|
||||
Btile.template load<T, BK_padded, 1>(Ws + tn * BK_padded + kk1);
|
||||
|
||||
tile_matmad_nax(
|
||||
Dtile,
|
||||
@@ -964,6 +961,8 @@ METAL_FUNC void qmm_n_nax_tgp_impl(
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
(void)lid;
|
||||
|
||||
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
||||
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
||||
|
||||
@@ -997,8 +996,8 @@ METAL_FUNC void qmm_n_nax_tgp_impl(
|
||||
y += y_row * static_cast<int64_t>(N) + y_col;
|
||||
|
||||
// Make the x loader and mma operation
|
||||
const short num_els = min(BM, M - y_row);
|
||||
const short num_outs = min(BN, N - y_col);
|
||||
// const short num_els = min(BM, M - y_row);
|
||||
// const short num_outs = min(BN, N - y_col);
|
||||
loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid);
|
||||
|
||||
constexpr short UM = 16;
|
||||
@@ -1037,7 +1036,7 @@ METAL_FUNC void qmm_n_nax_tgp_impl(
|
||||
loader_w.load_unsafe();
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
#pragma clang loop unroll(disable)
|
||||
STEEL_PRAGMA_NO_UNROLL
|
||||
for (int kk1 = 0; kk1 < BK; kk1 += SK) {
|
||||
NAXTile<T, TM, TK, ASubTile> Atile;
|
||||
NAXTile<T, TK, TN, BSubTile> Btile;
|
||||
@@ -1408,10 +1407,17 @@ template <
|
||||
const short tm = SM * (simd_group_id / WN);
|
||||
const short tn = SN * (simd_group_id % WN);
|
||||
|
||||
const short sgp_sm = align_M ? SM : min(SM, short(M - (y_row + tm)));
|
||||
const short sgp_sm =
|
||||
align_M ? SM : min(SM, short(max(0, (M - (y_row + tm)))));
|
||||
const short sgp_sn =
|
||||
align_N ? SN : min(SN, short(max(0, (N - (y_col + tn)))));
|
||||
|
||||
const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
|
||||
const bool is_unaligned_bn = align_N ? false : (tgp_bn != BN);
|
||||
|
||||
constexpr short BR = transpose ? TN : TK;
|
||||
constexpr short BC = transpose ? TK : TN;
|
||||
|
||||
using AccumType = float;
|
||||
|
||||
using ASubTile = NAXSubTile<T, UM, UK>;
|
||||
@@ -1467,11 +1473,10 @@ template <
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
#pragma clang loop unroll(disable)
|
||||
STEEL_PRAGMA_NO_UNROLL
|
||||
for (int kk1 = 0; kk1 < BK; kk1 += SK) {
|
||||
NAXTile<T, TM, TK, ASubTile> Atile;
|
||||
NAXTile<T, transpose ? TN : TK, transpose ? TK : TN, BSubTile>
|
||||
Btile;
|
||||
NAXTile<T, BR, BC, BSubTile> Btile;
|
||||
|
||||
volatile int compiler_barrier;
|
||||
|
||||
@@ -1506,15 +1511,15 @@ template <
|
||||
loader_w.load_safe(tile_w);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
#pragma clang loop unroll(disable)
|
||||
STEEL_PRAGMA_NO_UNROLL
|
||||
for (int kk1 = 0; kk1 < BK; kk1 += SK) {
|
||||
NAXTile<T, TM, TK, ASubTile> Atile;
|
||||
NAXTile<T, transpose ? TN : TK, transpose ? TK : TN, BSubTile>
|
||||
Btile;
|
||||
NAXTile<T, BR, BC, BSubTile> Btile;
|
||||
|
||||
volatile int compiler_barrier;
|
||||
|
||||
Atile.load_safe(xn + kk1, K, short2((BK - kk1), sgp_sm));
|
||||
const short psk = min(int(SK), max(0, (BK - kk1)));
|
||||
Atile.load_safe(xn + kk1, K, short2(psk, sgp_sm));
|
||||
|
||||
if constexpr (transpose) {
|
||||
Btile.template load<T, BK_padded, 1>(Ws + tn * BK_padded + kk1);
|
||||
@@ -1535,23 +1540,23 @@ template <
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const short m_lo_lim = min(int(sgp_sm), max(0, offset - tm));
|
||||
const short m_hi_lim = min(int(sgp_sm), max(0, offset_next - tm));
|
||||
|
||||
// Store results to device memory
|
||||
if constexpr (kAlignedN.value) {
|
||||
if ((offset_next - offset) == BM) {
|
||||
if (m_lo_lim == 0 && m_hi_lim == SM) {
|
||||
Dtile.store(y + tm * N + tn, N);
|
||||
} else {
|
||||
Dtile.store_slice(
|
||||
y + tm * N + tn,
|
||||
N,
|
||||
short2(0, min(int(sgp_sm), max(0, offset - tm))),
|
||||
short2(BN, min(int(sgp_sm), max(0, offset_next - tm))));
|
||||
y + tm * N + tn, N, short2(0, m_lo_lim), short2(SN, m_hi_lim));
|
||||
}
|
||||
} else {
|
||||
Dtile.store_slice(
|
||||
y + tm * N + tn,
|
||||
N,
|
||||
short2(0, max(0, offset - tm)),
|
||||
short2(max(0, tgp_bn - tn), max(0, offset_next - tm)));
|
||||
short2(0, m_lo_lim),
|
||||
short2(sgp_sn, m_hi_lim));
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1825,7 +1825,7 @@ void gather_mm_rhs_nax(
|
||||
base_name.reserve(64);
|
||||
concatenate(
|
||||
base_name,
|
||||
"steel_gather_mm_rhs_mxu_n",
|
||||
"steel_gather_mm_rhs_nax_n",
|
||||
transpose_b ? 't' : 'n',
|
||||
'_',
|
||||
type_to_name(a),
|
||||
@@ -2200,7 +2200,8 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
if (__builtin_available(
|
||||
macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
if (metal::is_nax_available() && a.dtype() != float32) {
|
||||
if (metal::is_nax_available() &&
|
||||
(a.dtype() != float32 || env::enable_tf32())) {
|
||||
return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -539,6 +539,120 @@ void qmm_nax(
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void gather_qmm_nax(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const std::optional<array>& biases,
|
||||
const array& lhs_indices,
|
||||
const array& rhs_indices,
|
||||
array& out,
|
||||
bool transpose,
|
||||
int group_size,
|
||||
int bits,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
metal::Device& d,
|
||||
const Stream& s,
|
||||
const std::string& mode) {
|
||||
int B = out.size() / M / N;
|
||||
|
||||
int wm = 2;
|
||||
int wn = 2;
|
||||
int bm = 64;
|
||||
int bn = 64;
|
||||
int bk = 32;
|
||||
MTL::Size group_dims(32, wn, wm);
|
||||
MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B);
|
||||
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
bool aligned = N % 64 == 0;
|
||||
std::string type_string = get_type_string(x.dtype());
|
||||
concatenate(
|
||||
kname,
|
||||
mode + (transpose ? "_gather_qmm_t_nax_" : "_gather_qmm_n_nax_"),
|
||||
type_string,
|
||||
"_gs_",
|
||||
group_size,
|
||||
"_b_",
|
||||
bits,
|
||||
"_bm",
|
||||
bm,
|
||||
"_bn",
|
||||
bn,
|
||||
"_bk",
|
||||
bk,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn,
|
||||
transpose ? (aligned ? "_alN_true" : "_alN_false") : "");
|
||||
MTL::ComputePipelineState* kernel;
|
||||
if (transpose) {
|
||||
kernel = get_quantized_kernel_wrapped(
|
||||
d,
|
||||
kname,
|
||||
"gather_qmm_t_nax_",
|
||||
mode,
|
||||
type_string,
|
||||
group_size,
|
||||
bits,
|
||||
"_bm",
|
||||
bm,
|
||||
"_bn",
|
||||
bn,
|
||||
"_bk",
|
||||
bk,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn,
|
||||
aligned);
|
||||
} else {
|
||||
kernel = get_quantized_kernel_wrapped(
|
||||
d,
|
||||
kname,
|
||||
"gather_qmm_n_nax_",
|
||||
mode,
|
||||
type_string,
|
||||
group_size,
|
||||
bits,
|
||||
"_bm",
|
||||
bm,
|
||||
"_bn",
|
||||
bn,
|
||||
"_bk",
|
||||
bk,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn);
|
||||
}
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int c = 0;
|
||||
compute_encoder.set_input_array(w, c++);
|
||||
compute_encoder.set_input_array(scales, c++);
|
||||
if (biases) {
|
||||
compute_encoder.set_input_array(*biases, c++);
|
||||
}
|
||||
compute_encoder.set_input_array(x, c++);
|
||||
compute_encoder.set_input_array(lhs_indices, c++);
|
||||
compute_encoder.set_input_array(rhs_indices, c++);
|
||||
compute_encoder.set_output_array(out, c++);
|
||||
compute_encoder.set_bytes(K, c++);
|
||||
compute_encoder.set_bytes(N, c++);
|
||||
compute_encoder.set_bytes(M, c++);
|
||||
c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c);
|
||||
add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
void qmm(
|
||||
@@ -559,8 +673,9 @@ void qmm(
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
if (metal::is_nax_available() && mode == "affine" && (group_size >= 64) &&
|
||||
transpose && (M % 64 == 0) && (N % 64 == 0) && (K % 64 == 0)) {
|
||||
if (metal::is_nax_available() && transpose &&
|
||||
(x.dtype() != float32 || env::enable_tf32()) && mode == "affine" &&
|
||||
(group_size >= 64) && (K % 64 == 0)) {
|
||||
return qmm_nax(
|
||||
/* const array& x = */ x,
|
||||
/* const array& w = */ w,
|
||||
@@ -658,6 +773,34 @@ void gather_qmm(
|
||||
metal::Device& d,
|
||||
const Stream& s,
|
||||
const std::string& mode) {
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
if (metal::is_nax_available() && transpose &&
|
||||
(x.dtype() != float32 || env::enable_tf32()) && transpose &&
|
||||
mode == "affine" && (group_size >= 64) && (K % 64 == 0)) {
|
||||
return gather_qmm_nax(
|
||||
/* const array& x = */ x,
|
||||
/* const array& w = */ w,
|
||||
/* const array& scales = */ scales,
|
||||
/* const std::optional<array>& biases = */ biases,
|
||||
/* const array& lhs_indices = */ lhs_indices,
|
||||
/* const array& rhs_indices = */ rhs_indices,
|
||||
/* array& out = */ out,
|
||||
/* bool transpose = */ transpose,
|
||||
/* int group_size = */ group_size,
|
||||
/* int bits = */ bits,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const Stream& s = */ s,
|
||||
/* const std::string& mode = */ mode);
|
||||
}
|
||||
}
|
||||
|
||||
#endif // MLX_ENABLE_NAX
|
||||
|
||||
int B = out.size() / M / N;
|
||||
|
||||
int wm = 2;
|
||||
@@ -988,7 +1131,9 @@ void gather_qmm_rhs(
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
|
||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
if (metal::is_nax_available() && mode == "affine" && (group_size >= 64)) {
|
||||
if (metal::is_nax_available() &&
|
||||
(x_.dtype() != float32 || env::enable_tf32()) && mode == "affine" &&
|
||||
(group_size >= 64)) {
|
||||
return gather_qmm_rhs_nax(
|
||||
/* const array& x_ = */ x_,
|
||||
/* const array& w_ = */ w_,
|
||||
|
||||
@@ -163,6 +163,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
def test_qmm(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
dtype = mx.float16
|
||||
tests = product(
|
||||
[128, 64, 32], # group_size
|
||||
[2, 4, 8], # bits
|
||||
@@ -178,8 +179,13 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
bits=bits,
|
||||
transposed=transposed,
|
||||
):
|
||||
x = mx.random.normal(shape=(M, K), key=k1)
|
||||
w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
|
||||
x = mx.random.normal(shape=(M, K), key=k1) / K**0.5
|
||||
w = (
|
||||
mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
|
||||
/ K**0.5
|
||||
)
|
||||
x = x.astype(dtype)
|
||||
w = w.astype(dtype)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||
y_q = mx.quantized_matmul(
|
||||
@@ -833,20 +839,34 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
(133, 512, 555, 4, 2, False, "affine"),
|
||||
(64, 512, 512, 4, 2, False, "affine"),
|
||||
]
|
||||
|
||||
key = mx.random.key(0)
|
||||
k1, k2, k3 = mx.random.split(key, 3)
|
||||
dtype = mx.float16
|
||||
|
||||
for L, K, D, E, I, transpose, mode in parameters:
|
||||
with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode):
|
||||
if mode == "mxfp4":
|
||||
group_size = 32
|
||||
dtype = mx.bfloat16
|
||||
else:
|
||||
group_size = 64
|
||||
dtype = mx.float16
|
||||
|
||||
K, D = (K, D) if transpose else (D, K)
|
||||
ishape = (L, I)
|
||||
xshape = (L, 1, 1, K)
|
||||
wshape = (E, D, K) if transpose else (E, K, D)
|
||||
|
||||
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
|
||||
x = mx.random.normal(xshape) / K**0.5
|
||||
w = mx.random.normal(wshape) / K**0.5
|
||||
indices = (mx.random.uniform(shape=ishape, key=k1) * E).astype(
|
||||
mx.uint32
|
||||
)
|
||||
x = mx.random.normal(xshape, key=k2) / K**0.5
|
||||
w = mx.random.normal(wshape, key=k3) / K**0.5
|
||||
|
||||
x = x.astype(dtype)
|
||||
w = w.astype(dtype)
|
||||
|
||||
w, *wq = quantize(
|
||||
w, group_size=group_size, mode=mode, transpose=transpose
|
||||
)
|
||||
@@ -875,13 +895,15 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
y3 = scatter_unsort(y3, inv_order, indices.shape)
|
||||
y4 = scatter_unsort(y4, inv_order, indices.shape)
|
||||
|
||||
self.assertLess((y1 - y2).abs().max(), 1e-5)
|
||||
self.assertLess((y1 - y3).abs().max(), 1e-5)
|
||||
self.assertLess((y1 - y4).abs().max(), 2e-4)
|
||||
tol = 1.5e-5 if (dtype == mx.float32) else 2.5e-4
|
||||
|
||||
self.assertTrue(mx.allclose(y1, y2, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(y1, y4, atol=2e-4))
|
||||
self.assertLess((y1 - y2).abs().max(), tol)
|
||||
self.assertLess((y1 - y3).abs().max(), tol)
|
||||
self.assertLess((y1 - y4).abs().max(), tol)
|
||||
|
||||
self.assertTrue(mx.allclose(y1, y2, atol=tol))
|
||||
self.assertTrue(mx.allclose(y1, y3, atol=tol))
|
||||
self.assertTrue(mx.allclose(y1, y4, atol=tol))
|
||||
|
||||
def test_gather_qmm_grad(self):
|
||||
def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort):
|
||||
@@ -905,10 +927,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
sorted_indices=sort,
|
||||
)
|
||||
|
||||
x = mx.random.normal((16, 1, 256))
|
||||
w, s, b = mx.quantize(mx.random.normal((4, 256, 256)))
|
||||
indices = mx.sort(mx.random.randint(0, 4, shape=(16,)))
|
||||
cotan = mx.random.normal((16, 1, 256))
|
||||
key = mx.random.key(0)
|
||||
k1, k2, k3, k4 = mx.random.split(key, 4)
|
||||
dtype = mx.float32
|
||||
|
||||
x = mx.random.normal((16, 1, 256), key=k1).astype(dtype)
|
||||
w, s, b = mx.quantize(mx.random.normal((4, 256, 256), key=k2).astype(dtype))
|
||||
indices = mx.sort(mx.random.randint(0, 4, shape=(16,), key=k3))
|
||||
cotan = mx.random.normal((16, 1, 256), key=k4).astype(dtype)
|
||||
|
||||
(o1,), (dx1, ds1, db1) = mx.vjp(
|
||||
lambda x, s, b: gather_qmm_ref(x, w, s, b, None, indices, True, True),
|
||||
@@ -921,6 +947,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
[cotan],
|
||||
)
|
||||
|
||||
self.assertLess((o1 - o2).abs().max(), 1e-4)
|
||||
self.assertTrue(mx.allclose(o1, o2, atol=1e-4))
|
||||
self.assertTrue(mx.allclose(dx1, dx2, atol=1e-4))
|
||||
self.assertTrue(mx.allclose(ds1, ds2, atol=1e-3))
|
||||
|
||||
Reference in New Issue
Block a user