fix large ops (#1620)

This commit is contained in:
Awni Hannun
2024-11-24 09:17:10 -08:00
committed by GitHub
parent bb303c45a5
commit 211411faf2
12 changed files with 37 additions and 25 deletions

View File

@@ -36,13 +36,13 @@ template <typename T, typename U>
dst[offset] = static_cast<U>(src[offset]);
}
template <typename T, typename U>
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_g_nd1(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
auto src_idx = elem_to_loc_1<int64_t, IdxT>(index, src_stride);
dst[index] = static_cast<U>(src[src_idx]);
}
@@ -97,15 +97,15 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
}
}
template <typename T, typename U>
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_gg_nd1(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
auto dst_idx = elem_to_loc_1<int64_t, int>(index, dst_stride);
auto src_idx = elem_to_loc_1<int64_t, IdxT>(index, src_stride);
auto dst_idx = elem_to_loc_1<int64_t, IdxT>(index, dst_stride);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}