mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Change layernorms to two pass algorithm (#2246)
This commit is contained in:
parent
24f89173d1
commit
2e8cf0b450
@ -1,5 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from time_utils import time_fn
|
||||
@ -18,51 +20,63 @@ def layer_norm(x, w, b, eps):
|
||||
return y
|
||||
|
||||
|
||||
def time_layer_norm():
|
||||
def time_layer_norm(N, dt):
|
||||
L = 1024
|
||||
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
||||
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
||||
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
mx.eval(x, w, b, y)
|
||||
|
||||
def layer_norm_loop(g, x, w, b):
|
||||
def layer_norm_loop(f, x, w, b):
|
||||
for _ in range(32):
|
||||
x = f(x, w, b)
|
||||
return x
|
||||
|
||||
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
|
||||
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
|
||||
|
||||
def layer_norm_grad_loop(g, x, w, b):
|
||||
gx, gw, gb = x, w, b
|
||||
for _ in range(32):
|
||||
gx, gw, gb = g(gx, gw, gb, y)
|
||||
return gx, gw, gb
|
||||
|
||||
time_fn(layer_norm_loop, g1, x, w, b)
|
||||
time_fn(layer_norm_loop, g2, x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
|
||||
time_fn(layer_norm_grad_loop, g1, x, w, b)
|
||||
time_fn(layer_norm_grad_loop, g2, x, w, b)
|
||||
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
|
||||
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
|
||||
|
||||
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
||||
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0,))
|
||||
g2 = mx.grad(f2, argnums=(0,))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
mx.eval(x, w, b, y)
|
||||
|
||||
def layer_norm_loop(g, x):
|
||||
def layer_norm_grad_x_loop(g, x):
|
||||
gx = x
|
||||
for _ in range(32):
|
||||
gx = g(gx, y)
|
||||
return gx
|
||||
|
||||
time_fn(layer_norm_loop, g1, x)
|
||||
time_fn(layer_norm_loop, g2, x)
|
||||
time_fn(layer_norm_loop, mx.compile(g1), x)
|
||||
time_fn(layer_norm_loop, mx.compile(g2), x)
|
||||
time_fn(layer_norm_grad_x_loop, g1, x)
|
||||
time_fn(layer_norm_grad_x_loop, g2, x)
|
||||
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
|
||||
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_layer_norm()
|
||||
for dt in [mx.float32, mx.float16, mx.bfloat16]:
|
||||
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
|
||||
print(dt, n)
|
||||
time_layer_norm(n, dt)
|
||||
|
@ -9,7 +9,41 @@ using namespace metal;
|
||||
|
||||
constant bool has_w [[function_constant(20)]];
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
template <int N = 1>
|
||||
inline void initialize_buffer(
|
||||
threadgroup float* xs,
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
if (simd_group_id == 0) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
xs[N * simd_lane_id + i] = 0;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
template <int N = 1>
|
||||
inline void threadgroup_sum(
|
||||
thread float* x,
|
||||
threadgroup float* xs,
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
x[i] = simd_sum(x[i]);
|
||||
}
|
||||
if (simd_lane_id == 0) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
xs[N * simd_group_id + i] = x[i];
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (int i = 0; i < N; i++) {
|
||||
x[i] = xs[N * simd_lane_id + i];
|
||||
x[i] = simd_sum(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = 8>
|
||||
[[kernel]] void layer_norm_single_row(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
@ -23,90 +57,71 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
float sumx = 0;
|
||||
float sumx2 = 0;
|
||||
float thread_x[N_READS];
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup float local_sumx[SIMD_SIZE];
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_mean[1];
|
||||
threadgroup float local_normalizer[1];
|
||||
// Initialize the registers and threadgroup memory
|
||||
float thread_x[N_READS] = {0};
|
||||
threadgroup float local_buffer[SIMD_SIZE] = {0};
|
||||
initialize_buffer(local_buffer, simd_lane_id, simd_group_id);
|
||||
|
||||
// Advance the pointers
|
||||
x += gid * size_t(axis_size) + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
b += b_stride * lid * N_READS;
|
||||
out += gid * size_t(axis_size) + lid * N_READS;
|
||||
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
// Compute some variables for reading writing etc
|
||||
const bool safe = lid * N_READS + N_READS <= axis_size;
|
||||
const int n = axis_size - lid * N_READS;
|
||||
|
||||
// Read the inputs
|
||||
if (safe) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = x[i];
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumx += thread_x[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = x[i];
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumx += thread_x[i];
|
||||
}
|
||||
for (int i = 0; i < n; i++) {
|
||||
thread_x[i] = x[i];
|
||||
}
|
||||
}
|
||||
|
||||
sumx = simd_sum(sumx);
|
||||
sumx2 = simd_sum(sumx2);
|
||||
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx[simd_lane_id] = 0;
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
// Compute the mean
|
||||
float mean = 0;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
mean += thread_x[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);
|
||||
mean /= axis_size;
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx[simd_group_id] = sumx;
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
float mean = sumx / axis_size;
|
||||
float variance = sumx2 / axis_size - mean * mean;
|
||||
|
||||
local_mean[0] = mean;
|
||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
||||
// Compute the normalizer
|
||||
float normalizer = 0;
|
||||
if (!safe) {
|
||||
for (int i = n; i < N_READS; i++) {
|
||||
thread_x[i] = mean;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float mean = local_mean[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] -= mean;
|
||||
normalizer += thread_x[i] * thread_x[i];
|
||||
}
|
||||
threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id);
|
||||
normalizer = metal::precise::rsqrt(normalizer / axis_size + eps);
|
||||
|
||||
// Write the outputs
|
||||
out += gid * size_t(axis_size) + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
if (safe) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||
thread_x[i] *= normalizer;
|
||||
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||
out[i] =
|
||||
w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
|
||||
}
|
||||
for (int i = 0; i < n; i++) {
|
||||
thread_x[i] *= normalizer;
|
||||
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
template <typename T, int N_READS = 4>
|
||||
[[kernel]] void layer_norm_looped(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
@ -121,71 +136,52 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
float sumx = 0;
|
||||
float sumx2 = 0;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup float local_sumx[SIMD_SIZE];
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_mean[1];
|
||||
threadgroup float local_normalizer[1];
|
||||
threadgroup float local_buffer[SIMD_SIZE];
|
||||
initialize_buffer(local_buffer, simd_lane_id, simd_group_id);
|
||||
|
||||
x += gid * size_t(axis_size) + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
b += b_stride * lid * N_READS;
|
||||
|
||||
// Compute the mean
|
||||
float mean = 0;
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i + r];
|
||||
sumx2 += xi * xi;
|
||||
sumx += xi;
|
||||
mean += x[i + r];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i + r];
|
||||
sumx2 += xi * xi;
|
||||
sumx += xi;
|
||||
mean += x[i + r];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);
|
||||
mean /= axis_size;
|
||||
|
||||
sumx = simd_sum(sumx);
|
||||
sumx2 = simd_sum(sumx2);
|
||||
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx[simd_lane_id] = 0;
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx[simd_group_id] = sumx;
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
float mean = sumx / axis_size;
|
||||
float variance = sumx2 / axis_size - mean * mean;
|
||||
|
||||
local_mean[0] = mean;
|
||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
||||
// Compute the normalizer
|
||||
float normalizer = 0;
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float t = x[i + r] - mean;
|
||||
normalizer += t * t;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float t = x[i + r] - mean;
|
||||
normalizer += t * t;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float mean = local_mean[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id);
|
||||
normalizer = metal::precise::rsqrt(normalizer / axis_size + eps);
|
||||
|
||||
// Write the outputs
|
||||
out += gid * size_t(axis_size) + lid * N_READS;
|
||||
@ -208,7 +204,7 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
template <typename T, int N_READS = 8>
|
||||
[[kernel]] void vjp_layer_norm_single_row(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
@ -222,133 +218,96 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
// Advance the input pointers
|
||||
x += gid * size_t(axis_size) + lid * N_READS;
|
||||
g += gid * size_t(axis_size) + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
|
||||
// Allocate registers for the computation and accumulators
|
||||
float thread_x[N_READS];
|
||||
float thread_w[N_READS];
|
||||
float thread_g[N_READS];
|
||||
float sumx = 0;
|
||||
float sumx2 = 0;
|
||||
float sumwg = 0;
|
||||
float sumwgx = 0;
|
||||
// Initialize the registers and threadgroup memory
|
||||
float thread_x[N_READS] = {0};
|
||||
float thread_w[N_READS] = {0};
|
||||
float thread_g[N_READS] = {0};
|
||||
threadgroup float local_buffer[3 * SIMD_SIZE];
|
||||
initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id);
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
// Compute some variables for reading writing etc
|
||||
const bool safe = lid * N_READS + N_READS <= axis_size;
|
||||
const int n = axis_size - lid * N_READS;
|
||||
|
||||
threadgroup float local_sumx[SIMD_SIZE];
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_sumwg[SIMD_SIZE];
|
||||
threadgroup float local_sumwgx[SIMD_SIZE];
|
||||
threadgroup float local_mean[1];
|
||||
threadgroup float local_normalizer[1];
|
||||
threadgroup float local_meanwg[1];
|
||||
threadgroup float local_meanwgx[1];
|
||||
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
// Read the inputs
|
||||
if (safe) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = x[i];
|
||||
thread_w[i] = w[i * w_stride];
|
||||
thread_g[i] = g[i];
|
||||
float wg = thread_w[i] * thread_g[i];
|
||||
sumx += thread_x[i];
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumwg += wg;
|
||||
sumwgx += wg * thread_x[i];
|
||||
thread_w[i] = w[i * w_stride];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = x[i];
|
||||
thread_w[i] = w[i * w_stride];
|
||||
thread_g[i] = g[i];
|
||||
float wg = thread_w[i] * thread_g[i];
|
||||
sumx += thread_x[i];
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumwg += wg;
|
||||
sumwgx += wg * thread_x[i];
|
||||
}
|
||||
for (int i = 0; i < n; i++) {
|
||||
thread_x[i] = x[i];
|
||||
thread_g[i] = g[i];
|
||||
thread_w[i] = w[i * w_stride];
|
||||
}
|
||||
}
|
||||
|
||||
sumx = simd_sum(sumx);
|
||||
sumx2 = simd_sum(sumx2);
|
||||
sumwg = simd_sum(sumwg);
|
||||
sumwgx = simd_sum(sumwgx);
|
||||
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx[simd_lane_id] = 0;
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
local_sumwg[simd_lane_id] = 0;
|
||||
local_sumwgx[simd_lane_id] = 0;
|
||||
// Compute the mean
|
||||
float mean = 0;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
mean += thread_x[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);
|
||||
mean /= axis_size;
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx[simd_group_id] = sumx;
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
local_sumwg[simd_group_id] = sumwg;
|
||||
local_sumwgx[simd_group_id] = sumwgx;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
sumwg = simd_sum(local_sumwg[simd_lane_id]);
|
||||
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
float mean = sumx / axis_size;
|
||||
float variance = sumx2 / axis_size - mean * mean;
|
||||
|
||||
local_mean[0] = mean;
|
||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
||||
local_meanwg[0] = sumwg / axis_size;
|
||||
local_meanwgx[0] = sumwgx / axis_size;
|
||||
// Compute the neccesary scaling factors using the mean
|
||||
if (!safe) {
|
||||
for (int i = n; i < N_READS; i++) {
|
||||
thread_x[i] = mean;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float mean = local_mean[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
float meanwg = local_meanwg[0];
|
||||
float meanwgxc = local_meanwgx[0] - meanwg * mean;
|
||||
float normalizer2 = normalizer * normalizer;
|
||||
float factors[3] = {0};
|
||||
constexpr int meanwg = 0;
|
||||
constexpr int meanwgxc = 1;
|
||||
constexpr int normalizer2 = 2;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] -= mean;
|
||||
factors[meanwg] += thread_w[i] * thread_g[i];
|
||||
factors[meanwgxc] += thread_w[i] * thread_g[i] * thread_x[i];
|
||||
factors[normalizer2] += thread_x[i] * thread_x[i];
|
||||
}
|
||||
threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id);
|
||||
factors[meanwg] /= axis_size;
|
||||
factors[meanwgxc] /= axis_size;
|
||||
factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps);
|
||||
float normalizer = metal::precise::sqrt(factors[normalizer2]);
|
||||
|
||||
// Write the outputs
|
||||
gx += gid * size_t(axis_size) + lid * N_READS;
|
||||
gw += gid * size_t(axis_size) + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
if (safe) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||
thread_x[i] *= normalizer;
|
||||
gx[i] = static_cast<T>(
|
||||
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||
thread_x[i] * meanwgxc * normalizer2);
|
||||
normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) -
|
||||
thread_x[i] * factors[meanwgxc] * factors[normalizer2]);
|
||||
if (has_w) {
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||
gx[i] = static_cast<T>(
|
||||
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||
thread_x[i] * meanwgxc * normalizer2);
|
||||
if (has_w) {
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||
}
|
||||
for (int i = 0; i < n; i++) {
|
||||
thread_x[i] *= normalizer;
|
||||
gx[i] = static_cast<T>(
|
||||
normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) -
|
||||
thread_x[i] * factors[meanwgxc] * factors[normalizer2]);
|
||||
if (has_w) {
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
template <typename T, int N_READS = 4>
|
||||
[[kernel]] void vjp_layer_norm_looped(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
@ -363,102 +322,69 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
// Advance the input pointers
|
||||
x += gid * size_t(axis_size) + lid * N_READS;
|
||||
g += gid * size_t(axis_size) + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
|
||||
// Allocate registers for the accumulators
|
||||
float sumx = 0;
|
||||
float sumx2 = 0;
|
||||
float sumwg = 0;
|
||||
float sumwgx = 0;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup float local_sumx[SIMD_SIZE];
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_sumwg[SIMD_SIZE];
|
||||
threadgroup float local_sumwgx[SIMD_SIZE];
|
||||
threadgroup float local_mean[1];
|
||||
threadgroup float local_normalizer[1];
|
||||
threadgroup float local_meanwg[1];
|
||||
threadgroup float local_meanwgx[1];
|
||||
threadgroup float local_buffer[3 * SIMD_SIZE];
|
||||
initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id);
|
||||
|
||||
// Compute the mean
|
||||
float mean = 0;
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
float wg = wi * gi;
|
||||
sumx += xi;
|
||||
sumx2 += xi * xi;
|
||||
sumwg += wg;
|
||||
sumwgx += wg * xi;
|
||||
mean += x[i + r];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
float wg = wi * gi;
|
||||
sumx += xi;
|
||||
sumx2 += xi * xi;
|
||||
sumwg += wg;
|
||||
sumwgx += wg * xi;
|
||||
mean += x[i + r];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);
|
||||
mean /= axis_size;
|
||||
|
||||
sumx = simd_sum(sumx);
|
||||
sumx2 = simd_sum(sumx2);
|
||||
sumwg = simd_sum(sumwg);
|
||||
sumwgx = simd_sum(sumwgx);
|
||||
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx[simd_lane_id] = 0;
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
local_sumwg[simd_lane_id] = 0;
|
||||
local_sumwgx[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx[simd_group_id] = sumx;
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
local_sumwg[simd_group_id] = sumwg;
|
||||
local_sumwgx[simd_group_id] = sumwgx;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
sumwg = simd_sum(local_sumwg[simd_lane_id]);
|
||||
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
float mean = sumx / axis_size;
|
||||
float variance = sumx2 / axis_size - mean * mean;
|
||||
|
||||
local_mean[0] = mean;
|
||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
||||
local_meanwg[0] = sumwg / axis_size;
|
||||
local_meanwgx[0] = sumwgx / axis_size;
|
||||
// Compute the neccesary scaling factors using the mean
|
||||
float factors[3] = {0};
|
||||
constexpr int meanwg = 0;
|
||||
constexpr int meanwgxc = 1;
|
||||
constexpr int normalizer2 = 2;
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float t = x[i + r] - mean;
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
float wg = wi * gi;
|
||||
factors[meanwg] += wg;
|
||||
factors[meanwgxc] += wg * t;
|
||||
factors[normalizer2] += t * t;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float t = x[i + r] - mean;
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
float wg = wi * gi;
|
||||
factors[meanwg] += wg;
|
||||
factors[meanwgxc] += wg * t;
|
||||
factors[normalizer2] += t * t;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float mean = local_mean[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
float meanwg = local_meanwg[0];
|
||||
float meanwgxc = local_meanwgx[0] - meanwg * mean;
|
||||
float normalizer2 = normalizer * normalizer;
|
||||
threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id);
|
||||
factors[meanwg] /= axis_size;
|
||||
factors[meanwgxc] /= axis_size;
|
||||
factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps);
|
||||
float normalizer = metal::precise::sqrt(factors[normalizer2]);
|
||||
|
||||
// Write the outputs
|
||||
gx += gid * size_t(axis_size) + lid * N_READS;
|
||||
@ -470,7 +396,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
gx[i + r] = static_cast<T>(
|
||||
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
|
||||
normalizer * (wi * gi - factors[meanwg]) -
|
||||
xi * factors[meanwgxc] * factors[normalizer2]);
|
||||
if (has_w) {
|
||||
gw[i + r] = static_cast<T>(gi * xi);
|
||||
}
|
||||
@ -482,7 +409,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
gx[i + r] = static_cast<T>(
|
||||
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
|
||||
normalizer * (wi * gi - factors[meanwg]) -
|
||||
xi * factors[meanwgxc] * factors[normalizer2]);
|
||||
if (has_w) {
|
||||
gw[i + r] = static_cast<T>(gi * xi);
|
||||
}
|
||||
|
@ -255,12 +255,13 @@ void LayerNorm::eval_gpu(
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = RMS_N_READS;
|
||||
const int looped_limit = RMS_LOOPED_LIMIT;
|
||||
int simd_size = 32;
|
||||
int n_reads = 8;
|
||||
int looped_limit = 6656;
|
||||
std::string op_name = "layer_norm";
|
||||
if (axis_size > looped_limit) {
|
||||
op_name += "_looped";
|
||||
n_reads = 4;
|
||||
}
|
||||
op_name += type_to_name(out);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
@ -272,7 +273,13 @@ void LayerNorm::eval_gpu(
|
||||
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
||||
size_t threadgroup_size = simd_size * simds_needed;
|
||||
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
if (threadgroup_size > kernel->maxTotalThreadsPerThreadgroup()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[layer_norm] Threadgroup size " << threadgroup_size
|
||||
<< " is larger than the maximum allowed threadgroup size "
|
||||
<< kernel->maxTotalThreadsPerThreadgroup();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
@ -372,12 +379,13 @@ void LayerNormVJP::eval_gpu(
|
||||
g, gb, "sum", plan, {0}, compute_encoder, d, s);
|
||||
}
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = RMS_N_READS;
|
||||
const int looped_limit = RMS_LOOPED_LIMIT;
|
||||
int simd_size = 32;
|
||||
int n_reads = 8;
|
||||
int looped_limit = 8192;
|
||||
std::string op_name = "vjp_layer_norm";
|
||||
if (axis_size > looped_limit) {
|
||||
op_name += "_looped";
|
||||
n_reads = 4;
|
||||
}
|
||||
op_name += type_to_name(gx);
|
||||
|
||||
@ -394,7 +402,13 @@ void LayerNormVJP::eval_gpu(
|
||||
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
||||
size_t threadgroup_size = simd_size * simds_needed;
|
||||
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
if (threadgroup_size > kernel->maxTotalThreadsPerThreadgroup()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[vjp_layer_norm] Threadgroup size " << threadgroup_size
|
||||
<< " is larger than the maximum allowed threadgroup size "
|
||||
<< kernel->maxTotalThreadsPerThreadgroup();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
|
@ -369,7 +369,7 @@ bool ScaledDotProductAttention::use_fallback(
|
||||
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
|
||||
(query_sequence_length <= key_sequence_length && do_causal);
|
||||
|
||||
const bool supports_sdpa_full =
|
||||
const bool supports_sdpa_full = query_sequence_length > 8 &&
|
||||
sdpa_full_supported_mask && sdpa_full_supported_head_dim;
|
||||
|
||||
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
|
||||
|
@ -231,13 +231,11 @@ array layer_norm(
|
||||
const std::vector<array>& inputs) {
|
||||
auto x = astype(inputs[0], float32, s);
|
||||
|
||||
// Should I not be smart here and leave the double mean to simplify()?
|
||||
auto mu = mean(x, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
auto mu2 = square(mu, s);
|
||||
auto x2 = mean(square(x, s), /* axis= */ -1, /* keepdims= */ true, s);
|
||||
auto v = subtract(x2, mu2, s);
|
||||
auto xc = subtract(x, mu, s);
|
||||
auto v = mean(square(xc, s), /* axis= */ -1, /* keepdims= */ true, s);
|
||||
|
||||
x = multiply(subtract(x, mu, s), rsqrt(add(v, array(eps, float32), s), s));
|
||||
x = multiply(xc, rsqrt(add(v, array(eps, float32), s), s));
|
||||
x = astype(x, out_type, s);
|
||||
|
||||
// If the LN is affine then transform x according to the weight and bias
|
||||
|
Loading…
Reference in New Issue
Block a user