This commit is contained in:
Awni Hannun 2025-06-17 00:06:01 +09:00 committed by GitHub
commit 5e5d379a2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 42 additions and 10 deletions

View File

@ -5,28 +5,33 @@ template <typename T, typename AccT = float, int N_READS = 4>
const device T* in,
device T* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint2 gid [[threadgroup_position_in_grid]],
uint2 tid [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]],
uint2 _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
int lid = _lid;
int lid = _lid.x;
constexpr int SIMD_SIZE = 32;
constexpr int elem_per_group = SIMD_SIZE * 32 * N_READS;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
AccT ld[N_READS];
in += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
const int axis_offset = tid.y * elem_per_group;
in += gid.x * size_t(axis_size) + lid * N_READS + axis_offset;
if (axis_offset + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
ld[i] = AccT(in[i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
ld[i] =
((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;
ld[i] = ((axis_offset + lid * N_READS + i) < axis_size)
? AccT(in[i])
: Limits<AccT>::min;
}
}
if (simd_group_id == 0) {
@ -55,6 +60,7 @@ template <typename T, typename AccT = float, int N_READS = 4>
maxval = local_max[0];
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
out += gid.x * grid_dim.y + tid.y;
AccT normalizer = 0;
for (int i = 0; i < N_READS; i++) {
normalizer += fast::exp(ld[i] - maxval);
@ -67,7 +73,7 @@ template <typename T, typename AccT = float, int N_READS = 4>
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
out[0] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
}
}
}

View File

@ -62,15 +62,37 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
const int n_reads = 4;
const int looped_limit = LOGSUMEXP_LOOPED_LIMIT;
std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_";
bool split = n_rows < 4 && axis_size > 4 * looped_limit;
bool looped = !split && axis_size > looped_limit;
std::string kernel_name = looped ? "looped_" : "block_";
kernel_name += "logsumexp_";
kernel_name += type_to_name(out);
auto kernel = get_logsumexp_kernel(d, kernel_name, out);
auto& compute_encoder = d.get_command_encoder(s.index);
if (split) {
auto tmp_size = ceildiv(axis_size, looped_limit);
auto tmp_shape = Shape{n_rows, static_cast<int>(tmp_size)};
array tmp(tmp_shape, in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc(tmp.nbytes()));
size_t threadgroup_size = 1024;
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
size_t n_threads = n_rows * threadgroup_size;
auto grid_dims = MTL::Size(n_threads, tmp_size, 1);
auto group_dims = MTL::Size(threadgroup_size, 1, 1);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(tmp, 1);
compute_encoder.set_bytes(axis_size, 2);
compute_encoder.dispatch_threads(grid_dims, group_dims);
d.add_temporary(tmp, s.index);
in = tmp;
axis_size = tmp_size;
}
{
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
if (!looped) {
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
size_t threadgroup_size = simd_size * simds_needed;

View File

@ -760,6 +760,10 @@ class TestOps(mlx_tests.MLXTestCase):
x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8))
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
# Even larger
x = mx.random.uniform(shape=(4 * 4096 + 3,))
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
def test_mean(self):
x = mx.array(
[