mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add Neural Accelerator Support (#2772)
This commit is contained in:
@@ -163,6 +163,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
def test_qmm(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
dtype = mx.float16 if (mx.default_device() == mx.gpu) else mx.float32
|
||||
tests = product(
|
||||
[128, 64, 32], # group_size
|
||||
[2, 4, 8], # bits
|
||||
@@ -178,8 +179,13 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
bits=bits,
|
||||
transposed=transposed,
|
||||
):
|
||||
x = mx.random.normal(shape=(M, K), key=k1)
|
||||
w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
|
||||
x = mx.random.normal(shape=(M, K), key=k1) / K**0.5
|
||||
w = (
|
||||
mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
|
||||
/ K**0.5
|
||||
)
|
||||
x = x.astype(dtype)
|
||||
w = w.astype(dtype)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||
y_q = mx.quantized_matmul(
|
||||
@@ -187,7 +193,9 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
)
|
||||
y_hat = (x @ w_hat.T) if transposed else (x @ w_hat)
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
tol = 1e-3 if dtype == mx.float32 else 1.5e-3
|
||||
self.assertLess((y_q - y_hat).abs().max(), tol)
|
||||
|
||||
def test_qmm_vjp(self):
|
||||
key = mx.random.key(0)
|
||||
@@ -833,48 +841,75 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
(133, 512, 555, 4, 2, False, "affine"),
|
||||
(64, 512, 512, 4, 2, False, "affine"),
|
||||
]
|
||||
|
||||
key = mx.random.key(0)
|
||||
k1, k2, k3 = mx.random.split(key, 3)
|
||||
dtype = mx.float16 if (mx.default_device() == mx.gpu) else mx.float32
|
||||
|
||||
for L, K, D, E, I, transpose, mode in parameters:
|
||||
if mode == "mxfp4":
|
||||
group_size = 32
|
||||
else:
|
||||
group_size = 64
|
||||
K, D = (K, D) if transpose else (D, K)
|
||||
ishape = (L, I)
|
||||
xshape = (L, 1, 1, K)
|
||||
wshape = (E, D, K) if transpose else (E, K, D)
|
||||
with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode):
|
||||
if mode == "mxfp4":
|
||||
group_size = 32
|
||||
dtype = (
|
||||
mx.bfloat16 if (mx.default_device() == mx.gpu) else mx.float32
|
||||
)
|
||||
else:
|
||||
group_size = 64
|
||||
dtype = (
|
||||
mx.float16 if (mx.default_device() == mx.gpu) else mx.float32
|
||||
)
|
||||
|
||||
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
|
||||
x = mx.random.normal(xshape) / K**0.5
|
||||
w = mx.random.normal(wshape) / K**0.5
|
||||
w, *wq = quantize(w, group_size=group_size, mode=mode, transpose=transpose)
|
||||
K, D = (K, D) if transpose else (D, K)
|
||||
ishape = (L, I)
|
||||
xshape = (L, 1, 1, K)
|
||||
wshape = (E, D, K) if transpose else (E, K, D)
|
||||
|
||||
y1 = mx.gather_mm(x, w, rhs_indices=indices)
|
||||
y2 = mx.gather_qmm(
|
||||
x,
|
||||
*wq,
|
||||
group_size=group_size,
|
||||
mode=mode,
|
||||
transpose=transpose,
|
||||
rhs_indices=indices
|
||||
)
|
||||
xs, idx, inv_order = gather_sort(x, indices)
|
||||
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
|
||||
indices = (mx.random.uniform(shape=ishape, key=k1) * E).astype(
|
||||
mx.uint32
|
||||
)
|
||||
x = mx.random.normal(xshape, key=k2) / K**0.5
|
||||
w = mx.random.normal(wshape, key=k3) / K**0.5
|
||||
|
||||
y4 = mx.gather_qmm(
|
||||
xs,
|
||||
*wq,
|
||||
group_size=group_size,
|
||||
mode=mode,
|
||||
rhs_indices=idx,
|
||||
transpose=transpose,
|
||||
sorted_indices=True
|
||||
)
|
||||
y3 = scatter_unsort(y3, inv_order, indices.shape)
|
||||
y4 = scatter_unsort(y4, inv_order, indices.shape)
|
||||
x = x.astype(dtype)
|
||||
w = w.astype(dtype)
|
||||
|
||||
self.assertTrue(mx.allclose(y1, y2, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
|
||||
w, *wq = quantize(
|
||||
w, group_size=group_size, mode=mode, transpose=transpose
|
||||
)
|
||||
|
||||
y1 = mx.gather_mm(x, w, rhs_indices=indices)
|
||||
y2 = mx.gather_qmm(
|
||||
x,
|
||||
*wq,
|
||||
group_size=group_size,
|
||||
mode=mode,
|
||||
transpose=transpose,
|
||||
rhs_indices=indices
|
||||
)
|
||||
xs, idx, inv_order = gather_sort(x, indices)
|
||||
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
|
||||
|
||||
y4 = mx.gather_qmm(
|
||||
xs,
|
||||
*wq,
|
||||
group_size=group_size,
|
||||
mode=mode,
|
||||
rhs_indices=idx,
|
||||
transpose=transpose,
|
||||
sorted_indices=True
|
||||
)
|
||||
y3 = scatter_unsort(y3, inv_order, indices.shape)
|
||||
y4 = scatter_unsort(y4, inv_order, indices.shape)
|
||||
|
||||
tol = 1.5e-5 if (dtype == mx.float32) else 2.5e-4
|
||||
|
||||
self.assertLess((y1 - y2).abs().max(), tol)
|
||||
self.assertLess((y1 - y3).abs().max(), tol)
|
||||
self.assertLess((y1 - y4).abs().max(), tol)
|
||||
|
||||
self.assertTrue(mx.allclose(y1, y2, atol=tol))
|
||||
self.assertTrue(mx.allclose(y1, y3, atol=tol))
|
||||
self.assertTrue(mx.allclose(y1, y4, atol=tol))
|
||||
|
||||
def test_gather_qmm_grad(self):
|
||||
def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort):
|
||||
@@ -898,10 +933,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
sorted_indices=sort,
|
||||
)
|
||||
|
||||
x = mx.random.normal((16, 1, 256))
|
||||
w, s, b = mx.quantize(mx.random.normal((4, 256, 256)))
|
||||
indices = mx.sort(mx.random.randint(0, 4, shape=(16,)))
|
||||
cotan = mx.random.normal((16, 1, 256))
|
||||
key = mx.random.key(0)
|
||||
k1, k2, k3, k4 = mx.random.split(key, 4)
|
||||
dtype = mx.float32
|
||||
|
||||
x = mx.random.normal((16, 1, 256), key=k1).astype(dtype)
|
||||
w, s, b = mx.quantize(mx.random.normal((4, 256, 256), key=k2).astype(dtype))
|
||||
indices = mx.sort(mx.random.randint(0, 4, shape=(16,), key=k3))
|
||||
cotan = mx.random.normal((16, 1, 256), key=k4).astype(dtype)
|
||||
|
||||
(o1,), (dx1, ds1, db1) = mx.vjp(
|
||||
lambda x, s, b: gather_qmm_ref(x, w, s, b, None, indices, True, True),
|
||||
@@ -914,6 +953,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
[cotan],
|
||||
)
|
||||
|
||||
self.assertLess((o1 - o2).abs().max(), 1e-4)
|
||||
self.assertTrue(mx.allclose(o1, o2, atol=1e-4))
|
||||
self.assertTrue(mx.allclose(dx1, dx2, atol=1e-4))
|
||||
self.assertTrue(mx.allclose(ds1, ds2, atol=1e-3))
|
||||
|
||||
Reference in New Issue
Block a user