mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	matmul jvps (#1772)
This commit is contained in:
		| @@ -246,6 +246,36 @@ std::vector<array> AddMM::vjp( | ||||
|   return vjps; | ||||
| } | ||||
|  | ||||
| std::vector<array> AddMM::jvp( | ||||
|     const std::vector<array>& primals, | ||||
|     const std::vector<array>& tangents, | ||||
|     const std::vector<int>& argnums) { | ||||
|   std::vector<array> jvp; | ||||
|   for (int i = 0; i < argnums.size(); ++i) { | ||||
|     auto arg = argnums[i]; | ||||
|     if (arg == 0) { | ||||
|       if (jvp.empty()) { | ||||
|         jvp.push_back(matmul(tangents[i], primals[1], stream())); | ||||
|       } else { | ||||
|         jvp[0] = addmm(jvp[0], tangents[i], primals[1], 1.0f, 1.0f, stream()); | ||||
|       } | ||||
|     } else if (arg == 1) { | ||||
|       if (jvp.empty()) { | ||||
|         jvp.push_back(matmul(primals[0], tangents[i], stream())); | ||||
|       } else { | ||||
|         jvp[0] = addmm(jvp[0], primals[0], tangents[i], 1.0f, 1.0f, stream()); | ||||
|       } | ||||
|     } else { | ||||
|       if (jvp.empty()) { | ||||
|         jvp.push_back(tangents[i]); | ||||
|       } else { | ||||
|         jvp[0] = add(jvp[0], tangents[i], stream()); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   return jvp; | ||||
| } | ||||
|  | ||||
| bool AddMM::is_equivalent(const Primitive& other) const { | ||||
|   const AddMM& a_other = static_cast<const AddMM&>(other); | ||||
|   return (alpha_ == a_other.alpha_ && beta_ == a_other.beta_); | ||||
| @@ -2439,6 +2469,26 @@ std::vector<array> Matmul::vjp( | ||||
|   return vjps; | ||||
| } | ||||
|  | ||||
| std::vector<array> Matmul::jvp( | ||||
|     const std::vector<array>& primals, | ||||
|     const std::vector<array>& tangents, | ||||
|     const std::vector<int>& argnums) { | ||||
|   std::vector<array> jvp; | ||||
|   for (int i = 0; i < argnums.size(); ++i) { | ||||
|     auto arg = argnums[i]; | ||||
|     if (arg == 0 && i == 0) { | ||||
|       jvp.push_back(matmul(tangents[0], primals[1], stream())); | ||||
|     } else if (arg == 0 && i == 1) { | ||||
|       jvp[0] = addmm(jvp[0], tangents[1], primals[1], 1.0f, 1.0f, stream()); | ||||
|     } else if (i == 0) { | ||||
|       jvp.push_back(matmul(primals[0], tangents[0], stream())); | ||||
|     } else if (i == 1) { | ||||
|       jvp[0] = addmm(jvp[0], primals[0], tangents[1], 1.0f, 1.0f, stream()); | ||||
|     } | ||||
|   } | ||||
|   return jvp; | ||||
| } | ||||
|  | ||||
| std::pair<std::vector<array>, std::vector<int>> Matmul::vmap( | ||||
|     const std::vector<array>& inputs, | ||||
|     const std::vector<int>& axes) { | ||||
| @@ -2833,7 +2883,7 @@ std::pair<std::vector<array>, std::vector<int>> Power::vmap( | ||||
| std::pair<std::vector<array>, std::vector<int>> QuantizedMatmul::vmap( | ||||
|     const std::vector<array>& inputs, | ||||
|     const std::vector<int>& axes) { | ||||
|   throw std::runtime_error("QuantizedMatmul::vmap NYI"); | ||||
|   throw std::runtime_error("[QuantizedMatmul::vmap] NYI"); | ||||
| } | ||||
|  | ||||
| std::vector<array> QuantizedMatmul::vjp( | ||||
| @@ -2861,7 +2911,7 @@ std::vector<array> QuantizedMatmul::vjp( | ||||
|     // gradient wrt to w_q, scales or biases | ||||
|     else { | ||||
|       throw std::runtime_error( | ||||
|           "QuantizedMatmul::vjp no gradient wrt the quantized matrix yet."); | ||||
|           "[QuantizedMatmul::vjp] no gradient wrt the quantized matrix yet."); | ||||
|     } | ||||
|   } | ||||
|   return vjps; | ||||
| @@ -2871,7 +2921,19 @@ std::vector<array> QuantizedMatmul::jvp( | ||||
|     const std::vector<array>& primals, | ||||
|     const std::vector<array>& tangents, | ||||
|     const std::vector<int>& argnums) { | ||||
|   throw std::runtime_error("QuantizedMatmul::jvp NYI"); | ||||
|   if (argnums.size() > 1 || argnums[0] != 0) { | ||||
|     throw std::runtime_error( | ||||
|         "[QuantizedMatmul::jvp] No JVP wrt the quantized matrix yet."); | ||||
|   } | ||||
|   return {quantized_matmul( | ||||
|       tangents[0], | ||||
|       primals[1], | ||||
|       primals[2], | ||||
|       primals[3], | ||||
|       transpose_, | ||||
|       group_size_, | ||||
|       bits_, | ||||
|       stream())}; | ||||
| } | ||||
|  | ||||
| bool QuantizedMatmul::is_equivalent(const Primitive& other) const { | ||||
|   | ||||
| @@ -193,12 +193,7 @@ class AddMM : public UnaryPrimitive { | ||||
|   void eval_cpu(const std::vector<array>& inputs, array& out) override; | ||||
|   void eval_gpu(const std::vector<array>& inputs, array& out) override; | ||||
|  | ||||
|   std::vector<array> vjp( | ||||
|       const std::vector<array>& primals, | ||||
|       const std::vector<array>& cotangents, | ||||
|       const std::vector<int>& argnums, | ||||
|       const std::vector<array>& outputs) override; | ||||
|  | ||||
|   DEFINE_GRADS() | ||||
|   DEFINE_VMAP() | ||||
|   DEFINE_PRINT(AddMM) | ||||
|  | ||||
| @@ -1459,12 +1454,7 @@ class Matmul : public UnaryPrimitive { | ||||
|   void eval_cpu(const std::vector<array>& inputs, array& out) override; | ||||
|   void eval_gpu(const std::vector<array>& inputs, array& out) override; | ||||
|  | ||||
|   std::vector<array> vjp( | ||||
|       const std::vector<array>& primals, | ||||
|       const std::vector<array>& cotangents, | ||||
|       const std::vector<int>& argnums, | ||||
|       const std::vector<array>& outputs) override; | ||||
|  | ||||
|   DEFINE_GRADS() | ||||
|   DEFINE_VMAP() | ||||
|   DEFINE_PRINT(Matmul) | ||||
|   DEFINE_DEFAULT_IS_EQUIVALENT() | ||||
|   | ||||
| @@ -634,6 +634,41 @@ class TestAutograd(mlx_tests.MLXTestCase): | ||||
|         self.assertEqual(grads[0].dtype, mx.float32) | ||||
|         self.assertEqual(grads[1].dtype, mx.float16) | ||||
|  | ||||
|     def test_matmul_jvps(self): | ||||
|         a = mx.random.uniform(shape=(4, 4)) | ||||
|         b = mx.random.uniform(shape=(4, 4)) | ||||
|         c = mx.random.uniform(shape=(4, 4)) | ||||
|         d = mx.random.uniform(shape=(4, 4)) | ||||
|  | ||||
|         _, tangent = mx.jvp(lambda a: a @ b, (a,), (c,)) | ||||
|         self.assertTrue(mx.allclose(tangent[0], c @ b)) | ||||
|  | ||||
|         _, tangent = mx.jvp(lambda b: a @ b, (b,), (d,)) | ||||
|         self.assertTrue(mx.allclose(tangent[0], a @ d)) | ||||
|  | ||||
|         _, tangent = mx.jvp(lambda a, b: a @ b, (a, b), (c, d)) | ||||
|         self.assertTrue(mx.allclose(tangent[0], a @ d + c @ b)) | ||||
|  | ||||
|         x = mx.random.uniform(shape=(4, 4)) | ||||
|         y = mx.random.uniform(shape=(4, 4)) | ||||
|         z = mx.random.uniform(shape=(4, 4)) | ||||
|  | ||||
|         _, (tangent,) = mx.jvp(lambda a, b, c: a @ b + c, (a, b, c), (x, y, z)) | ||||
|         _, (expected,) = mx.jvp(lambda a, b, c: mx.addmm(c, a, b), (a, b, c), (x, y, z)) | ||||
|         self.assertTrue(mx.allclose(tangent, expected)) | ||||
|  | ||||
|         _, (tangent,) = mx.jvp(lambda a, c: a @ b + c, (a, c), (x, z)) | ||||
|         _, (expected,) = mx.jvp(lambda a, c: mx.addmm(c, a, b), (a, c), (x, z)) | ||||
|         self.assertTrue(mx.allclose(tangent, expected)) | ||||
|  | ||||
|         _, (tangent,) = mx.jvp(lambda b, c: a @ b + c, (b, c), (y, z)) | ||||
|         _, (expected,) = mx.jvp(lambda b, c: mx.addmm(c, a, b), (b, c), (y, z)) | ||||
|         self.assertTrue(mx.allclose(tangent, expected)) | ||||
|  | ||||
|         _, (tangent,) = mx.jvp(lambda c: a @ b + c, (c,), (z,)) | ||||
|         _, (expected,) = mx.jvp(lambda c: mx.addmm(c, a, b), (c,), (z,)) | ||||
|         self.assertTrue(mx.allclose(tangent, expected)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
| @@ -34,8 +34,8 @@ class TestQuantized(mlx_tests.MLXTestCase): | ||||
|             [128, 64, 32],  # group_size | ||||
|             [2, 4, 8],  # bits | ||||
|             [8, 32, 33, 64],  # M | ||||
|             [512, 1024],  # N | ||||
|             [512, 1024],  # K | ||||
|             [128, 256],  # N | ||||
|             [128, 256],  # K | ||||
|             [True, False],  # transposed | ||||
|         ) | ||||
|         for group_size, bits, M, N, K, transposed in tests: | ||||
| @@ -86,6 +86,36 @@ class TestQuantized(mlx_tests.MLXTestCase): | ||||
|             ) | ||||
|             self.assertTrue(mx.allclose(vjp_out[0], expected_out)) | ||||
|  | ||||
|     def test_qmm_jvp(self): | ||||
|         key = mx.random.key(0) | ||||
|         k1, k2 = mx.random.split(key) | ||||
|  | ||||
|         bits = 8 | ||||
|         group_size = 64 | ||||
|         M = 64 | ||||
|         N = 128 | ||||
|         K = 128 | ||||
|  | ||||
|         x = mx.random.normal(shape=(2, M, K), key=k1) | ||||
|         x_tan = mx.ones(shape=(2, M, N)) | ||||
|  | ||||
|         transposes = [True, False] | ||||
|         for transposed in transposes: | ||||
|             w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2) | ||||
|             w_q, scales, biases = mx.quantize(w, group_size, bits) | ||||
|  | ||||
|             def fn(x): | ||||
|                 return mx.quantized_matmul( | ||||
|                     x, w_q, scales, biases, transposed, group_size, bits | ||||
|                 ) | ||||
|  | ||||
|             _, jvp_out = mx.jvp(fn, primals=(x,), tangents=(x_tan,)) | ||||
|  | ||||
|             expected_out = mx.quantized_matmul( | ||||
|                 x_tan, w_q, scales, biases, transposed, group_size, bits | ||||
|             ) | ||||
|             self.assertTrue(mx.allclose(jvp_out[0], expected_out)) | ||||
|  | ||||
|     def test_qmm_shapes(self): | ||||
|         key = mx.random.key(0) | ||||
|         k1, k2 = mx.random.split(key) | ||||
| @@ -117,8 +147,8 @@ class TestQuantized(mlx_tests.MLXTestCase): | ||||
|         tests = product( | ||||
|             [128, 64, 32],  # group_size | ||||
|             [2, 3, 4, 6, 8],  # bits | ||||
|             [512, 1024, 67],  # M | ||||
|             [64, 128, 512, 1024],  # N | ||||
|             [256, 512, 67],  # M | ||||
|             [64, 128],  # N | ||||
|             [0, 1, 3, 8],  # B | ||||
|         ) | ||||
|         for group_size, bits, M, N, B in tests: | ||||
| @@ -144,8 +174,8 @@ class TestQuantized(mlx_tests.MLXTestCase): | ||||
|         tests = product( | ||||
|             [128, 64, 32],  # group_size | ||||
|             [2, 3, 4, 6, 8],  # bits | ||||
|             [512, 1024],  # M | ||||
|             [512, 1024, 67],  # N | ||||
|             [128, 256],  # M | ||||
|             [128, 256, 67],  # N | ||||
|             [0, 1, 3, 8],  # B | ||||
|         ) | ||||
|         for group_size, bits, M, N, B in tests: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun