Implement diagonal operator (#562)

* Implement diagonal operator

This implements mx.diagonal in operator level, inspired by
@ManishAradwad.

* added `mx.diag` with tests

* corrected few things

* nits in bindings

* updates to diag

---------

Co-authored-by: ManishAradwad <manisharadwad@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Jacket
2024-01-30 11:45:48 -06:00
committed by GitHub
parent 65d0b8df9f
commit 3f7aba8498
8 changed files with 309 additions and 4 deletions

View File

@@ -1,6 +1,5 @@
// Copyright © 2023 Apple Inc.
#include <cmath>
#include <iostream> // TODO
#include <numeric>
#include "doctest/doctest.h"
@@ -2634,3 +2633,86 @@ TEST_CASE("test divmod") {
eval(out_holder);
CHECK_EQ(out_holder[0].item<float>(), 1.0);
}
TEST_CASE("test diagonal") {
auto x = array({0, 1, 2, 3, 4, 5, 6, 7}, {4, 2});
auto out = diagonal(x);
CHECK(array_equal(out, array({0, 3}, {2})).item<bool>());
CHECK_THROWS_AS(diagonal(x, 1, 6, 0), std::out_of_range);
CHECK_THROWS_AS(diagonal(x, 1, 0, -3), std::out_of_range);
x = array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 4});
out = diagonal(x, 2, 1, 0);
CHECK(array_equal(out, array({8}, {1})).item<bool>());
out = diagonal(x, -1, 0, 1);
CHECK(array_equal(out, array({4, 9}, {2})).item<bool>());
out = diagonal(x, -5, 0, 1);
eval(out);
CHECK_EQ(out.shape(), std::vector<int>{0});
x = array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 2, 2});
out = diagonal(x, 1, 0, 1);
CHECK(array_equal(out, array({2, 3}, {2, 1})).item<bool>());
out = diagonal(x, 0, 2, 0);
CHECK(array_equal(out, array({0, 5, 2, 7}, {2, 2})).item<bool>());
out = diagonal(x, 1, -1, 0);
CHECK(array_equal(out, array({4, 9, 6, 11}, {2, 2})).item<bool>());
x = reshape(arange(16), {2, 2, 2, 2});
out = diagonal(x, 0, 0, 1);
CHECK(array_equal(out, array({0, 12, 1, 13, 2, 14, 3, 15}, {2, 2, 2}))
.item<bool>());
CHECK_THROWS_AS(diagonal(x, 0, 1, 1), std::invalid_argument);
x = array({0, 1}, {2});
CHECK_THROWS_AS(diagonal(x, 0, 0, 1), std::invalid_argument);
}
TEST_CASE("test diag") {
// To few or too many dimensions
CHECK_THROWS(diag(array(0.0)));
CHECK_THROWS(diag(array({0.0}, {1, 1, 1})));
// Test with 1D array
auto x = array({0, 1, 2, 3}, {4});
auto out = diag(x, 0);
CHECK(
array_equal(
out, array({0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3}, {4, 4}))
.item<bool>());
out = diag(x, 1);
CHECK(array_equal(
out,
array(
{0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
2, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0},
{5, 5}))
.item<bool>());
out = diag(x, -1);
CHECK(array_equal(
out,
array(
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 3, 0},
{5, 5}))
.item<bool>());
// Test with 2D array
x = array({0, 1, 2, 3, 4, 5, 6, 7, 8}, {3, 3});
out = diag(x, 0);
CHECK(array_equal(out, array({0, 4, 8}, {3})).item<bool>());
out = diag(x, 1);
CHECK(array_equal(out, array({1, 5}, {2})).item<bool>());
out = diag(x, -1);
CHECK(array_equal(out, array({3, 7}, {2})).item<bool>());
}