Dynamic slicing (#1741)

* dynamic slice and slice update

* python bindings + tests + fix set item

* fix compile issue

* comment

* fix jit
This commit is contained in:
Awni Hannun
2025-01-07 14:02:16 -08:00
committed by GitHub
parent c9c81d0584
commit 516ded618b
27 changed files with 941 additions and 75 deletions

View File

@@ -161,3 +161,78 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
idx.y += dst_xstride;
}
}
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_gg_dynamic_nd1(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
constant const int64_t& src_offset [[buffer(6)]],
constant const int64_t& dst_offset [[buffer(7)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
auto dst_idx = elem_to_loc_1<IdxT>(index, dst_stride);
dst[dst_idx + dst_offset] = src[src_idx + src_offset];
}
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_gg_dynamic_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int64_t& src_offset [[buffer(6)]],
constant const int64_t& dst_offset [[buffer(7)]],
uint2 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_2<IdxT>(index, dst_strides);
dst[dst_idx + dst_offset] = src[src_idx + src_offset];
}
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_gg_dynamic_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int64_t& src_offset [[buffer(6)]],
constant const int64_t& dst_offset [[buffer(7)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_3<IdxT>(index, dst_strides);
dst[dst_idx + dst_offset] = src[src_idx + src_offset];
}
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
[[kernel]] void copy_gg_dynamic(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
constant const int64_t& src_offset [[buffer(6)]],
constant const int64_t& dst_offset [[buffer(7)]],
uint3 index [[thread_position_in_grid]]) {
src += src_offset;
dst += dst_offset;
auto idx = elem_to_loc_2_nd<IdxT>(
{N * index.x, index.y, index.z},
src_shape,
src_strides,
dst_strides,
ndim);
if (N == 1) {
dst[idx.y] = src[idx.x];
return;
}
IdxT src_xstride = src_strides[ndim - 1];
IdxT dst_xstride = dst_strides[ndim - 1];
auto xshape = src_shape[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[idx.y] = src[idx.x];
idx.x += src_xstride;
idx.y += dst_xstride;
}
}