mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-24 12:18:20 +08:00
Fixes for large arrays with a few ops (#1299)
* fixes for large arrays with a few ops * fix bug * fix all of copy
This commit is contained in:
@@ -25,7 +25,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
|
||||
AccT ld[N_READS];
|
||||
|
||||
in += gid * axis_size + lid * N_READS;
|
||||
in += gid * size_t(axis_size) + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] = AccT(in[i]);
|
||||
@@ -83,7 +83,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
normalizer = 1 / local_normalizer[0];
|
||||
|
||||
// Normalize and write to the output
|
||||
out += gid * axis_size + lid * N_READS;
|
||||
out += gid * size_t(axis_size) + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[i] = T(ld[i] * normalizer);
|
||||
@@ -107,7 +107,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
in += gid * axis_size;
|
||||
in += gid * size_t(axis_size);
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
@@ -170,7 +170,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
|
||||
// Finally given the normalizer and max value we can directly write the
|
||||
// softmax output
|
||||
out += gid * axis_size;
|
||||
out += gid * size_t(axis_size);
|
||||
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
||||
r++) {
|
||||
int offset = r * lsize * N_READS + lid * N_READS;
|
||||
|
||||
Reference in New Issue
Block a user