matmul jvps (#1772)

This commit is contained in:
Awni Hannun
2025-01-17 10:36:26 -08:00
committed by GitHub
parent f288db8d34
commit 0c259961ac
4 changed files with 138 additions and 21 deletions

View File

@@ -634,6 +634,41 @@ class TestAutograd(mlx_tests.MLXTestCase):
self.assertEqual(grads[0].dtype, mx.float32)
self.assertEqual(grads[1].dtype, mx.float16)
def test_matmul_jvps(self):
a = mx.random.uniform(shape=(4, 4))
b = mx.random.uniform(shape=(4, 4))
c = mx.random.uniform(shape=(4, 4))
d = mx.random.uniform(shape=(4, 4))
_, tangent = mx.jvp(lambda a: a @ b, (a,), (c,))
self.assertTrue(mx.allclose(tangent[0], c @ b))
_, tangent = mx.jvp(lambda b: a @ b, (b,), (d,))
self.assertTrue(mx.allclose(tangent[0], a @ d))
_, tangent = mx.jvp(lambda a, b: a @ b, (a, b), (c, d))
self.assertTrue(mx.allclose(tangent[0], a @ d + c @ b))
x = mx.random.uniform(shape=(4, 4))
y = mx.random.uniform(shape=(4, 4))
z = mx.random.uniform(shape=(4, 4))
_, (tangent,) = mx.jvp(lambda a, b, c: a @ b + c, (a, b, c), (x, y, z))
_, (expected,) = mx.jvp(lambda a, b, c: mx.addmm(c, a, b), (a, b, c), (x, y, z))
self.assertTrue(mx.allclose(tangent, expected))
_, (tangent,) = mx.jvp(lambda a, c: a @ b + c, (a, c), (x, z))
_, (expected,) = mx.jvp(lambda a, c: mx.addmm(c, a, b), (a, c), (x, z))
self.assertTrue(mx.allclose(tangent, expected))
_, (tangent,) = mx.jvp(lambda b, c: a @ b + c, (b, c), (y, z))
_, (expected,) = mx.jvp(lambda b, c: mx.addmm(c, a, b), (b, c), (y, z))
self.assertTrue(mx.allclose(tangent, expected))
_, (tangent,) = mx.jvp(lambda c: a @ b + c, (c,), (z,))
_, (expected,) = mx.jvp(lambda c: mx.addmm(c, a, b), (c,), (z,))
self.assertTrue(mx.allclose(tangent, expected))
if __name__ == "__main__":
unittest.main()

View File

@@ -34,8 +34,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
[128, 64, 32], # group_size
[2, 4, 8], # bits
[8, 32, 33, 64], # M
[512, 1024], # N
[512, 1024], # K
[128, 256], # N
[128, 256], # K
[True, False], # transposed
)
for group_size, bits, M, N, K, transposed in tests:
@@ -86,6 +86,36 @@ class TestQuantized(mlx_tests.MLXTestCase):
)
self.assertTrue(mx.allclose(vjp_out[0], expected_out))
def test_qmm_jvp(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
bits = 8
group_size = 64
M = 64
N = 128
K = 128
x = mx.random.normal(shape=(2, M, K), key=k1)
x_tan = mx.ones(shape=(2, M, N))
transposes = [True, False]
for transposed in transposes:
w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
def fn(x):
return mx.quantized_matmul(
x, w_q, scales, biases, transposed, group_size, bits
)
_, jvp_out = mx.jvp(fn, primals=(x,), tangents=(x_tan,))
expected_out = mx.quantized_matmul(
x_tan, w_q, scales, biases, transposed, group_size, bits
)
self.assertTrue(mx.allclose(jvp_out[0], expected_out))
def test_qmm_shapes(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
@@ -117,8 +147,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
tests = product(
[128, 64, 32], # group_size
[2, 3, 4, 6, 8], # bits
[512, 1024, 67], # M
[64, 128, 512, 1024], # N
[256, 512, 67], # M
[64, 128], # N
[0, 1, 3, 8], # B
)
for group_size, bits, M, N, B in tests:
@@ -144,8 +174,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
tests = product(
[128, 64, 32], # group_size
[2, 3, 4, 6, 8], # bits
[512, 1024], # M
[512, 1024, 67], # N
[128, 256], # M
[128, 256, 67], # N
[0, 1, 3, 8], # B
)
for group_size, bits, M, N, B in tests: