mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -395,7 +395,7 @@ TEST_CASE("test split") {
|
||||
CHECK_EQ(out[1].shape(), Shape{8, 4});
|
||||
CHECK_EQ(out[2].shape(), Shape{8, 4});
|
||||
|
||||
out = split(x, std::vector<int>{});
|
||||
out = split(x, Shape{});
|
||||
CHECK_EQ(out.size(), 1);
|
||||
CHECK_EQ(out[0].shape(), x.shape());
|
||||
|
||||
@@ -405,25 +405,25 @@ TEST_CASE("test split") {
|
||||
CHECK_EQ(out[1].shape(), Shape{4, 12});
|
||||
CHECK_EQ(out[2].shape(), Shape{1, 12});
|
||||
|
||||
out = split(x, std::vector<int>{20});
|
||||
out = split(x, Shape{20});
|
||||
CHECK_EQ(out.size(), 2);
|
||||
CHECK_EQ(out[0].shape(), Shape{8, 12});
|
||||
CHECK_EQ(out[1].shape(), Shape{0, 12});
|
||||
|
||||
// Negative indices
|
||||
out = split(x, std::vector<int>{-5});
|
||||
out = split(x, Shape{-5});
|
||||
CHECK_EQ(out[0].shape(), Shape{3, 12});
|
||||
CHECK_EQ(out[1].shape(), Shape{5, 12});
|
||||
|
||||
// Different axis
|
||||
out = split(x, std::vector<int>{2, 8}, 1);
|
||||
out = split(x, {2, 8}, 1);
|
||||
CHECK_EQ(out[0].shape(), Shape{8, 2});
|
||||
CHECK_EQ(out[1].shape(), Shape{8, 6});
|
||||
CHECK_EQ(out[2].shape(), Shape{8, 4});
|
||||
|
||||
// Out of order indices
|
||||
x = arange(5);
|
||||
out = split(x, std::vector<int>{2, 1, 2});
|
||||
out = split(x, {2, 1, 2});
|
||||
CHECK(array_equal(out[0], array({0, 1})).item<bool>());
|
||||
CHECK(array_equal(out[1], array({})).item<bool>());
|
||||
CHECK(array_equal(out[2], array({1})).item<bool>());
|
||||
|
||||
@@ -611,8 +611,8 @@ TEST_CASE("test categorical") {
|
||||
CHECK_THROWS(categorical(logits, -3));
|
||||
|
||||
// Invalid requested shapes
|
||||
CHECK_THROWS(categorical(logits, 1, std::vector<int>{1}));
|
||||
CHECK_THROWS(categorical(logits, 1, std::vector<int>{11}));
|
||||
CHECK_THROWS(categorical(logits, 1, Shape{1}));
|
||||
CHECK_THROWS(categorical(logits, 1, Shape{11}));
|
||||
CHECK_THROWS(categorical(logits, 1, {10, 1}));
|
||||
|
||||
CHECK_EQ(categorical(logits, -1).shape(), Shape{10});
|
||||
|
||||
@@ -335,8 +335,7 @@ TEST_CASE("test vmap gather") {
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
auto src = inputs[0];
|
||||
auto indices = inputs[1];
|
||||
std::vector<int> slice_sizes = {1, 2, 2};
|
||||
auto out = squeeze(gather(src, indices, 0, slice_sizes), 2);
|
||||
auto out = squeeze(gather(src, indices, 0, {1, 2, 2}), 2);
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
auto x = zeros({2, 2, 2, 2});
|
||||
@@ -351,8 +350,7 @@ TEST_CASE("test vmap gather") {
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
auto src = inputs[0];
|
||||
auto indices = inputs[1];
|
||||
std::vector<int> slice_sizes = {1, 2, 2};
|
||||
auto out = squeeze(gather(src, indices, 0, slice_sizes), 1);
|
||||
auto out = squeeze(gather(src, indices, 0, {1, 2, 2}), 1);
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
auto x = zeros({2, 2, 2, 2});
|
||||
@@ -365,8 +363,7 @@ TEST_CASE("test vmap gather") {
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
auto src = inputs[0];
|
||||
auto indices = inputs[1];
|
||||
std::vector<int> slice_sizes = {1, 2, 2, 2};
|
||||
auto out = squeeze(gather(src, indices, 0, slice_sizes), 1);
|
||||
auto out = squeeze(gather(src, indices, 0, {1, 2, 2, 2}), 1);
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
auto x = zeros({2, 2, 2, 2});
|
||||
@@ -380,8 +377,7 @@ TEST_CASE("test vmap gather") {
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
auto src = inputs[0];
|
||||
auto indices = std::vector<array>(inputs.begin() + 1, inputs.end());
|
||||
std::vector<int> slice_sizes = {1, 1, 2, 2};
|
||||
auto out = squeeze(gather(src, indices, {0, 1}, slice_sizes), {1, 2});
|
||||
auto out = squeeze(gather(src, indices, {0, 1}, {1, 1, 2, 2}), {1, 2});
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
auto x = zeros({2, 2, 2, 2});
|
||||
|
||||
Reference in New Issue
Block a user