mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 07:18:29 +08:00
Implement diagonal operator (#562)
* Implement diagonal operator This implements mx.diagonal in operator level, inspired by @ManishAradwad. * added `mx.diag` with tests * corrected few things * nits in bindings * updates to diag --------- Co-authored-by: ManishAradwad <manisharadwad@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
#include <cmath>
|
||||
#include <iostream> // TODO
|
||||
#include <numeric>
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
@@ -2634,3 +2633,86 @@ TEST_CASE("test divmod") {
|
||||
eval(out_holder);
|
||||
CHECK_EQ(out_holder[0].item<float>(), 1.0);
|
||||
}
|
||||
|
||||
TEST_CASE("test diagonal") {
|
||||
auto x = array({0, 1, 2, 3, 4, 5, 6, 7}, {4, 2});
|
||||
auto out = diagonal(x);
|
||||
CHECK(array_equal(out, array({0, 3}, {2})).item<bool>());
|
||||
|
||||
CHECK_THROWS_AS(diagonal(x, 1, 6, 0), std::out_of_range);
|
||||
CHECK_THROWS_AS(diagonal(x, 1, 0, -3), std::out_of_range);
|
||||
|
||||
x = array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 4});
|
||||
out = diagonal(x, 2, 1, 0);
|
||||
CHECK(array_equal(out, array({8}, {1})).item<bool>());
|
||||
|
||||
out = diagonal(x, -1, 0, 1);
|
||||
CHECK(array_equal(out, array({4, 9}, {2})).item<bool>());
|
||||
|
||||
out = diagonal(x, -5, 0, 1);
|
||||
eval(out);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{0});
|
||||
|
||||
x = array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 2, 2});
|
||||
out = diagonal(x, 1, 0, 1);
|
||||
CHECK(array_equal(out, array({2, 3}, {2, 1})).item<bool>());
|
||||
|
||||
out = diagonal(x, 0, 2, 0);
|
||||
CHECK(array_equal(out, array({0, 5, 2, 7}, {2, 2})).item<bool>());
|
||||
|
||||
out = diagonal(x, 1, -1, 0);
|
||||
CHECK(array_equal(out, array({4, 9, 6, 11}, {2, 2})).item<bool>());
|
||||
|
||||
x = reshape(arange(16), {2, 2, 2, 2});
|
||||
out = diagonal(x, 0, 0, 1);
|
||||
CHECK(array_equal(out, array({0, 12, 1, 13, 2, 14, 3, 15}, {2, 2, 2}))
|
||||
.item<bool>());
|
||||
|
||||
CHECK_THROWS_AS(diagonal(x, 0, 1, 1), std::invalid_argument);
|
||||
|
||||
x = array({0, 1}, {2});
|
||||
CHECK_THROWS_AS(diagonal(x, 0, 0, 1), std::invalid_argument);
|
||||
}
|
||||
|
||||
TEST_CASE("test diag") {
|
||||
// To few or too many dimensions
|
||||
CHECK_THROWS(diag(array(0.0)));
|
||||
CHECK_THROWS(diag(array({0.0}, {1, 1, 1})));
|
||||
|
||||
// Test with 1D array
|
||||
auto x = array({0, 1, 2, 3}, {4});
|
||||
auto out = diag(x, 0);
|
||||
CHECK(
|
||||
array_equal(
|
||||
out, array({0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3}, {4, 4}))
|
||||
.item<bool>());
|
||||
|
||||
out = diag(x, 1);
|
||||
CHECK(array_equal(
|
||||
out,
|
||||
array(
|
||||
{0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
|
||||
2, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0},
|
||||
{5, 5}))
|
||||
.item<bool>());
|
||||
|
||||
out = diag(x, -1);
|
||||
CHECK(array_equal(
|
||||
out,
|
||||
array(
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
|
||||
0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 3, 0},
|
||||
{5, 5}))
|
||||
.item<bool>());
|
||||
|
||||
// Test with 2D array
|
||||
x = array({0, 1, 2, 3, 4, 5, 6, 7, 8}, {3, 3});
|
||||
out = diag(x, 0);
|
||||
CHECK(array_equal(out, array({0, 4, 8}, {3})).item<bool>());
|
||||
|
||||
out = diag(x, 1);
|
||||
CHECK(array_equal(out, array({1, 5}, {2})).item<bool>());
|
||||
|
||||
out = diag(x, -1);
|
||||
CHECK(array_equal(out, array({3, 7}, {2})).item<bool>());
|
||||
}
|
||||
|
Reference in New Issue
Block a user