mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Add move and swap axis, and vmap for slice, concat, and gather (#158)
* add move and swap axis, and vmap for slice, concat, and gather
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
|
||||
@@ -205,6 +204,46 @@ TEST_CASE("test split") {
|
||||
CHECK(array_equal(out[3], array({2, 3, 4})).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test swap and move axes") {
|
||||
// Test swapaxes
|
||||
array a(0.0);
|
||||
CHECK_THROWS(swapaxes(a, 0, 0));
|
||||
|
||||
a = zeros({2});
|
||||
CHECK_THROWS(swapaxes(a, 0, 1));
|
||||
CHECK_EQ(swapaxes(a, 0, 0).shape(), std::vector<int>{2});
|
||||
CHECK_EQ(swapaxes(a, -1, -1).shape(), std::vector<int>{2});
|
||||
|
||||
a = zeros({2, 3, 4});
|
||||
CHECK_THROWS(swapaxes(a, 0, -4));
|
||||
CHECK_THROWS(swapaxes(a, 0, 3));
|
||||
CHECK_THROWS(swapaxes(a, 3, 0));
|
||||
CHECK_THROWS(swapaxes(a, -4, 0));
|
||||
CHECK_EQ(swapaxes(a, 0, 2).shape(), std::vector<int>{4, 3, 2});
|
||||
CHECK_EQ(swapaxes(a, 0, 1).shape(), std::vector<int>{3, 2, 4});
|
||||
CHECK_EQ(swapaxes(a, 0, -1).shape(), std::vector<int>{4, 3, 2});
|
||||
CHECK_EQ(swapaxes(a, -2, 2).shape(), std::vector<int>{2, 4, 3});
|
||||
|
||||
// Test moveaxis
|
||||
a = array(0.0);
|
||||
CHECK_THROWS(moveaxis(a, 0, 0));
|
||||
|
||||
a = zeros({2});
|
||||
CHECK_THROWS(moveaxis(a, 0, 1));
|
||||
CHECK_EQ(moveaxis(a, 0, 0).shape(), std::vector<int>{2});
|
||||
CHECK_EQ(moveaxis(a, -1, -1).shape(), std::vector<int>{2});
|
||||
|
||||
a = zeros({2, 3, 4});
|
||||
CHECK_THROWS(moveaxis(a, 0, -4));
|
||||
CHECK_THROWS(moveaxis(a, 0, 3));
|
||||
CHECK_THROWS(moveaxis(a, 3, 0));
|
||||
CHECK_THROWS(moveaxis(a, -4, 0));
|
||||
CHECK_EQ(moveaxis(a, 0, 2).shape(), std::vector<int>{3, 4, 2});
|
||||
CHECK_EQ(moveaxis(a, 0, 1).shape(), std::vector<int>{3, 2, 4});
|
||||
CHECK_EQ(moveaxis(a, 0, -1).shape(), std::vector<int>{3, 4, 2});
|
||||
CHECK_EQ(moveaxis(a, -2, 2).shape(), std::vector<int>{2, 4, 3});
|
||||
}
|
||||
|
||||
TEST_CASE("test transpose") {
|
||||
array x(1);
|
||||
auto y = transpose(x);
|
||||
@@ -2003,4 +2042,4 @@ TEST_CASE("test eye with negative k offset") {
|
||||
{0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f},
|
||||
{4, 3});
|
||||
CHECK(array_equal(eye_4_k_minus1, expected_eye_4_k_minus1).item<bool>());
|
||||
}
|
||||
}
|
||||
|
@@ -248,3 +248,104 @@ TEST_CASE("test vmap creation ops") {
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test vmap slice") {
|
||||
{
|
||||
auto fun = [](array in) { return slice(in, {4}, {8}, {2}); };
|
||||
auto x = reshape(arange(16), {2, 8});
|
||||
auto out = vmap(fun)(x);
|
||||
auto expected = reshape(array({4, 6, 12, 14}), {2, 2});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
auto fun = [](array in) { return slice(in, {0, 1}, {2, 3}); };
|
||||
auto x = reshape(arange(12), {2, 2, 3});
|
||||
auto out = vmap(fun, 1, 0)(x);
|
||||
auto expected = reshape(array({1, 2, 7, 8, 4, 5, 10, 11}), {2, 2, 2});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test vmap concatenate") {
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
return std::vector<array>{concatenate(inputs, 0)};
|
||||
};
|
||||
auto x = reshape(arange(4), {2, 2});
|
||||
auto y = reshape(arange(4), {2, 2});
|
||||
auto out = vmap(fun)({x, y})[0];
|
||||
auto expected = reshape(array({0, 1, 0, 1, 2, 3, 2, 3}), {2, 4});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
out = vmap(fun, {1, 1})({x, y})[0];
|
||||
expected = reshape(array({0, 2, 0, 2, 1, 3, 1, 3}), {2, 4});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
out = vmap(fun, {0, 1})({x, y})[0];
|
||||
expected = reshape(array({0, 1, 0, 2, 2, 3, 1, 3}), {2, 4});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test vmap gather") {
|
||||
{
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
auto src = inputs[0];
|
||||
auto indices = inputs[1];
|
||||
std::vector<int> slice_sizes = {1, 2, 2};
|
||||
auto out = squeeze(gather(src, indices, 0, slice_sizes), 2);
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
auto x = zeros({2, 2, 2, 2});
|
||||
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
|
||||
auto out = vmap(fun, {0, -1})({x, y})[0];
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 2, 3, 2, 2});
|
||||
out = vmap(fun, {0, -1}, {3})({x, y})[0];
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2, 2});
|
||||
}
|
||||
|
||||
{
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
auto src = inputs[0];
|
||||
auto indices = inputs[1];
|
||||
std::vector<int> slice_sizes = {1, 2, 2};
|
||||
auto out = squeeze(gather(src, indices, 0, slice_sizes), 1);
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
auto x = zeros({2, 2, 2, 2});
|
||||
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
|
||||
auto out = vmap(fun, {0, 0})({x, y})[0];
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
|
||||
}
|
||||
|
||||
{
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
auto src = inputs[0];
|
||||
auto indices = inputs[1];
|
||||
std::vector<int> slice_sizes = {1, 2, 2, 2};
|
||||
auto out = squeeze(gather(src, indices, 0, slice_sizes), 1);
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
auto x = zeros({2, 2, 2, 2});
|
||||
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
|
||||
|
||||
auto out = vmap(fun, {-1, 0})({x, y})[0];
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2, 2});
|
||||
}
|
||||
|
||||
{
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
auto src = inputs[0];
|
||||
auto indices = std::vector<array>(inputs.begin() + 1, inputs.end());
|
||||
std::vector<int> slice_sizes = {1, 1, 2, 2};
|
||||
auto out = squeeze(gather(src, indices, {0, 1}, slice_sizes), {1, 2});
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
auto x = zeros({2, 2, 2, 2});
|
||||
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
|
||||
auto z = array({0, 1, 0, 0, 1, 0}, {2, 3});
|
||||
auto out = vmap(fun, {-1, 0, 0})({x, y, z})[0];
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
|
||||
|
||||
z = array({0, 1, 0, 0, 1, 0}, {3, 2});
|
||||
out = vmap(fun, {-1, 0, 1})({x, y, z})[0];
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user