mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
* working c++ trace implementation * updated throw + added overloads * added python binding for trace function * pre-commit reformatting * add trace to docs * resolve comments * remove to_stream call
This commit is contained in:
@@ -3327,3 +3327,20 @@ TEST_CASE("test conv1d") {
|
||||
CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test trace") {
|
||||
auto in = eye(3);
|
||||
auto out = trace(in).item<float>();
|
||||
CHECK_EQ(out, 3.0);
|
||||
|
||||
in = array({1, 2, 3, 4, 5, 6, 7, 8, 9}, {3, 3}, int32);
|
||||
auto out2 = trace(in).item<int>();
|
||||
CHECK_EQ(out2, 15);
|
||||
|
||||
in = reshape(arange(8), {2, 2, 2});
|
||||
auto out3 = trace(in, 0, 0, 1);
|
||||
CHECK(array_equal(out3, array({6, 8}, {2})).item<bool>());
|
||||
|
||||
auto out4 = trace(in, 0, 1, 2, float32);
|
||||
CHECK(array_equal(out4, array({3, 11}, {2})).item<bool>());
|
||||
}
|
||||
|
Reference in New Issue
Block a user