Add a SliceUpdate op and primitive (#850)

* Enable copy to work with int64 strides
* Fix uniform buffer indices or copy kernel arguments
* Update utils.h
* Remove manual unrolling of elem to loc loop
* GPU copy updated to handle negative strides
* Add slice update primitive
This commit is contained in:
Jagrit Digani
2024-03-20 10:39:25 -07:00
committed by GitHub
parent 73a8c090e0
commit cec8661113
21 changed files with 1147 additions and 506 deletions

View File

@@ -206,7 +206,7 @@ inline auto collapse_batches(const array& a, const array& b) {
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
auto [batch_shape, batch_strides] =
collapse_contiguous_dims(A_bshape, {A_bstride, B_bstride});
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
auto A_batch_stride = batch_strides[0];
auto B_batch_stride = batch_strides[1];
@@ -237,8 +237,8 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
std::vector<size_t> C_bstride{c.strides().begin(), c.strides().end() - 2};
auto [batch_shape, batch_strides] =
collapse_contiguous_dims(A_bshape, {A_bstride, B_bstride, C_bstride});
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
auto A_batch_stride = batch_strides[0];
auto B_batch_stride = batch_strides[1];