Fixes for large arrays with a few ops (#1299)

* fixes for large arrays with a few ops

* fix bug

* fix all of copy
This commit is contained in:
Awni Hannun
2024-07-30 17:18:39 -07:00
committed by GitHub
parent c52d1600f0
commit 40b6d67333
21 changed files with 273 additions and 202 deletions

View File

@@ -16,6 +16,26 @@ template <typename T, typename U>
dst[index] = static_cast<U>(src[index]);
}
template <typename T, typename U>
[[kernel]] void copy_s2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
dst[offset] = static_cast<U>(src[0]);
}
template <typename T, typename U>
[[kernel]] void copy_v2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
dst[offset] = static_cast<U>(src[offset]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd1(
device const T* src [[buffer(0)]],