diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 807215bb4..3346f01c8 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -246,6 +246,36 @@ std::vector AddMM::vjp( return vjps; } +std::vector AddMM::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + std::vector jvp; + for (int i = 0; i < argnums.size(); ++i) { + auto arg = argnums[i]; + if (arg == 0) { + if (jvp.empty()) { + jvp.push_back(matmul(tangents[i], primals[1], stream())); + } else { + jvp[0] = addmm(jvp[0], tangents[i], primals[1], 1.0f, 1.0f, stream()); + } + } else if (arg == 1) { + if (jvp.empty()) { + jvp.push_back(matmul(primals[0], tangents[i], stream())); + } else { + jvp[0] = addmm(jvp[0], primals[0], tangents[i], 1.0f, 1.0f, stream()); + } + } else { + if (jvp.empty()) { + jvp.push_back(tangents[i]); + } else { + jvp[0] = add(jvp[0], tangents[i], stream()); + } + } + } + return jvp; +} + bool AddMM::is_equivalent(const Primitive& other) const { const AddMM& a_other = static_cast(other); return (alpha_ == a_other.alpha_ && beta_ == a_other.beta_); @@ -2439,6 +2469,26 @@ std::vector Matmul::vjp( return vjps; } +std::vector Matmul::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + std::vector jvp; + for (int i = 0; i < argnums.size(); ++i) { + auto arg = argnums[i]; + if (arg == 0 && i == 0) { + jvp.push_back(matmul(tangents[0], primals[1], stream())); + } else if (arg == 0 && i == 1) { + jvp[0] = addmm(jvp[0], tangents[1], primals[1], 1.0f, 1.0f, stream()); + } else if (i == 0) { + jvp.push_back(matmul(primals[0], tangents[0], stream())); + } else if (i == 1) { + jvp[0] = addmm(jvp[0], primals[0], tangents[1], 1.0f, 1.0f, stream()); + } + } + return jvp; +} + std::pair, std::vector> Matmul::vmap( const std::vector& inputs, const std::vector& axes) { @@ -2833,7 +2883,7 @@ std::pair, std::vector> Power::vmap( std::pair, std::vector> QuantizedMatmul::vmap( const std::vector& inputs, const std::vector& axes) { - throw std::runtime_error("QuantizedMatmul::vmap NYI"); + throw std::runtime_error("[QuantizedMatmul::vmap] NYI"); } std::vector QuantizedMatmul::vjp( @@ -2861,7 +2911,7 @@ std::vector QuantizedMatmul::vjp( // gradient wrt to w_q, scales or biases else { throw std::runtime_error( - "QuantizedMatmul::vjp no gradient wrt the quantized matrix yet."); + "[QuantizedMatmul::vjp] no gradient wrt the quantized matrix yet."); } } return vjps; @@ -2871,7 +2921,19 @@ std::vector QuantizedMatmul::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { - throw std::runtime_error("QuantizedMatmul::jvp NYI"); + if (argnums.size() > 1 || argnums[0] != 0) { + throw std::runtime_error( + "[QuantizedMatmul::jvp] No JVP wrt the quantized matrix yet."); + } + return {quantized_matmul( + tangents[0], + primals[1], + primals[2], + primals[3], + transpose_, + group_size_, + bits_, + stream())}; } bool QuantizedMatmul::is_equivalent(const Primitive& other) const { diff --git a/mlx/primitives.h b/mlx/primitives.h index 6a66d5f5a..89209e0f6 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -193,12 +193,7 @@ class AddMM : public UnaryPrimitive { void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::vector vjp( - const std::vector& primals, - const std::vector& cotangents, - const std::vector& argnums, - const std::vector& outputs) override; - + DEFINE_GRADS() DEFINE_VMAP() DEFINE_PRINT(AddMM) @@ -1459,12 +1454,7 @@ class Matmul : public UnaryPrimitive { void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - std::vector vjp( - const std::vector& primals, - const std::vector& cotangents, - const std::vector& argnums, - const std::vector& outputs) override; - + DEFINE_GRADS() DEFINE_VMAP() DEFINE_PRINT(Matmul) DEFINE_DEFAULT_IS_EQUIVALENT() diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 727d3c060..4ab7fb922 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -634,6 +634,41 @@ class TestAutograd(mlx_tests.MLXTestCase): self.assertEqual(grads[0].dtype, mx.float32) self.assertEqual(grads[1].dtype, mx.float16) + def test_matmul_jvps(self): + a = mx.random.uniform(shape=(4, 4)) + b = mx.random.uniform(shape=(4, 4)) + c = mx.random.uniform(shape=(4, 4)) + d = mx.random.uniform(shape=(4, 4)) + + _, tangent = mx.jvp(lambda a: a @ b, (a,), (c,)) + self.assertTrue(mx.allclose(tangent[0], c @ b)) + + _, tangent = mx.jvp(lambda b: a @ b, (b,), (d,)) + self.assertTrue(mx.allclose(tangent[0], a @ d)) + + _, tangent = mx.jvp(lambda a, b: a @ b, (a, b), (c, d)) + self.assertTrue(mx.allclose(tangent[0], a @ d + c @ b)) + + x = mx.random.uniform(shape=(4, 4)) + y = mx.random.uniform(shape=(4, 4)) + z = mx.random.uniform(shape=(4, 4)) + + _, (tangent,) = mx.jvp(lambda a, b, c: a @ b + c, (a, b, c), (x, y, z)) + _, (expected,) = mx.jvp(lambda a, b, c: mx.addmm(c, a, b), (a, b, c), (x, y, z)) + self.assertTrue(mx.allclose(tangent, expected)) + + _, (tangent,) = mx.jvp(lambda a, c: a @ b + c, (a, c), (x, z)) + _, (expected,) = mx.jvp(lambda a, c: mx.addmm(c, a, b), (a, c), (x, z)) + self.assertTrue(mx.allclose(tangent, expected)) + + _, (tangent,) = mx.jvp(lambda b, c: a @ b + c, (b, c), (y, z)) + _, (expected,) = mx.jvp(lambda b, c: mx.addmm(c, a, b), (b, c), (y, z)) + self.assertTrue(mx.allclose(tangent, expected)) + + _, (tangent,) = mx.jvp(lambda c: a @ b + c, (c,), (z,)) + _, (expected,) = mx.jvp(lambda c: mx.addmm(c, a, b), (c,), (z,)) + self.assertTrue(mx.allclose(tangent, expected)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 6630338fc..363722bcf 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -34,8 +34,8 @@ class TestQuantized(mlx_tests.MLXTestCase): [128, 64, 32], # group_size [2, 4, 8], # bits [8, 32, 33, 64], # M - [512, 1024], # N - [512, 1024], # K + [128, 256], # N + [128, 256], # K [True, False], # transposed ) for group_size, bits, M, N, K, transposed in tests: @@ -86,6 +86,36 @@ class TestQuantized(mlx_tests.MLXTestCase): ) self.assertTrue(mx.allclose(vjp_out[0], expected_out)) + def test_qmm_jvp(self): + key = mx.random.key(0) + k1, k2 = mx.random.split(key) + + bits = 8 + group_size = 64 + M = 64 + N = 128 + K = 128 + + x = mx.random.normal(shape=(2, M, K), key=k1) + x_tan = mx.ones(shape=(2, M, N)) + + transposes = [True, False] + for transposed in transposes: + w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2) + w_q, scales, biases = mx.quantize(w, group_size, bits) + + def fn(x): + return mx.quantized_matmul( + x, w_q, scales, biases, transposed, group_size, bits + ) + + _, jvp_out = mx.jvp(fn, primals=(x,), tangents=(x_tan,)) + + expected_out = mx.quantized_matmul( + x_tan, w_q, scales, biases, transposed, group_size, bits + ) + self.assertTrue(mx.allclose(jvp_out[0], expected_out)) + def test_qmm_shapes(self): key = mx.random.key(0) k1, k2 = mx.random.split(key) @@ -117,8 +147,8 @@ class TestQuantized(mlx_tests.MLXTestCase): tests = product( [128, 64, 32], # group_size [2, 3, 4, 6, 8], # bits - [512, 1024, 67], # M - [64, 128, 512, 1024], # N + [256, 512, 67], # M + [64, 128], # N [0, 1, 3, 8], # B ) for group_size, bits, M, N, B in tests: @@ -144,8 +174,8 @@ class TestQuantized(mlx_tests.MLXTestCase): tests = product( [128, 64, 32], # group_size [2, 3, 4, 6, 8], # bits - [512, 1024], # M - [512, 1024, 67], # N + [128, 256], # M + [128, 256, 67], # N [0, 1, 3, 8], # B ) for group_size, bits, M, N, B in tests: