mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Implement diagonal operator (#562)
* Implement diagonal operator This implements mx.diagonal in operator level, inspired by @ManishAradwad. * added `mx.diag` with tests * corrected few things * nits in bindings * updates to diag --------- Co-authored-by: ManishAradwad <manisharadwad@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		| @@ -1,6 +1,5 @@ | ||||
| // Copyright © 2023 Apple Inc. | ||||
| #include <cmath> | ||||
| #include <iostream> // TODO | ||||
| #include <numeric> | ||||
|  | ||||
| #include "doctest/doctest.h" | ||||
| @@ -2634,3 +2633,86 @@ TEST_CASE("test divmod") { | ||||
|   eval(out_holder); | ||||
|   CHECK_EQ(out_holder[0].item<float>(), 1.0); | ||||
| } | ||||
|  | ||||
| TEST_CASE("test diagonal") { | ||||
|   auto x = array({0, 1, 2, 3, 4, 5, 6, 7}, {4, 2}); | ||||
|   auto out = diagonal(x); | ||||
|   CHECK(array_equal(out, array({0, 3}, {2})).item<bool>()); | ||||
|  | ||||
|   CHECK_THROWS_AS(diagonal(x, 1, 6, 0), std::out_of_range); | ||||
|   CHECK_THROWS_AS(diagonal(x, 1, 0, -3), std::out_of_range); | ||||
|  | ||||
|   x = array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 4}); | ||||
|   out = diagonal(x, 2, 1, 0); | ||||
|   CHECK(array_equal(out, array({8}, {1})).item<bool>()); | ||||
|  | ||||
|   out = diagonal(x, -1, 0, 1); | ||||
|   CHECK(array_equal(out, array({4, 9}, {2})).item<bool>()); | ||||
|  | ||||
|   out = diagonal(x, -5, 0, 1); | ||||
|   eval(out); | ||||
|   CHECK_EQ(out.shape(), std::vector<int>{0}); | ||||
|  | ||||
|   x = array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 2, 2}); | ||||
|   out = diagonal(x, 1, 0, 1); | ||||
|   CHECK(array_equal(out, array({2, 3}, {2, 1})).item<bool>()); | ||||
|  | ||||
|   out = diagonal(x, 0, 2, 0); | ||||
|   CHECK(array_equal(out, array({0, 5, 2, 7}, {2, 2})).item<bool>()); | ||||
|  | ||||
|   out = diagonal(x, 1, -1, 0); | ||||
|   CHECK(array_equal(out, array({4, 9, 6, 11}, {2, 2})).item<bool>()); | ||||
|  | ||||
|   x = reshape(arange(16), {2, 2, 2, 2}); | ||||
|   out = diagonal(x, 0, 0, 1); | ||||
|   CHECK(array_equal(out, array({0, 12, 1, 13, 2, 14, 3, 15}, {2, 2, 2})) | ||||
|             .item<bool>()); | ||||
|  | ||||
|   CHECK_THROWS_AS(diagonal(x, 0, 1, 1), std::invalid_argument); | ||||
|  | ||||
|   x = array({0, 1}, {2}); | ||||
|   CHECK_THROWS_AS(diagonal(x, 0, 0, 1), std::invalid_argument); | ||||
| } | ||||
|  | ||||
| TEST_CASE("test diag") { | ||||
|   // To few or too many dimensions | ||||
|   CHECK_THROWS(diag(array(0.0))); | ||||
|   CHECK_THROWS(diag(array({0.0}, {1, 1, 1}))); | ||||
|  | ||||
|   // Test with 1D array | ||||
|   auto x = array({0, 1, 2, 3}, {4}); | ||||
|   auto out = diag(x, 0); | ||||
|   CHECK( | ||||
|       array_equal( | ||||
|           out, array({0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3}, {4, 4})) | ||||
|           .item<bool>()); | ||||
|  | ||||
|   out = diag(x, 1); | ||||
|   CHECK(array_equal( | ||||
|             out, | ||||
|             array( | ||||
|                 {0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, | ||||
|                  2, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0}, | ||||
|                 {5, 5})) | ||||
|             .item<bool>()); | ||||
|  | ||||
|   out = diag(x, -1); | ||||
|   CHECK(array_equal( | ||||
|             out, | ||||
|             array( | ||||
|                 {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, | ||||
|                  0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 3, 0}, | ||||
|                 {5, 5})) | ||||
|             .item<bool>()); | ||||
|  | ||||
|   // Test with 2D array | ||||
|   x = array({0, 1, 2, 3, 4, 5, 6, 7, 8}, {3, 3}); | ||||
|   out = diag(x, 0); | ||||
|   CHECK(array_equal(out, array({0, 4, 8}, {3})).item<bool>()); | ||||
|  | ||||
|   out = diag(x, 1); | ||||
|   CHECK(array_equal(out, array({1, 5}, {2})).item<bool>()); | ||||
|  | ||||
|   out = diag(x, -1); | ||||
|   CHECK(array_equal(out, array({3, 7}, {2})).item<bool>()); | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jacket
					Jacket