mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 21:04:41 +08:00
Add a SliceUpdate op and primitive (#850)
* Enable copy to work with int64 strides * Fix uniform buffer indices or copy kernel arguments * Update utils.h * Remove manual unrolling of elem to loc loop * GPU copy updated to handle negative strides * Add slice update primitive
This commit is contained in:
@@ -198,6 +198,31 @@ TEST_CASE("test slice") {
|
||||
CHECK(array_equal(out, array({0, 2, 4, 6}, {2, 2})).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test slice update") {
|
||||
array x = array({0., 0., 0., 0., 0., 0., 0., 0.}, {8}, float32);
|
||||
array y = array(
|
||||
{
|
||||
1.,
|
||||
2.,
|
||||
3.,
|
||||
4.,
|
||||
},
|
||||
{4},
|
||||
float32);
|
||||
|
||||
auto out = slice_update(x, y, {2}, {6}, {1});
|
||||
CHECK(array_equal(slice(out, {2}, {6}, {1}), y).item<bool>());
|
||||
|
||||
out = slice_update(x, y, {5}, {1}, {-1});
|
||||
CHECK(array_equal(slice(out, {5}, {1}, {-1}), y).item<bool>());
|
||||
|
||||
x = reshape(x, {2, 4});
|
||||
out = slice_update(x, y, {0, 0}, {2, 4}, {1, 1});
|
||||
out = reshape(out, {8});
|
||||
CHECK(array_equal(slice(out, {0}, {4}, {1}), y).item<bool>());
|
||||
CHECK(array_equal(slice(out, {4}, {8}, {1}), y).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test split") {
|
||||
array x = array(1);
|
||||
CHECK_THROWS(split(x, 0));
|
||||
|
Reference in New Issue
Block a user