diff --git a/benchmarks/python/layer_norm_bench.py b/benchmarks/python/layer_norm_bench.py index 69263835a..29925de0b 100644 --- a/benchmarks/python/layer_norm_bench.py +++ b/benchmarks/python/layer_norm_bench.py @@ -1,5 +1,7 @@ # Copyright © 2023-2024 Apple Inc. +from functools import partial + import mlx.core as mx import mlx.nn as nn from time_utils import time_fn @@ -18,51 +20,63 @@ def layer_norm(x, w, b, eps): return y -def time_layer_norm(): +def time_layer_norm(N, dt): + L = 1024 f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum() f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum() g1 = mx.grad(f1, argnums=(0, 1, 2)) g2 = mx.grad(f2, argnums=(0, 1, 2)) - x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) - w = mx.random.uniform(shape=(4096,)).astype(mx.float16) - b = mx.random.uniform(shape=(4096,)).astype(mx.float16) - y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) + x = mx.random.uniform(shape=(8, L, N)).astype(dt) + w = mx.random.uniform(shape=(N,)).astype(dt) + b = mx.random.uniform(shape=(N,)).astype(dt) + y = mx.random.uniform(shape=(8, L, N)).astype(dt) mx.eval(x, w, b, y) - def layer_norm_loop(g, x, w, b): + def layer_norm_loop(f, x, w, b): + for _ in range(32): + x = f(x, w, b) + return x + + time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b) + time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b) + + def layer_norm_grad_loop(g, x, w, b): gx, gw, gb = x, w, b for _ in range(32): gx, gw, gb = g(gx, gw, gb, y) return gx, gw, gb - time_fn(layer_norm_loop, g1, x, w, b) - time_fn(layer_norm_loop, g2, x, w, b) - time_fn(layer_norm_loop, mx.compile(g1), x, w, b) - time_fn(layer_norm_loop, mx.compile(g2), x, w, b) + time_fn(layer_norm_grad_loop, g1, x, w, b) + time_fn(layer_norm_grad_loop, g2, x, w, b) + time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b) + time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b) f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum() f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum() g1 = mx.grad(f1, argnums=(0,)) g2 = mx.grad(f2, argnums=(0,)) - x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) - w = mx.random.uniform(shape=(4096,)).astype(mx.float16) - b = mx.random.uniform(shape=(4096,)).astype(mx.float16) - y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) + x = mx.random.uniform(shape=(8, L, N)).astype(dt) + w = mx.random.uniform(shape=(N,)).astype(dt) + b = mx.random.uniform(shape=(N,)).astype(dt) + y = mx.random.uniform(shape=(8, L, N)).astype(dt) mx.eval(x, w, b, y) - def layer_norm_loop(g, x): + def layer_norm_grad_x_loop(g, x): gx = x for _ in range(32): gx = g(gx, y) return gx - time_fn(layer_norm_loop, g1, x) - time_fn(layer_norm_loop, g2, x) - time_fn(layer_norm_loop, mx.compile(g1), x) - time_fn(layer_norm_loop, mx.compile(g2), x) + time_fn(layer_norm_grad_x_loop, g1, x) + time_fn(layer_norm_grad_x_loop, g2, x) + time_fn(layer_norm_grad_x_loop, mx.compile(g1), x) + time_fn(layer_norm_grad_x_loop, mx.compile(g2), x) if __name__ == "__main__": - time_layer_norm() + for dt in [mx.float32, mx.float16, mx.bfloat16]: + for n in [1024, 2048, 4096, 8192, 8192 + 1024]: + print(dt, n) + time_layer_norm(n, dt) diff --git a/mlx/backend/metal/kernels/layer_norm.metal b/mlx/backend/metal/kernels/layer_norm.metal index 51570e48d..06b8be55f 100644 --- a/mlx/backend/metal/kernels/layer_norm.metal +++ b/mlx/backend/metal/kernels/layer_norm.metal @@ -9,7 +9,41 @@ using namespace metal; constant bool has_w [[function_constant(20)]]; -template +template +inline void initialize_buffer( + threadgroup float* xs, + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + if (simd_group_id == 0) { + for (int i = 0; i < N; i++) { + xs[N * simd_lane_id + i] = 0; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); +} + +template +inline void threadgroup_sum( + thread float* x, + threadgroup float* xs, + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + for (int i = 0; i < N; i++) { + x[i] = simd_sum(x[i]); + } + if (simd_lane_id == 0) { + for (int i = 0; i < N; i++) { + xs[N * simd_group_id + i] = x[i]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N; i++) { + x[i] = xs[N * simd_lane_id + i]; + x[i] = simd_sum(x[i]); + } +} + +template [[kernel]] void layer_norm_single_row( const device T* x, const device T* w, @@ -23,90 +57,71 @@ template uint lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - float sumx = 0; - float sumx2 = 0; - float thread_x[N_READS]; - constexpr int SIMD_SIZE = 32; - threadgroup float local_sumx[SIMD_SIZE]; - threadgroup float local_sumx2[SIMD_SIZE]; - threadgroup float local_mean[1]; - threadgroup float local_normalizer[1]; + // Initialize the registers and threadgroup memory + float thread_x[N_READS] = {0}; + threadgroup float local_buffer[SIMD_SIZE] = {0}; + initialize_buffer(local_buffer, simd_lane_id, simd_group_id); + // Advance the pointers x += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; b += b_stride * lid * N_READS; + out += gid * size_t(axis_size) + lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { + // Compute some variables for reading writing etc + const bool safe = lid * N_READS + N_READS <= axis_size; + const int n = axis_size - lid * N_READS; + + // Read the inputs + if (safe) { for (int i = 0; i < N_READS; i++) { thread_x[i] = x[i]; - sumx2 += thread_x[i] * thread_x[i]; - sumx += thread_x[i]; } } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - thread_x[i] = x[i]; - sumx2 += thread_x[i] * thread_x[i]; - sumx += thread_x[i]; - } + for (int i = 0; i < n; i++) { + thread_x[i] = x[i]; } } - sumx = simd_sum(sumx); - sumx2 = simd_sum(sumx2); - - // Initialize shared memory - if (simd_group_id == 0) { - local_sumx[simd_lane_id] = 0; - local_sumx2[simd_lane_id] = 0; + // Compute the mean + float mean = 0; + for (int i = 0; i < N_READS; i++) { + mean += thread_x[i]; } - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); + mean /= axis_size; - // Write simd accumulations into shared memory - if (simd_lane_id == 0) { - local_sumx[simd_group_id] = sumx; - local_sumx2[simd_group_id] = sumx2; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Accumulate over simd groups - if (simd_group_id == 0) { - sumx = simd_sum(local_sumx[simd_lane_id]); - sumx2 = simd_sum(local_sumx2[simd_lane_id]); - if (simd_lane_id == 0) { - float mean = sumx / axis_size; - float variance = sumx2 / axis_size - mean * mean; - - local_mean[0] = mean; - local_normalizer[0] = metal::precise::rsqrt(variance + eps); + // Compute the normalizer + float normalizer = 0; + if (!safe) { + for (int i = n; i < N_READS; i++) { + thread_x[i] = mean; } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - float mean = local_mean[0]; - float normalizer = local_normalizer[0]; + for (int i = 0; i < N_READS; i++) { + thread_x[i] -= mean; + normalizer += thread_x[i] * thread_x[i]; + } + threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id); + normalizer = metal::precise::rsqrt(normalizer / axis_size + eps); // Write the outputs - out += gid * size_t(axis_size) + lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { + if (safe) { for (int i = 0; i < N_READS; i++) { - thread_x[i] = (thread_x[i] - mean) * normalizer; + thread_x[i] *= normalizer; out[i] = w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; } } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - thread_x[i] = (thread_x[i] - mean) * normalizer; - out[i] = - w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; - } + for (int i = 0; i < n; i++) { + thread_x[i] *= normalizer; + out[i] = w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; } } } -template +template [[kernel]] void layer_norm_looped( const device T* x, const device T* w, @@ -121,71 +136,52 @@ template uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - float sumx = 0; - float sumx2 = 0; - constexpr int SIMD_SIZE = 32; - threadgroup float local_sumx[SIMD_SIZE]; - threadgroup float local_sumx2[SIMD_SIZE]; - threadgroup float local_mean[1]; - threadgroup float local_normalizer[1]; + threadgroup float local_buffer[SIMD_SIZE]; + initialize_buffer(local_buffer, simd_lane_id, simd_group_id); x += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; b += b_stride * lid * N_READS; + // Compute the mean + float mean = 0; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { - float xi = x[i + r]; - sumx2 += xi * xi; - sumx += xi; + mean += x[i + r]; } } else { for (int i = 0; i < N_READS; i++) { if ((r + lid * N_READS + i) < axis_size) { - float xi = x[i + r]; - sumx2 += xi * xi; - sumx += xi; + mean += x[i + r]; } } } } + threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); + mean /= axis_size; - sumx = simd_sum(sumx); - sumx2 = simd_sum(sumx2); - - // Initialize shared memory - if (simd_group_id == 0) { - local_sumx[simd_lane_id] = 0; - local_sumx2[simd_lane_id] = 0; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write simd accumulations into shared memory - if (simd_lane_id == 0) { - local_sumx[simd_group_id] = sumx; - local_sumx2[simd_group_id] = sumx2; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Accumulate over simd groups - if (simd_group_id == 0) { - sumx = simd_sum(local_sumx[simd_lane_id]); - sumx2 = simd_sum(local_sumx2[simd_lane_id]); - if (simd_lane_id == 0) { - float mean = sumx / axis_size; - float variance = sumx2 / axis_size - mean * mean; - - local_mean[0] = mean; - local_normalizer[0] = metal::precise::rsqrt(variance + eps); + // Compute the normalizer + float normalizer = 0; + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float t = x[i + r] - mean; + normalizer += t * t; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + float t = x[i + r] - mean; + normalizer += t * t; + } + } } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - float mean = local_mean[0]; - float normalizer = local_normalizer[0]; + threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id); + normalizer = metal::precise::rsqrt(normalizer / axis_size + eps); // Write the outputs out += gid * size_t(axis_size) + lid * N_READS; @@ -208,7 +204,7 @@ template } } -template +template [[kernel]] void vjp_layer_norm_single_row( const device T* x, const device T* w, @@ -222,133 +218,96 @@ template uint lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int SIMD_SIZE = 32; + // Advance the input pointers x += gid * size_t(axis_size) + lid * N_READS; g += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; - // Allocate registers for the computation and accumulators - float thread_x[N_READS]; - float thread_w[N_READS]; - float thread_g[N_READS]; - float sumx = 0; - float sumx2 = 0; - float sumwg = 0; - float sumwgx = 0; + // Initialize the registers and threadgroup memory + float thread_x[N_READS] = {0}; + float thread_w[N_READS] = {0}; + float thread_g[N_READS] = {0}; + threadgroup float local_buffer[3 * SIMD_SIZE]; + initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id); - constexpr int SIMD_SIZE = 32; + // Compute some variables for reading writing etc + const bool safe = lid * N_READS + N_READS <= axis_size; + const int n = axis_size - lid * N_READS; - threadgroup float local_sumx[SIMD_SIZE]; - threadgroup float local_sumx2[SIMD_SIZE]; - threadgroup float local_sumwg[SIMD_SIZE]; - threadgroup float local_sumwgx[SIMD_SIZE]; - threadgroup float local_mean[1]; - threadgroup float local_normalizer[1]; - threadgroup float local_meanwg[1]; - threadgroup float local_meanwgx[1]; - - if (lid * N_READS + N_READS <= axis_size) { + // Read the inputs + if (safe) { for (int i = 0; i < N_READS; i++) { thread_x[i] = x[i]; - thread_w[i] = w[i * w_stride]; thread_g[i] = g[i]; - float wg = thread_w[i] * thread_g[i]; - sumx += thread_x[i]; - sumx2 += thread_x[i] * thread_x[i]; - sumwg += wg; - sumwgx += wg * thread_x[i]; + thread_w[i] = w[i * w_stride]; } } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - thread_x[i] = x[i]; - thread_w[i] = w[i * w_stride]; - thread_g[i] = g[i]; - float wg = thread_w[i] * thread_g[i]; - sumx += thread_x[i]; - sumx2 += thread_x[i] * thread_x[i]; - sumwg += wg; - sumwgx += wg * thread_x[i]; - } + for (int i = 0; i < n; i++) { + thread_x[i] = x[i]; + thread_g[i] = g[i]; + thread_w[i] = w[i * w_stride]; } } - sumx = simd_sum(sumx); - sumx2 = simd_sum(sumx2); - sumwg = simd_sum(sumwg); - sumwgx = simd_sum(sumwgx); - - // Initialize shared memory - if (simd_group_id == 0) { - local_sumx[simd_lane_id] = 0; - local_sumx2[simd_lane_id] = 0; - local_sumwg[simd_lane_id] = 0; - local_sumwgx[simd_lane_id] = 0; + // Compute the mean + float mean = 0; + for (int i = 0; i < N_READS; i++) { + mean += thread_x[i]; } - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); + mean /= axis_size; - // Write simd accumulations into shared memory - if (simd_lane_id == 0) { - local_sumx[simd_group_id] = sumx; - local_sumx2[simd_group_id] = sumx2; - local_sumwg[simd_group_id] = sumwg; - local_sumwgx[simd_group_id] = sumwgx; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Accumulate over simd groups - if (simd_group_id == 0) { - sumx = simd_sum(local_sumx[simd_lane_id]); - sumx2 = simd_sum(local_sumx2[simd_lane_id]); - sumwg = simd_sum(local_sumwg[simd_lane_id]); - sumwgx = simd_sum(local_sumwgx[simd_lane_id]); - if (simd_lane_id == 0) { - float mean = sumx / axis_size; - float variance = sumx2 / axis_size - mean * mean; - - local_mean[0] = mean; - local_normalizer[0] = metal::precise::rsqrt(variance + eps); - local_meanwg[0] = sumwg / axis_size; - local_meanwgx[0] = sumwgx / axis_size; + // Compute the neccesary scaling factors using the mean + if (!safe) { + for (int i = n; i < N_READS; i++) { + thread_x[i] = mean; } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - float mean = local_mean[0]; - float normalizer = local_normalizer[0]; - float meanwg = local_meanwg[0]; - float meanwgxc = local_meanwgx[0] - meanwg * mean; - float normalizer2 = normalizer * normalizer; + float factors[3] = {0}; + constexpr int meanwg = 0; + constexpr int meanwgxc = 1; + constexpr int normalizer2 = 2; + for (int i = 0; i < N_READS; i++) { + thread_x[i] -= mean; + factors[meanwg] += thread_w[i] * thread_g[i]; + factors[meanwgxc] += thread_w[i] * thread_g[i] * thread_x[i]; + factors[normalizer2] += thread_x[i] * thread_x[i]; + } + threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id); + factors[meanwg] /= axis_size; + factors[meanwgxc] /= axis_size; + factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps); + float normalizer = metal::precise::sqrt(factors[normalizer2]); // Write the outputs gx += gid * size_t(axis_size) + lid * N_READS; gw += gid * size_t(axis_size) + lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { + if (safe) { for (int i = 0; i < N_READS; i++) { - thread_x[i] = (thread_x[i] - mean) * normalizer; + thread_x[i] *= normalizer; gx[i] = static_cast( - normalizer * (thread_w[i] * thread_g[i] - meanwg) - - thread_x[i] * meanwgxc * normalizer2); + normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) - + thread_x[i] * factors[meanwgxc] * factors[normalizer2]); if (has_w) { gw[i] = static_cast(thread_g[i] * thread_x[i]); } } } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - thread_x[i] = (thread_x[i] - mean) * normalizer; - gx[i] = static_cast( - normalizer * (thread_w[i] * thread_g[i] - meanwg) - - thread_x[i] * meanwgxc * normalizer2); - if (has_w) { - gw[i] = static_cast(thread_g[i] * thread_x[i]); - } + for (int i = 0; i < n; i++) { + thread_x[i] *= normalizer; + gx[i] = static_cast( + normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) - + thread_x[i] * factors[meanwgxc] * factors[normalizer2]); + if (has_w) { + gw[i] = static_cast(thread_g[i] * thread_x[i]); } } } } -template +template [[kernel]] void vjp_layer_norm_looped( const device T* x, const device T* w, @@ -363,102 +322,69 @@ template uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int SIMD_SIZE = 32; + // Advance the input pointers x += gid * size_t(axis_size) + lid * N_READS; g += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; - // Allocate registers for the accumulators - float sumx = 0; - float sumx2 = 0; - float sumwg = 0; - float sumwgx = 0; - - constexpr int SIMD_SIZE = 32; - - threadgroup float local_sumx[SIMD_SIZE]; - threadgroup float local_sumx2[SIMD_SIZE]; - threadgroup float local_sumwg[SIMD_SIZE]; - threadgroup float local_sumwgx[SIMD_SIZE]; - threadgroup float local_mean[1]; - threadgroup float local_normalizer[1]; - threadgroup float local_meanwg[1]; - threadgroup float local_meanwgx[1]; + threadgroup float local_buffer[3 * SIMD_SIZE]; + initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id); + // Compute the mean + float mean = 0; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { - float xi = x[i + r]; - float wi = w[(i + r) * w_stride]; - float gi = g[i + r]; - float wg = wi * gi; - sumx += xi; - sumx2 += xi * xi; - sumwg += wg; - sumwgx += wg * xi; + mean += x[i + r]; } } else { for (int i = 0; i < N_READS; i++) { if ((r + lid * N_READS + i) < axis_size) { - float xi = x[i + r]; - float wi = w[(i + r) * w_stride]; - float gi = g[i + r]; - float wg = wi * gi; - sumx += xi; - sumx2 += xi * xi; - sumwg += wg; - sumwgx += wg * xi; + mean += x[i + r]; } } } } + threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); + mean /= axis_size; - sumx = simd_sum(sumx); - sumx2 = simd_sum(sumx2); - sumwg = simd_sum(sumwg); - sumwgx = simd_sum(sumwgx); - - // Initialize shared memory - if (simd_group_id == 0) { - local_sumx[simd_lane_id] = 0; - local_sumx2[simd_lane_id] = 0; - local_sumwg[simd_lane_id] = 0; - local_sumwgx[simd_lane_id] = 0; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write simd accumulations into shared memory - if (simd_lane_id == 0) { - local_sumx[simd_group_id] = sumx; - local_sumx2[simd_group_id] = sumx2; - local_sumwg[simd_group_id] = sumwg; - local_sumwgx[simd_group_id] = sumwgx; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Accumulate over simd groups - if (simd_group_id == 0) { - sumx = simd_sum(local_sumx[simd_lane_id]); - sumx2 = simd_sum(local_sumx2[simd_lane_id]); - sumwg = simd_sum(local_sumwg[simd_lane_id]); - sumwgx = simd_sum(local_sumwgx[simd_lane_id]); - if (simd_lane_id == 0) { - float mean = sumx / axis_size; - float variance = sumx2 / axis_size - mean * mean; - - local_mean[0] = mean; - local_normalizer[0] = metal::precise::rsqrt(variance + eps); - local_meanwg[0] = sumwg / axis_size; - local_meanwgx[0] = sumwgx / axis_size; + // Compute the neccesary scaling factors using the mean + float factors[3] = {0}; + constexpr int meanwg = 0; + constexpr int meanwgxc = 1; + constexpr int normalizer2 = 2; + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float t = x[i + r] - mean; + float wi = w[(i + r) * w_stride]; + float gi = g[i + r]; + float wg = wi * gi; + factors[meanwg] += wg; + factors[meanwgxc] += wg * t; + factors[normalizer2] += t * t; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + float t = x[i + r] - mean; + float wi = w[(i + r) * w_stride]; + float gi = g[i + r]; + float wg = wi * gi; + factors[meanwg] += wg; + factors[meanwgxc] += wg * t; + factors[normalizer2] += t * t; + } + } } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - float mean = local_mean[0]; - float normalizer = local_normalizer[0]; - float meanwg = local_meanwg[0]; - float meanwgxc = local_meanwgx[0] - meanwg * mean; - float normalizer2 = normalizer * normalizer; + threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id); + factors[meanwg] /= axis_size; + factors[meanwgxc] /= axis_size; + factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps); + float normalizer = metal::precise::sqrt(factors[normalizer2]); // Write the outputs gx += gid * size_t(axis_size) + lid * N_READS; @@ -470,7 +396,8 @@ template float wi = w[(i + r) * w_stride]; float gi = g[i + r]; gx[i + r] = static_cast( - normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2); + normalizer * (wi * gi - factors[meanwg]) - + xi * factors[meanwgxc] * factors[normalizer2]); if (has_w) { gw[i + r] = static_cast(gi * xi); } @@ -482,7 +409,8 @@ template float wi = w[(i + r) * w_stride]; float gi = g[i + r]; gx[i + r] = static_cast( - normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2); + normalizer * (wi * gi - factors[meanwg]) - + xi * factors[meanwgxc] * factors[normalizer2]); if (has_w) { gw[i + r] = static_cast(gi * xi); } diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index c0901ccec..c53289828 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -255,12 +255,13 @@ void LayerNorm::eval_gpu( auto axis_size = static_cast(x.shape().back()); int n_rows = x.data_size() / axis_size; - const int simd_size = 32; - const int n_reads = RMS_N_READS; - const int looped_limit = RMS_LOOPED_LIMIT; + int simd_size = 32; + int n_reads = 8; + int looped_limit = 6656; std::string op_name = "layer_norm"; if (axis_size > looped_limit) { op_name += "_looped"; + n_reads = 4; } op_name += type_to_name(out); auto& compute_encoder = d.get_command_encoder(s.index); @@ -272,7 +273,13 @@ void LayerNorm::eval_gpu( size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads; size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size; size_t threadgroup_size = simd_size * simds_needed; - assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup()); + if (threadgroup_size > kernel->maxTotalThreadsPerThreadgroup()) { + std::ostringstream msg; + msg << "[layer_norm] Threadgroup size " << threadgroup_size + << " is larger than the maximum allowed threadgroup size " + << kernel->maxTotalThreadsPerThreadgroup(); + throw std::runtime_error(msg.str()); + } size_t n_threads = n_rows * threadgroup_size; grid_dims = MTL::Size(n_threads, 1, 1); group_dims = MTL::Size(threadgroup_size, 1, 1); @@ -372,12 +379,13 @@ void LayerNormVJP::eval_gpu( g, gb, "sum", plan, {0}, compute_encoder, d, s); } - const int simd_size = 32; - const int n_reads = RMS_N_READS; - const int looped_limit = RMS_LOOPED_LIMIT; + int simd_size = 32; + int n_reads = 8; + int looped_limit = 8192; std::string op_name = "vjp_layer_norm"; if (axis_size > looped_limit) { op_name += "_looped"; + n_reads = 4; } op_name += type_to_name(gx); @@ -394,7 +402,13 @@ void LayerNormVJP::eval_gpu( size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads; size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size; size_t threadgroup_size = simd_size * simds_needed; - assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup()); + if (threadgroup_size > kernel->maxTotalThreadsPerThreadgroup()) { + std::ostringstream msg; + msg << "[vjp_layer_norm] Threadgroup size " << threadgroup_size + << " is larger than the maximum allowed threadgroup size " + << kernel->maxTotalThreadsPerThreadgroup(); + throw std::runtime_error(msg.str()); + } size_t n_threads = n_rows * threadgroup_size; grid_dims = MTL::Size(n_threads, 1, 1); group_dims = MTL::Size(threadgroup_size, 1, 1); diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index aad1a0018..096d6b906 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -369,7 +369,7 @@ bool ScaledDotProductAttention::use_fallback( const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || (query_sequence_length <= key_sequence_length && do_causal); - const bool supports_sdpa_full = + const bool supports_sdpa_full = query_sequence_length > 8 && sdpa_full_supported_mask && sdpa_full_supported_head_dim; const bool supports_sdpa_vector = (query_sequence_length <= 8) && diff --git a/mlx/fast.cpp b/mlx/fast.cpp index eab22f14d..657c0aba8 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -231,13 +231,11 @@ array layer_norm( const std::vector& inputs) { auto x = astype(inputs[0], float32, s); - // Should I not be smart here and leave the double mean to simplify()? auto mu = mean(x, /* axis= */ -1, /* keepdims= */ true, s); - auto mu2 = square(mu, s); - auto x2 = mean(square(x, s), /* axis= */ -1, /* keepdims= */ true, s); - auto v = subtract(x2, mu2, s); + auto xc = subtract(x, mu, s); + auto v = mean(square(xc, s), /* axis= */ -1, /* keepdims= */ true, s); - x = multiply(subtract(x, mu, s), rsqrt(add(v, array(eps, float32), s), s)); + x = multiply(xc, rsqrt(add(v, array(eps, float32), s), s)); x = astype(x, out_type, s); // If the LN is affine then transform x according to the weight and bias