Add a SliceUpdate op and primitive (#850)

* Enable copy to work with int64 strides
* Fix uniform buffer indices or copy kernel arguments
* Update utils.h
* Remove manual unrolling of elem to loc loop
* GPU copy updated to handle negative strides
* Add slice update primitive
This commit is contained in:
Jagrit Digani
2024-03-20 10:39:25 -07:00
committed by GitHub
parent 73a8c090e0
commit cec8661113
21 changed files with 1147 additions and 506 deletions

View File

@@ -198,6 +198,31 @@ TEST_CASE("test slice") {
CHECK(array_equal(out, array({0, 2, 4, 6}, {2, 2})).item<bool>());
}
TEST_CASE("test slice update") {
array x = array({0., 0., 0., 0., 0., 0., 0., 0.}, {8}, float32);
array y = array(
{
1.,
2.,
3.,
4.,
},
{4},
float32);
auto out = slice_update(x, y, {2}, {6}, {1});
CHECK(array_equal(slice(out, {2}, {6}, {1}), y).item<bool>());
out = slice_update(x, y, {5}, {1}, {-1});
CHECK(array_equal(slice(out, {5}, {1}, {-1}), y).item<bool>());
x = reshape(x, {2, 4});
out = slice_update(x, y, {0, 0}, {2, 4}, {1, 1});
out = reshape(out, {8});
CHECK(array_equal(slice(out, {0}, {4}, {1}), y).item<bool>());
CHECK(array_equal(slice(out, {4}, {8}, {1}), y).item<bool>());
}
TEST_CASE("test split") {
array x = array(1);
CHECK_THROWS(split(x, 0));