mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 23:24:41 +08:00
matmul jvps (#1772)
This commit is contained in:
@@ -34,8 +34,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
[128, 64, 32], # group_size
|
||||
[2, 4, 8], # bits
|
||||
[8, 32, 33, 64], # M
|
||||
[512, 1024], # N
|
||||
[512, 1024], # K
|
||||
[128, 256], # N
|
||||
[128, 256], # K
|
||||
[True, False], # transposed
|
||||
)
|
||||
for group_size, bits, M, N, K, transposed in tests:
|
||||
@@ -86,6 +86,36 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(mx.allclose(vjp_out[0], expected_out))
|
||||
|
||||
def test_qmm_jvp(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
|
||||
bits = 8
|
||||
group_size = 64
|
||||
M = 64
|
||||
N = 128
|
||||
K = 128
|
||||
|
||||
x = mx.random.normal(shape=(2, M, K), key=k1)
|
||||
x_tan = mx.ones(shape=(2, M, N))
|
||||
|
||||
transposes = [True, False]
|
||||
for transposed in transposes:
|
||||
w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
|
||||
def fn(x):
|
||||
return mx.quantized_matmul(
|
||||
x, w_q, scales, biases, transposed, group_size, bits
|
||||
)
|
||||
|
||||
_, jvp_out = mx.jvp(fn, primals=(x,), tangents=(x_tan,))
|
||||
|
||||
expected_out = mx.quantized_matmul(
|
||||
x_tan, w_q, scales, biases, transposed, group_size, bits
|
||||
)
|
||||
self.assertTrue(mx.allclose(jvp_out[0], expected_out))
|
||||
|
||||
def test_qmm_shapes(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
@@ -117,8 +147,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
tests = product(
|
||||
[128, 64, 32], # group_size
|
||||
[2, 3, 4, 6, 8], # bits
|
||||
[512, 1024, 67], # M
|
||||
[64, 128, 512, 1024], # N
|
||||
[256, 512, 67], # M
|
||||
[64, 128], # N
|
||||
[0, 1, 3, 8], # B
|
||||
)
|
||||
for group_size, bits, M, N, B in tests:
|
||||
@@ -144,8 +174,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
tests = product(
|
||||
[128, 64, 32], # group_size
|
||||
[2, 3, 4, 6, 8], # bits
|
||||
[512, 1024], # M
|
||||
[512, 1024, 67], # N
|
||||
[128, 256], # M
|
||||
[128, 256, 67], # N
|
||||
[0, 1, 3, 8], # B
|
||||
)
|
||||
for group_size, bits, M, N, B in tests:
|
||||
|
Reference in New Issue
Block a user