Fixes for large arrays with a few ops (#1299)

* fixes for large arrays with a few ops

* fix bug

* fix all of copy
This commit is contained in:
Awni Hannun
2024-07-30 17:18:39 -07:00
committed by GitHub
parent c52d1600f0
commit 40b6d67333
21 changed files with 273 additions and 202 deletions

View File

@@ -25,7 +25,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
AccT ld[N_READS];
in += gid * axis_size + lid * N_READS;
in += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
ld[i] = AccT(in[i]);
@@ -83,7 +83,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
normalizer = 1 / local_normalizer[0];
// Normalize and write to the output
out += gid * axis_size + lid * N_READS;
out += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
out[i] = T(ld[i] * normalizer);
@@ -107,7 +107,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
in += gid * axis_size;
in += gid * size_t(axis_size);
constexpr int SIMD_SIZE = 32;
@@ -170,7 +170,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
// Finally given the normalizer and max value we can directly write the
// softmax output
out += gid * axis_size;
out += gid * size_t(axis_size);
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;