mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	stable cumprod grad at 0 (#1167)
This commit is contained in:
		| @@ -2748,12 +2748,52 @@ std::vector<array> Scan::vjp( | |||||||
|   if (reduce_type_ == Scan::Sum) { |   if (reduce_type_ == Scan::Sum) { | ||||||
|     return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())}; |     return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())}; | ||||||
|   } else if (reduce_type_ == Scan::Prod) { |   } else if (reduce_type_ == Scan::Prod) { | ||||||
|     // TODO: Make it numerically stable when we introduce where() |     auto in = primals[0]; | ||||||
|     auto prod = outputs[0]; |     // Find the location of the first 0 and set it to 1: | ||||||
|     auto partial_grads = multiply(prod, cotangents[0], stream()); |     // - A: Exclusive cumprod | ||||||
|     auto accum_grads = |     // - B: Inclusive cumprod | ||||||
|         cumsum(partial_grads, axis_, !reverse_, inclusive_, stream()); |     // - Find the location that is 0 in A and not zero B | ||||||
|     return {divide(accum_grads, primals[0], stream())}; |     // Compute the gradient by: | ||||||
|  |     // - Compute the regular gradient for everything before the first zero | ||||||
|  |     // - Set the first zero to 1 and redo the computation, use this for the | ||||||
|  |     //   gradient of the first zero | ||||||
|  |     // - Everything after the first zero has a gradient of 0 | ||||||
|  |  | ||||||
|  |     // Get inclusive and exclusive cum prods | ||||||
|  |     auto cprod_exclusive = cumprod(in, axis_, reverse_, !inclusive_, stream()); | ||||||
|  |     auto cprod_inclusive = outputs[0]; | ||||||
|  |     if (!inclusive_) { | ||||||
|  |       std::swap(cprod_exclusive, cprod_inclusive); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // Make the mask for the first zero | ||||||
|  |     auto z = array(0, in.dtype()); | ||||||
|  |     auto eq_zero = equal(cprod_inclusive, z, stream()); | ||||||
|  |     auto first_zero = | ||||||
|  |         logical_and(eq_zero, not_equal(cprod_exclusive, z, stream()), stream()); | ||||||
|  |  | ||||||
|  |     auto to_partial_grad = [this, &cotangents](const array& arr) { | ||||||
|  |       return cumsum( | ||||||
|  |           multiply(arr, cotangents[0], stream()), | ||||||
|  |           axis_, | ||||||
|  |           !reverse_, | ||||||
|  |           inclusive_, | ||||||
|  |           stream()); | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     auto cprod_with_one = cumprod( | ||||||
|  |         where(first_zero, array(1, in.dtype()), in, stream()), | ||||||
|  |         axis_, | ||||||
|  |         reverse_, | ||||||
|  |         inclusive_, | ||||||
|  |         stream()); | ||||||
|  |     auto grad_with_one = to_partial_grad(cprod_with_one); | ||||||
|  |     auto grad = divide(to_partial_grad(outputs[0]), in, stream()); | ||||||
|  |     return {where( | ||||||
|  |         first_zero, | ||||||
|  |         grad_with_one, | ||||||
|  |         where(eq_zero, z, grad, stream()), | ||||||
|  |         stream())}; | ||||||
|   } else { |   } else { | ||||||
|     // Can probably be implemented by equals and then cummax to make the mask |     // Can probably be implemented by equals and then cummax to make the mask | ||||||
|     throw std::runtime_error("VJP is not implemented for cumulative min/max"); |     throw std::runtime_error("VJP is not implemented for cumulative min/max"); | ||||||
|   | |||||||
| @@ -423,6 +423,79 @@ class TestAutograd(mlx_tests.MLXTestCase): | |||||||
|         grad = mx.grad(fun)(mx.array(1.0), mx.array(1.0)) |         grad = mx.grad(fun)(mx.array(1.0), mx.array(1.0)) | ||||||
|         self.assertEqual(grad.item(), 1.0) |         self.assertEqual(grad.item(), 1.0) | ||||||
|  |  | ||||||
|  |     def test_cumprod_grad(self): | ||||||
|  |         def fun(y): | ||||||
|  |             return mx.cumprod(y).sum() | ||||||
|  |  | ||||||
|  |         y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0]) | ||||||
|  |         out = mx.grad(fun)(y) | ||||||
|  |         expected = mx.array([20.0, 38.0, 18.0, 16.0, 8.0]) | ||||||
|  |         self.assertTrue(mx.allclose(out, expected)) | ||||||
|  |  | ||||||
|  |         y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0]) | ||||||
|  |         out = mx.grad(fun)(y) | ||||||
|  |         expected = mx.array([1.0, 38.0, 0.0, 0.0, 0.0]) | ||||||
|  |         self.assertTrue(mx.allclose(out, expected)) | ||||||
|  |  | ||||||
|  |         y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0]) | ||||||
|  |         out = mx.grad(fun)(y) | ||||||
|  |         expected = mx.array([1.0, 6.0, 0.0, 0.0, 0.0]) | ||||||
|  |         self.assertTrue(mx.allclose(out, expected)) | ||||||
|  |  | ||||||
|  |         def fun(y): | ||||||
|  |             return mx.cumprod(y, inclusive=False).sum() | ||||||
|  |  | ||||||
|  |         y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0]) | ||||||
|  |         out = mx.grad(fun)(y) | ||||||
|  |         expected = mx.array([8.0, 14.0, 6.0, 4.0, 0.0]) | ||||||
|  |         self.assertTrue(mx.allclose(out, expected)) | ||||||
|  |  | ||||||
|  |         y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0]) | ||||||
|  |         out = mx.grad(fun)(y) | ||||||
|  |         expected = mx.array([1.0, 14.0, 0.0, 0.0, 0.0]) | ||||||
|  |         self.assertTrue(mx.allclose(out, expected)) | ||||||
|  |  | ||||||
|  |         y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0]) | ||||||
|  |         out = mx.grad(fun)(y) | ||||||
|  |         expected = mx.array([1.0, 6.0, 0.0, 0.0, 0.0]) | ||||||
|  |         self.assertTrue(mx.allclose(out, expected)) | ||||||
|  |  | ||||||
|  |         def fun(y): | ||||||
|  |             return mx.cumprod(y, inclusive=False, reverse=True).sum() | ||||||
|  |  | ||||||
|  |         y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0]) | ||||||
|  |         out = mx.grad(fun)(y) | ||||||
|  |         expected = mx.array([0.0, 12.0, 12.0, 15.0, 11.0]) | ||||||
|  |         self.assertTrue(mx.allclose(out, expected)) | ||||||
|  |  | ||||||
|  |         y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0]) | ||||||
|  |         out = mx.grad(fun)(y) | ||||||
|  |         expected = mx.array([0.0, 12.0, 6.0, 9.0, 7.0]) | ||||||
|  |         self.assertTrue(mx.allclose(out, expected)) | ||||||
|  |  | ||||||
|  |         y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0]) | ||||||
|  |         out = mx.grad(fun)(y) | ||||||
|  |         expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0]) | ||||||
|  |         self.assertTrue(mx.allclose(out, expected)) | ||||||
|  |  | ||||||
|  |         def fun(y): | ||||||
|  |             return mx.cumprod(y, reverse=True).sum() | ||||||
|  |  | ||||||
|  |         y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0]) | ||||||
|  |         out = mx.grad(fun)(y) | ||||||
|  |         expected = mx.array([12.0, 36.0, 24.0, 27.0, 19.0]) | ||||||
|  |         self.assertTrue(mx.allclose(out, expected)) | ||||||
|  |  | ||||||
|  |         y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0]) | ||||||
|  |         out = mx.grad(fun)(y) | ||||||
|  |         expected = mx.array([0.0, 36.0, 6.0, 9.0, 7.0]) | ||||||
|  |         self.assertTrue(mx.allclose(out, expected)) | ||||||
|  |  | ||||||
|  |         y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0]) | ||||||
|  |         out = mx.grad(fun)(y) | ||||||
|  |         expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0]) | ||||||
|  |         self.assertTrue(mx.allclose(out, expected)) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun