add mx.trace (#1143) (#1147)

* working c++ trace implementation

* updated throw + added overloads

* added python binding for trace function

* pre-commit reformatting

* add trace to docs

* resolve comments

* remove to_stream call
This commit is contained in:
Abe Leininger
2024-05-22 18:50:27 -04:00
committed by GitHub
parent e110ca11e2
commit 79ef49b2c2
6 changed files with 161 additions and 0 deletions

View File

@@ -3327,3 +3327,20 @@ TEST_CASE("test conv1d") {
CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item<bool>());
}
}
TEST_CASE("test trace") {
auto in = eye(3);
auto out = trace(in).item<float>();
CHECK_EQ(out, 3.0);
in = array({1, 2, 3, 4, 5, 6, 7, 8, 9}, {3, 3}, int32);
auto out2 = trace(in).item<int>();
CHECK_EQ(out2, 15);
in = reshape(arange(8), {2, 2, 2});
auto out3 = trace(in, 0, 0, 1);
CHECK(array_equal(out3, array({6, 8}, {2})).item<bool>());
auto out4 = trace(in, 0, 1, 2, float32);
CHECK(array_equal(out4, array({3, 11}, {2})).item<bool>());
}