Use int64 stride everywhere (#1671)

* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
This commit is contained in:
Awni Hannun
2024-12-09 11:09:02 -08:00
committed by GitHub
parent 35b412c099
commit 40c62c1321
102 changed files with 1262 additions and 1705 deletions

View File

@@ -22,7 +22,7 @@ template <typename T, typename U>
device U* dst [[buffer(1)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto offset = index.x + grid_dim.x * int64_t(index.y);
dst[offset] = static_cast<U>(src[0]);
}
@@ -32,7 +32,7 @@ template <typename T, typename U>
device U* dst [[buffer(1)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto offset = index.x + grid_dim.x * int64_t(index.y);
dst[offset] = static_cast<U>(src[offset]);
}
@@ -42,7 +42,7 @@ template <typename T, typename U, typename IdxT = int64_t>
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1<int64_t, IdxT>(index, src_stride);
auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
dst[index] = static_cast<U>(src[src_idx]);
}
@@ -53,7 +53,7 @@ template <typename T, typename U, typename IdxT = int64_t>
constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y;
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
@@ -65,7 +65,7 @@ template <typename T, typename U, typename IdxT = int64_t>
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
IdxT dst_idx =
index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
@@ -80,7 +80,7 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc<int64_t, IdxT>(
auto src_idx = elem_to_loc<IdxT>(
{N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
if (N == 1) {
IdxT dst_idx =
@@ -104,8 +104,8 @@ template <typename T, typename U, typename IdxT = int64_t>
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1<int64_t, IdxT>(index, src_stride);
auto dst_idx = elem_to_loc_1<int64_t, IdxT>(index, dst_stride);
auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
auto dst_idx = elem_to_loc_1<IdxT>(index, dst_stride);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
@@ -116,8 +116,8 @@ template <typename T, typename U, typename IdxT = int64_t>
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_2<int64_t, IdxT>(index, dst_strides);
auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_2<IdxT>(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
@@ -128,8 +128,8 @@ template <typename T, typename U, typename IdxT = int64_t>
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_3<int64_t, IdxT>(index, dst_strides);
auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_3<IdxT>(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
@@ -142,7 +142,7 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]) {
auto idx = elem_to_loc_2_nd<int64_t, IdxT>(
auto idx = elem_to_loc_2_nd<IdxT>(
{N * index.x, index.y, index.z},
src_shape,
src_strides,