mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 09:58:17 +08:00
Add move and swap axis, and vmap for slice, concat, and gather (#158)
* add move and swap axis, and vmap for slice, concat, and gather
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
|
||||
@@ -205,6 +204,46 @@ TEST_CASE("test split") {
|
||||
CHECK(array_equal(out[3], array({2, 3, 4})).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test swap and move axes") {
|
||||
// Test swapaxes
|
||||
array a(0.0);
|
||||
CHECK_THROWS(swapaxes(a, 0, 0));
|
||||
|
||||
a = zeros({2});
|
||||
CHECK_THROWS(swapaxes(a, 0, 1));
|
||||
CHECK_EQ(swapaxes(a, 0, 0).shape(), std::vector<int>{2});
|
||||
CHECK_EQ(swapaxes(a, -1, -1).shape(), std::vector<int>{2});
|
||||
|
||||
a = zeros({2, 3, 4});
|
||||
CHECK_THROWS(swapaxes(a, 0, -4));
|
||||
CHECK_THROWS(swapaxes(a, 0, 3));
|
||||
CHECK_THROWS(swapaxes(a, 3, 0));
|
||||
CHECK_THROWS(swapaxes(a, -4, 0));
|
||||
CHECK_EQ(swapaxes(a, 0, 2).shape(), std::vector<int>{4, 3, 2});
|
||||
CHECK_EQ(swapaxes(a, 0, 1).shape(), std::vector<int>{3, 2, 4});
|
||||
CHECK_EQ(swapaxes(a, 0, -1).shape(), std::vector<int>{4, 3, 2});
|
||||
CHECK_EQ(swapaxes(a, -2, 2).shape(), std::vector<int>{2, 4, 3});
|
||||
|
||||
// Test moveaxis
|
||||
a = array(0.0);
|
||||
CHECK_THROWS(moveaxis(a, 0, 0));
|
||||
|
||||
a = zeros({2});
|
||||
CHECK_THROWS(moveaxis(a, 0, 1));
|
||||
CHECK_EQ(moveaxis(a, 0, 0).shape(), std::vector<int>{2});
|
||||
CHECK_EQ(moveaxis(a, -1, -1).shape(), std::vector<int>{2});
|
||||
|
||||
a = zeros({2, 3, 4});
|
||||
CHECK_THROWS(moveaxis(a, 0, -4));
|
||||
CHECK_THROWS(moveaxis(a, 0, 3));
|
||||
CHECK_THROWS(moveaxis(a, 3, 0));
|
||||
CHECK_THROWS(moveaxis(a, -4, 0));
|
||||
CHECK_EQ(moveaxis(a, 0, 2).shape(), std::vector<int>{3, 4, 2});
|
||||
CHECK_EQ(moveaxis(a, 0, 1).shape(), std::vector<int>{3, 2, 4});
|
||||
CHECK_EQ(moveaxis(a, 0, -1).shape(), std::vector<int>{3, 4, 2});
|
||||
CHECK_EQ(moveaxis(a, -2, 2).shape(), std::vector<int>{2, 4, 3});
|
||||
}
|
||||
|
||||
TEST_CASE("test transpose") {
|
||||
array x(1);
|
||||
auto y = transpose(x);
|
||||
@@ -2003,4 +2042,4 @@ TEST_CASE("test eye with negative k offset") {
|
||||
{0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f},
|
||||
{4, 3});
|
||||
CHECK(array_equal(eye_4_k_minus1, expected_eye_4_k_minus1).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user