mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-28 22:28:11 +08:00
Put along axis + fixe for partition grad (#1430)
* put along axis, fixes for partition grad * zeros for arg reduce
This commit is contained in:
@@ -1983,6 +1983,12 @@ TEST_CASE("test take") {
|
||||
CHECK(array_equal(out, zeros({1, 1, 1})).item<bool>());
|
||||
out = take(a, array({0, 1}), 1);
|
||||
CHECK(array_equal(out, zeros({1, 2, 1})).item<bool>());
|
||||
|
||||
// Indices have wrong shape
|
||||
a = zeros({2, 3, 4});
|
||||
CHECK_THROWS(take(a, zeros({1, 3, 4}), 1));
|
||||
CHECK_THROWS(take(a, zeros({2, 3, 7}), 1));
|
||||
CHECK_THROWS(take(a, zeros({2, 3, 2}), 0));
|
||||
}
|
||||
|
||||
TEST_CASE("test take along axis") {
|
||||
@@ -2001,12 +2007,6 @@ TEST_CASE("test take along axis") {
|
||||
out = take_along_axis(a, array({1}), -1);
|
||||
CHECK_EQ(out.item<int>(), 1);
|
||||
|
||||
// Indices have wrong shape
|
||||
a = zeros({2, 3, 4});
|
||||
CHECK_THROWS(take(a, zeros({1, 3, 4}), 1));
|
||||
CHECK_THROWS(take(a, zeros({2, 3, 7}), 1));
|
||||
CHECK_THROWS(take(a, zeros({2, 3, 2}), 0));
|
||||
|
||||
// Empty arrays
|
||||
a = reshape(array({}), {1, 0});
|
||||
CHECK_THROWS(take_along_axis(a, array({1}), 0));
|
||||
@@ -2057,6 +2057,48 @@ TEST_CASE("test take along axis") {
|
||||
.item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test put along axis") {
|
||||
// No zero dim arrays
|
||||
auto a = array(1);
|
||||
auto v = array(1);
|
||||
CHECK_THROWS(put_along_axis(a, array(0), v, 0));
|
||||
|
||||
// Index and array size mismatches
|
||||
a = arange(5);
|
||||
CHECK_THROWS(put_along_axis(a, array({1}), array({0}), 1));
|
||||
CHECK_THROWS(put_along_axis(a, array({1}, {1, 1}), array({0}), 0));
|
||||
CHECK_THROWS(put_along_axis(a, array(1), array(0), -1));
|
||||
|
||||
auto expected = array({0, 0, 2, 3, 4});
|
||||
auto out = put_along_axis(a, array({1}), array({0}), 0);
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
|
||||
// Empty arrays
|
||||
a = reshape(array({}), {1, 0});
|
||||
CHECK_THROWS(put_along_axis(a, array({1}), array({0}), 0));
|
||||
|
||||
auto inds = reshape(astype(array({}), int32), {1, 0});
|
||||
out = take_along_axis(a, inds, 0);
|
||||
eval(out); // Make sure it runs
|
||||
CHECK_EQ(out.shape(), std::vector<int>{1, 0});
|
||||
|
||||
a = array({1, 2, 3, 4}, {2, 2});
|
||||
inds = array({0, 1}, {1, 2});
|
||||
out = put_along_axis(a, inds, array({0}), 0);
|
||||
expected = array({0, 2, 3, 0}, {2, 2});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
|
||||
inds = array({0, 0, 1, 1}, {2, 2}, int32);
|
||||
auto values = array({2, 3, 4, 5}, {2, 2}, int32);
|
||||
out = put_along_axis(a, inds, values, 0);
|
||||
CHECK(array_equal(out, array({2, 3, 4, 5}, {2, 2})).item<bool>());
|
||||
|
||||
inds = array({0, 1}, {2, 1});
|
||||
out = put_along_axis(a, inds, array({0}), 1);
|
||||
expected = array({0, 2, 3, 0}, {2, 2});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test scatter") {
|
||||
// More indices than dimensions
|
||||
CHECK_THROWS(scatter(array(0), array({1}), array(1), 0));
|
||||
|
||||
Reference in New Issue
Block a user