|
|
|
@@ -9,7 +9,41 @@ using namespace metal;
|
|
|
|
|
|
|
|
|
|
constant bool has_w [[function_constant(20)]];
|
|
|
|
|
|
|
|
|
|
template <typename T, int N_READS = RMS_N_READS>
|
|
|
|
|
template <int N = 1>
|
|
|
|
|
inline void initialize_buffer(
|
|
|
|
|
threadgroup float* xs,
|
|
|
|
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
|
|
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
if (simd_group_id == 0) {
|
|
|
|
|
for (int i=0; i<N; i++) {
|
|
|
|
|
xs[N * simd_lane_id + i] = 0;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <int N = 1>
|
|
|
|
|
inline void threadgroup_sum(
|
|
|
|
|
thread float* x,
|
|
|
|
|
threadgroup float* xs,
|
|
|
|
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
|
|
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
for (int i=0; i<N; i++) {
|
|
|
|
|
x[i] = simd_sum(x[i]);
|
|
|
|
|
}
|
|
|
|
|
if (simd_lane_id == 0) {
|
|
|
|
|
for (int i=0; i<N; i++) {
|
|
|
|
|
xs[N * simd_group_id + i] = x[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
for (int i=0; i<N; i++) {
|
|
|
|
|
x[i] = xs[N * simd_lane_id + i];
|
|
|
|
|
x[i] = simd_sum(x[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, int N_READS = 8>
|
|
|
|
|
[[kernel]] void layer_norm_single_row(
|
|
|
|
|
const device T* x,
|
|
|
|
|
const device T* w,
|
|
|
|
@@ -23,90 +57,71 @@ template <typename T, int N_READS = RMS_N_READS>
|
|
|
|
|
uint lid [[thread_position_in_threadgroup]],
|
|
|
|
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
|
|
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
float sumx = 0;
|
|
|
|
|
float sumx2 = 0;
|
|
|
|
|
float thread_x[N_READS];
|
|
|
|
|
|
|
|
|
|
constexpr int SIMD_SIZE = 32;
|
|
|
|
|
|
|
|
|
|
threadgroup float local_sumx[SIMD_SIZE];
|
|
|
|
|
threadgroup float local_sumx2[SIMD_SIZE];
|
|
|
|
|
threadgroup float local_mean[1];
|
|
|
|
|
threadgroup float local_normalizer[1];
|
|
|
|
|
// Initialize the registers and threadgroup memory
|
|
|
|
|
float thread_x[N_READS] = {0};
|
|
|
|
|
threadgroup float local_buffer[SIMD_SIZE] = {0};
|
|
|
|
|
initialize_buffer(local_buffer, simd_lane_id, simd_group_id);
|
|
|
|
|
|
|
|
|
|
// Advance the pointers
|
|
|
|
|
x += gid * size_t(axis_size) + lid * N_READS;
|
|
|
|
|
w += w_stride * lid * N_READS;
|
|
|
|
|
b += b_stride * lid * N_READS;
|
|
|
|
|
out += gid * size_t(axis_size) + lid * N_READS;
|
|
|
|
|
|
|
|
|
|
if (lid * N_READS + N_READS <= axis_size) {
|
|
|
|
|
// Compute some variables for reading writing etc
|
|
|
|
|
const bool safe = lid * N_READS + N_READS <= axis_size;
|
|
|
|
|
const int n = axis_size - lid * N_READS;
|
|
|
|
|
|
|
|
|
|
// Read the inputs
|
|
|
|
|
if (safe) {
|
|
|
|
|
for (int i = 0; i < N_READS; i++) {
|
|
|
|
|
thread_x[i] = x[i];
|
|
|
|
|
sumx2 += thread_x[i] * thread_x[i];
|
|
|
|
|
sumx += thread_x[i];
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (int i = 0; i < N_READS; i++) {
|
|
|
|
|
if ((lid * N_READS + i) < axis_size) {
|
|
|
|
|
thread_x[i] = x[i];
|
|
|
|
|
sumx2 += thread_x[i] * thread_x[i];
|
|
|
|
|
sumx += thread_x[i];
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < n; i++) {
|
|
|
|
|
thread_x[i] = x[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sumx = simd_sum(sumx);
|
|
|
|
|
sumx2 = simd_sum(sumx2);
|
|
|
|
|
|
|
|
|
|
// Initialize shared memory
|
|
|
|
|
if (simd_group_id == 0) {
|
|
|
|
|
local_sumx[simd_lane_id] = 0;
|
|
|
|
|
local_sumx2[simd_lane_id] = 0;
|
|
|
|
|
// Compute the mean
|
|
|
|
|
float mean = 0;
|
|
|
|
|
for (int i = 0; i < N_READS; i++) {
|
|
|
|
|
mean += thread_x[i];
|
|
|
|
|
}
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);
|
|
|
|
|
mean /= axis_size;
|
|
|
|
|
|
|
|
|
|
// Write simd accumulations into shared memory
|
|
|
|
|
if (simd_lane_id == 0) {
|
|
|
|
|
local_sumx[simd_group_id] = sumx;
|
|
|
|
|
local_sumx2[simd_group_id] = sumx2;
|
|
|
|
|
}
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
|
|
|
|
// Accumulate over simd groups
|
|
|
|
|
if (simd_group_id == 0) {
|
|
|
|
|
sumx = simd_sum(local_sumx[simd_lane_id]);
|
|
|
|
|
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
|
|
|
|
if (simd_lane_id == 0) {
|
|
|
|
|
float mean = sumx / axis_size;
|
|
|
|
|
float variance = sumx2 / axis_size - mean * mean;
|
|
|
|
|
|
|
|
|
|
local_mean[0] = mean;
|
|
|
|
|
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
|
|
|
|
// Compute the normalizer
|
|
|
|
|
float normalizer = 0;
|
|
|
|
|
if (!safe) {
|
|
|
|
|
for (int i = n; i < N_READS; i++) {
|
|
|
|
|
thread_x[i] = mean;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
|
|
|
|
float mean = local_mean[0];
|
|
|
|
|
float normalizer = local_normalizer[0];
|
|
|
|
|
for (int i = 0; i < N_READS; i++) {
|
|
|
|
|
thread_x[i] -= mean;
|
|
|
|
|
normalizer += thread_x[i] * thread_x[i];
|
|
|
|
|
}
|
|
|
|
|
threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id);
|
|
|
|
|
normalizer = metal::precise::rsqrt(normalizer / axis_size + eps);
|
|
|
|
|
|
|
|
|
|
// Write the outputs
|
|
|
|
|
out += gid * size_t(axis_size) + lid * N_READS;
|
|
|
|
|
if (lid * N_READS + N_READS <= axis_size) {
|
|
|
|
|
if (safe) {
|
|
|
|
|
for (int i = 0; i < N_READS; i++) {
|
|
|
|
|
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
|
|
|
|
thread_x[i] *= normalizer;
|
|
|
|
|
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (int i = 0; i < N_READS; i++) {
|
|
|
|
|
if ((lid * N_READS + i) < axis_size) {
|
|
|
|
|
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
|
|
|
|
out[i] =
|
|
|
|
|
w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < n; i++) {
|
|
|
|
|
thread_x[i] *= normalizer;
|
|
|
|
|
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, int N_READS = RMS_N_READS>
|
|
|
|
|
template <typename T, int N_READS = 4>
|
|
|
|
|
[[kernel]] void layer_norm_looped(
|
|
|
|
|
const device T* x,
|
|
|
|
|
const device T* w,
|
|
|
|
@@ -121,71 +136,52 @@ template <typename T, int N_READS = RMS_N_READS>
|
|
|
|
|
uint lsize [[threads_per_threadgroup]],
|
|
|
|
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
|
|
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
float sumx = 0;
|
|
|
|
|
float sumx2 = 0;
|
|
|
|
|
|
|
|
|
|
constexpr int SIMD_SIZE = 32;
|
|
|
|
|
|
|
|
|
|
threadgroup float local_sumx[SIMD_SIZE];
|
|
|
|
|
threadgroup float local_sumx2[SIMD_SIZE];
|
|
|
|
|
threadgroup float local_mean[1];
|
|
|
|
|
threadgroup float local_normalizer[1];
|
|
|
|
|
threadgroup float local_buffer[SIMD_SIZE];
|
|
|
|
|
initialize_buffer(local_buffer, simd_lane_id, simd_group_id);
|
|
|
|
|
|
|
|
|
|
x += gid * size_t(axis_size) + lid * N_READS;
|
|
|
|
|
w += w_stride * lid * N_READS;
|
|
|
|
|
b += b_stride * lid * N_READS;
|
|
|
|
|
|
|
|
|
|
// Compute the mean
|
|
|
|
|
float mean = 0;
|
|
|
|
|
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
|
|
|
|
if (r + lid * N_READS + N_READS <= axis_size) {
|
|
|
|
|
for (int i = 0; i < N_READS; i++) {
|
|
|
|
|
float xi = x[i + r];
|
|
|
|
|
sumx2 += xi * xi;
|
|
|
|
|
sumx += xi;
|
|
|
|
|
mean += x[i + r];
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (int i = 0; i < N_READS; i++) {
|
|
|
|
|
if ((r + lid * N_READS + i) < axis_size) {
|
|
|
|
|
float xi = x[i + r];
|
|
|
|
|
sumx2 += xi * xi;
|
|
|
|
|
sumx += xi;
|
|
|
|
|
mean += x[i + r];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);
|
|
|
|
|
mean /= axis_size;
|
|
|
|
|
|
|
|
|
|
sumx = simd_sum(sumx);
|
|
|
|
|
sumx2 = simd_sum(sumx2);
|
|
|
|
|
|
|
|
|
|
// Initialize shared memory
|
|
|
|
|
if (simd_group_id == 0) {
|
|
|
|
|
local_sumx[simd_lane_id] = 0;
|
|
|
|
|
local_sumx2[simd_lane_id] = 0;
|
|
|
|
|
}
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
|
|
|
|
// Write simd accumulations into shared memory
|
|
|
|
|
if (simd_lane_id == 0) {
|
|
|
|
|
local_sumx[simd_group_id] = sumx;
|
|
|
|
|
local_sumx2[simd_group_id] = sumx2;
|
|
|
|
|
}
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
|
|
|
|
// Accumulate over simd groups
|
|
|
|
|
if (simd_group_id == 0) {
|
|
|
|
|
sumx = simd_sum(local_sumx[simd_lane_id]);
|
|
|
|
|
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
|
|
|
|
if (simd_lane_id == 0) {
|
|
|
|
|
float mean = sumx / axis_size;
|
|
|
|
|
float variance = sumx2 / axis_size - mean * mean;
|
|
|
|
|
|
|
|
|
|
local_mean[0] = mean;
|
|
|
|
|
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
|
|
|
|
// Compute the normalizer
|
|
|
|
|
float normalizer = 0;
|
|
|
|
|
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
|
|
|
|
if (r + lid * N_READS + N_READS <= axis_size) {
|
|
|
|
|
for (int i = 0; i < N_READS; i++) {
|
|
|
|
|
float t = x[i + r] - mean;
|
|
|
|
|
normalizer += t * t;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (int i = 0; i < N_READS; i++) {
|
|
|
|
|
if ((r + lid * N_READS + i) < axis_size) {
|
|
|
|
|
float t = x[i + r] - mean;
|
|
|
|
|
normalizer += t * t;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
|
|
|
|
float mean = local_mean[0];
|
|
|
|
|
float normalizer = local_normalizer[0];
|
|
|
|
|
threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id);
|
|
|
|
|
normalizer = metal::precise::rsqrt(normalizer / axis_size + eps);
|
|
|
|
|
|
|
|
|
|
// Write the outputs
|
|
|
|
|
out += gid * size_t(axis_size) + lid * N_READS;
|
|
|
|
@@ -208,7 +204,7 @@ template <typename T, int N_READS = RMS_N_READS>
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, int N_READS = RMS_N_READS>
|
|
|
|
|
template <typename T, int N_READS = 8>
|
|
|
|
|
[[kernel]] void vjp_layer_norm_single_row(
|
|
|
|
|
const device T* x,
|
|
|
|
|
const device T* w,
|
|
|
|
@@ -222,133 +218,96 @@ template <typename T, int N_READS = RMS_N_READS>
|
|
|
|
|
uint lid [[thread_position_in_threadgroup]],
|
|
|
|
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
|
|
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
constexpr int SIMD_SIZE = 32;
|
|
|
|
|
|
|
|
|
|
// Advance the input pointers
|
|
|
|
|
x += gid * size_t(axis_size) + lid * N_READS;
|
|
|
|
|
g += gid * size_t(axis_size) + lid * N_READS;
|
|
|
|
|
w += w_stride * lid * N_READS;
|
|
|
|
|
|
|
|
|
|
// Allocate registers for the computation and accumulators
|
|
|
|
|
float thread_x[N_READS];
|
|
|
|
|
float thread_w[N_READS];
|
|
|
|
|
float thread_g[N_READS];
|
|
|
|
|
float sumx = 0;
|
|
|
|
|
float sumx2 = 0;
|
|
|
|
|
float sumwg = 0;
|
|
|
|
|
float sumwgx = 0;
|
|
|
|
|
// Initialize the registers and threadgroup memory
|
|
|
|
|
float thread_x[N_READS] = {0};
|
|
|
|
|
float thread_w[N_READS] = {0};
|
|
|
|
|
float thread_g[N_READS] = {0};
|
|
|
|
|
threadgroup float local_buffer[3 * SIMD_SIZE];
|
|
|
|
|
initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id);
|
|
|
|
|
|
|
|
|
|
constexpr int SIMD_SIZE = 32;
|
|
|
|
|
// Compute some variables for reading writing etc
|
|
|
|
|
const bool safe = lid * N_READS + N_READS <= axis_size;
|
|
|
|
|
const int n = axis_size - lid * N_READS;
|
|
|
|
|
|
|
|
|
|
threadgroup float local_sumx[SIMD_SIZE];
|
|
|
|
|
threadgroup float local_sumx2[SIMD_SIZE];
|
|
|
|
|
threadgroup float local_sumwg[SIMD_SIZE];
|
|
|
|
|
threadgroup float local_sumwgx[SIMD_SIZE];
|
|
|
|
|
threadgroup float local_mean[1];
|
|
|
|
|
threadgroup float local_normalizer[1];
|
|
|
|
|
threadgroup float local_meanwg[1];
|
|
|
|
|
threadgroup float local_meanwgx[1];
|
|
|
|
|
|
|
|
|
|
if (lid * N_READS + N_READS <= axis_size) {
|
|
|
|
|
// Read the inputs
|
|
|
|
|
if (safe) {
|
|
|
|
|
for (int i = 0; i < N_READS; i++) {
|
|
|
|
|
thread_x[i] = x[i];
|
|
|
|
|
thread_w[i] = w[i * w_stride];
|
|
|
|
|
thread_g[i] = g[i];
|
|
|
|
|
float wg = thread_w[i] * thread_g[i];
|
|
|
|
|
sumx += thread_x[i];
|
|
|
|
|
sumx2 += thread_x[i] * thread_x[i];
|
|
|
|
|
sumwg += wg;
|
|
|
|
|
sumwgx += wg * thread_x[i];
|
|
|
|
|
thread_w[i] = w[i * w_stride];
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (int i = 0; i < N_READS; i++) {
|
|
|
|
|
if ((lid * N_READS + i) < axis_size) {
|
|
|
|
|
thread_x[i] = x[i];
|
|
|
|
|
thread_w[i] = w[i * w_stride];
|
|
|
|
|
thread_g[i] = g[i];
|
|
|
|
|
float wg = thread_w[i] * thread_g[i];
|
|
|
|
|
sumx += thread_x[i];
|
|
|
|
|
sumx2 += thread_x[i] * thread_x[i];
|
|
|
|
|
sumwg += wg;
|
|
|
|
|
sumwgx += wg * thread_x[i];
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < n; i++) {
|
|
|
|
|
thread_x[i] = x[i];
|
|
|
|
|
thread_g[i] = g[i];
|
|
|
|
|
thread_w[i] = w[i * w_stride];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sumx = simd_sum(sumx);
|
|
|
|
|
sumx2 = simd_sum(sumx2);
|
|
|
|
|
sumwg = simd_sum(sumwg);
|
|
|
|
|
sumwgx = simd_sum(sumwgx);
|
|
|
|
|
|
|
|
|
|
// Initialize shared memory
|
|
|
|
|
if (simd_group_id == 0) {
|
|
|
|
|
local_sumx[simd_lane_id] = 0;
|
|
|
|
|
local_sumx2[simd_lane_id] = 0;
|
|
|
|
|
local_sumwg[simd_lane_id] = 0;
|
|
|
|
|
local_sumwgx[simd_lane_id] = 0;
|
|
|
|
|
// Compute the mean
|
|
|
|
|
float mean = 0;
|
|
|
|
|
for (int i = 0; i < N_READS; i++) {
|
|
|
|
|
mean += thread_x[i];
|
|
|
|
|
}
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);
|
|
|
|
|
mean /= axis_size;
|
|
|
|
|
|
|
|
|
|
// Write simd accumulations into shared memory
|
|
|
|
|
if (simd_lane_id == 0) {
|
|
|
|
|
local_sumx[simd_group_id] = sumx;
|
|
|
|
|
local_sumx2[simd_group_id] = sumx2;
|
|
|
|
|
local_sumwg[simd_group_id] = sumwg;
|
|
|
|
|
local_sumwgx[simd_group_id] = sumwgx;
|
|
|
|
|
}
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
|
|
|
|
// Accumulate over simd groups
|
|
|
|
|
if (simd_group_id == 0) {
|
|
|
|
|
sumx = simd_sum(local_sumx[simd_lane_id]);
|
|
|
|
|
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
|
|
|
|
sumwg = simd_sum(local_sumwg[simd_lane_id]);
|
|
|
|
|
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
|
|
|
|
|
if (simd_lane_id == 0) {
|
|
|
|
|
float mean = sumx / axis_size;
|
|
|
|
|
float variance = sumx2 / axis_size - mean * mean;
|
|
|
|
|
|
|
|
|
|
local_mean[0] = mean;
|
|
|
|
|
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
|
|
|
|
local_meanwg[0] = sumwg / axis_size;
|
|
|
|
|
local_meanwgx[0] = sumwgx / axis_size;
|
|
|
|
|
// Compute the neccesary scaling factors using the mean
|
|
|
|
|
if (!safe) {
|
|
|
|
|
for (int i = n; i < N_READS; i++) {
|
|
|
|
|
thread_x[i] = mean;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
|
|
|
|
float mean = local_mean[0];
|
|
|
|
|
float normalizer = local_normalizer[0];
|
|
|
|
|
float meanwg = local_meanwg[0];
|
|
|
|
|
float meanwgxc = local_meanwgx[0] - meanwg * mean;
|
|
|
|
|
float normalizer2 = normalizer * normalizer;
|
|
|
|
|
float factors[3] = {0};
|
|
|
|
|
constexpr int meanwg = 0;
|
|
|
|
|
constexpr int meanwgxc = 1;
|
|
|
|
|
constexpr int normalizer2 = 2;
|
|
|
|
|
for (int i = 0; i < N_READS; i++) {
|
|
|
|
|
thread_x[i] -= mean;
|
|
|
|
|
factors[meanwg] += thread_w[i] * thread_g[i];
|
|
|
|
|
factors[meanwgxc] += thread_w[i] * thread_g[i] * thread_x[i];
|
|
|
|
|
factors[normalizer2] += thread_x[i] * thread_x[i];
|
|
|
|
|
}
|
|
|
|
|
threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id);
|
|
|
|
|
factors[meanwg] /= axis_size;
|
|
|
|
|
factors[meanwgxc] /= axis_size;
|
|
|
|
|
factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps);
|
|
|
|
|
float normalizer = metal::precise::sqrt(factors[normalizer2]);
|
|
|
|
|
|
|
|
|
|
// Write the outputs
|
|
|
|
|
gx += gid * size_t(axis_size) + lid * N_READS;
|
|
|
|
|
gw += gid * size_t(axis_size) + lid * N_READS;
|
|
|
|
|
if (lid * N_READS + N_READS <= axis_size) {
|
|
|
|
|
if (safe) {
|
|
|
|
|
for (int i = 0; i < N_READS; i++) {
|
|
|
|
|
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
|
|
|
|
thread_x[i] *= normalizer;
|
|
|
|
|
gx[i] = static_cast<T>(
|
|
|
|
|
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
|
|
|
|
thread_x[i] * meanwgxc * normalizer2);
|
|
|
|
|
normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) -
|
|
|
|
|
thread_x[i] * factors[meanwgxc] * factors[normalizer2]);
|
|
|
|
|
if (has_w) {
|
|
|
|
|
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (int i = 0; i < N_READS; i++) {
|
|
|
|
|
if ((lid * N_READS + i) < axis_size) {
|
|
|
|
|
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
|
|
|
|
gx[i] = static_cast<T>(
|
|
|
|
|
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
|
|
|
|
thread_x[i] * meanwgxc * normalizer2);
|
|
|
|
|
if (has_w) {
|
|
|
|
|
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < n; i++) {
|
|
|
|
|
thread_x[i] *= normalizer;
|
|
|
|
|
gx[i] = static_cast<T>(
|
|
|
|
|
normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) -
|
|
|
|
|
thread_x[i] * factors[meanwgxc] * factors[normalizer2]);
|
|
|
|
|
if (has_w) {
|
|
|
|
|
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, int N_READS = RMS_N_READS>
|
|
|
|
|
template <typename T, int N_READS = 4>
|
|
|
|
|
[[kernel]] void vjp_layer_norm_looped(
|
|
|
|
|
const device T* x,
|
|
|
|
|
const device T* w,
|
|
|
|
@@ -363,102 +322,69 @@ template <typename T, int N_READS = RMS_N_READS>
|
|
|
|
|
uint lsize [[threads_per_threadgroup]],
|
|
|
|
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
|
|
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
constexpr int SIMD_SIZE = 32;
|
|
|
|
|
|
|
|
|
|
// Advance the input pointers
|
|
|
|
|
x += gid * size_t(axis_size) + lid * N_READS;
|
|
|
|
|
g += gid * size_t(axis_size) + lid * N_READS;
|
|
|
|
|
w += w_stride * lid * N_READS;
|
|
|
|
|
|
|
|
|
|
// Allocate registers for the accumulators
|
|
|
|
|
float sumx = 0;
|
|
|
|
|
float sumx2 = 0;
|
|
|
|
|
float sumwg = 0;
|
|
|
|
|
float sumwgx = 0;
|
|
|
|
|
|
|
|
|
|
constexpr int SIMD_SIZE = 32;
|
|
|
|
|
|
|
|
|
|
threadgroup float local_sumx[SIMD_SIZE];
|
|
|
|
|
threadgroup float local_sumx2[SIMD_SIZE];
|
|
|
|
|
threadgroup float local_sumwg[SIMD_SIZE];
|
|
|
|
|
threadgroup float local_sumwgx[SIMD_SIZE];
|
|
|
|
|
threadgroup float local_mean[1];
|
|
|
|
|
threadgroup float local_normalizer[1];
|
|
|
|
|
threadgroup float local_meanwg[1];
|
|
|
|
|
threadgroup float local_meanwgx[1];
|
|
|
|
|
threadgroup float local_buffer[3 * SIMD_SIZE];
|
|
|
|
|
initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id);
|
|
|
|
|
|
|
|
|
|
// Compute the mean
|
|
|
|
|
float mean = 0;
|
|
|
|
|
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
|
|
|
|
if (r + lid * N_READS + N_READS <= axis_size) {
|
|
|
|
|
for (int i = 0; i < N_READS; i++) {
|
|
|
|
|
float xi = x[i + r];
|
|
|
|
|
float wi = w[(i + r) * w_stride];
|
|
|
|
|
float gi = g[i + r];
|
|
|
|
|
float wg = wi * gi;
|
|
|
|
|
sumx += xi;
|
|
|
|
|
sumx2 += xi * xi;
|
|
|
|
|
sumwg += wg;
|
|
|
|
|
sumwgx += wg * xi;
|
|
|
|
|
mean += x[i + r];
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (int i = 0; i < N_READS; i++) {
|
|
|
|
|
if ((r + lid * N_READS + i) < axis_size) {
|
|
|
|
|
float xi = x[i + r];
|
|
|
|
|
float wi = w[(i + r) * w_stride];
|
|
|
|
|
float gi = g[i + r];
|
|
|
|
|
float wg = wi * gi;
|
|
|
|
|
sumx += xi;
|
|
|
|
|
sumx2 += xi * xi;
|
|
|
|
|
sumwg += wg;
|
|
|
|
|
sumwgx += wg * xi;
|
|
|
|
|
mean += x[i + r];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id);
|
|
|
|
|
mean /= axis_size;
|
|
|
|
|
|
|
|
|
|
sumx = simd_sum(sumx);
|
|
|
|
|
sumx2 = simd_sum(sumx2);
|
|
|
|
|
sumwg = simd_sum(sumwg);
|
|
|
|
|
sumwgx = simd_sum(sumwgx);
|
|
|
|
|
|
|
|
|
|
// Initialize shared memory
|
|
|
|
|
if (simd_group_id == 0) {
|
|
|
|
|
local_sumx[simd_lane_id] = 0;
|
|
|
|
|
local_sumx2[simd_lane_id] = 0;
|
|
|
|
|
local_sumwg[simd_lane_id] = 0;
|
|
|
|
|
local_sumwgx[simd_lane_id] = 0;
|
|
|
|
|
}
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
|
|
|
|
// Write simd accumulations into shared memory
|
|
|
|
|
if (simd_lane_id == 0) {
|
|
|
|
|
local_sumx[simd_group_id] = sumx;
|
|
|
|
|
local_sumx2[simd_group_id] = sumx2;
|
|
|
|
|
local_sumwg[simd_group_id] = sumwg;
|
|
|
|
|
local_sumwgx[simd_group_id] = sumwgx;
|
|
|
|
|
}
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
|
|
|
|
// Accumulate over simd groups
|
|
|
|
|
if (simd_group_id == 0) {
|
|
|
|
|
sumx = simd_sum(local_sumx[simd_lane_id]);
|
|
|
|
|
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
|
|
|
|
sumwg = simd_sum(local_sumwg[simd_lane_id]);
|
|
|
|
|
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
|
|
|
|
|
if (simd_lane_id == 0) {
|
|
|
|
|
float mean = sumx / axis_size;
|
|
|
|
|
float variance = sumx2 / axis_size - mean * mean;
|
|
|
|
|
|
|
|
|
|
local_mean[0] = mean;
|
|
|
|
|
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
|
|
|
|
local_meanwg[0] = sumwg / axis_size;
|
|
|
|
|
local_meanwgx[0] = sumwgx / axis_size;
|
|
|
|
|
// Compute the neccesary scaling factors using the mean
|
|
|
|
|
float factors[3] = {0};
|
|
|
|
|
constexpr int meanwg = 0;
|
|
|
|
|
constexpr int meanwgxc = 1;
|
|
|
|
|
constexpr int normalizer2 = 2;
|
|
|
|
|
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
|
|
|
|
if (r + lid * N_READS + N_READS <= axis_size) {
|
|
|
|
|
for (int i = 0; i < N_READS; i++) {
|
|
|
|
|
float t = x[i + r] - mean;
|
|
|
|
|
float wi = w[(i + r) * w_stride];
|
|
|
|
|
float gi = g[i + r];
|
|
|
|
|
float wg = wi * gi;
|
|
|
|
|
factors[meanwg] += wg;
|
|
|
|
|
factors[meanwgxc] += wg * t;
|
|
|
|
|
factors[normalizer2] += t * t;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (int i = 0; i < N_READS; i++) {
|
|
|
|
|
if ((r + lid * N_READS + i) < axis_size) {
|
|
|
|
|
float t = x[i + r] - mean;
|
|
|
|
|
float wi = w[(i + r) * w_stride];
|
|
|
|
|
float gi = g[i + r];
|
|
|
|
|
float wg = wi * gi;
|
|
|
|
|
factors[meanwg] += wg;
|
|
|
|
|
factors[meanwgxc] += wg * t;
|
|
|
|
|
factors[normalizer2] += t * t;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
|
|
|
|
float mean = local_mean[0];
|
|
|
|
|
float normalizer = local_normalizer[0];
|
|
|
|
|
float meanwg = local_meanwg[0];
|
|
|
|
|
float meanwgxc = local_meanwgx[0] - meanwg * mean;
|
|
|
|
|
float normalizer2 = normalizer * normalizer;
|
|
|
|
|
threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id);
|
|
|
|
|
factors[meanwg] /= axis_size;
|
|
|
|
|
factors[meanwgxc] /= axis_size;
|
|
|
|
|
factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps);
|
|
|
|
|
float normalizer = metal::precise::sqrt(factors[normalizer2]);
|
|
|
|
|
|
|
|
|
|
// Write the outputs
|
|
|
|
|
gx += gid * size_t(axis_size) + lid * N_READS;
|
|
|
|
@@ -470,7 +396,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|
|
|
|
float wi = w[(i + r) * w_stride];
|
|
|
|
|
float gi = g[i + r];
|
|
|
|
|
gx[i + r] = static_cast<T>(
|
|
|
|
|
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
|
|
|
|
|
normalizer * (wi * gi - factors[meanwg]) -
|
|
|
|
|
xi * factors[meanwgxc] * factors[normalizer2]);
|
|
|
|
|
if (has_w) {
|
|
|
|
|
gw[i + r] = static_cast<T>(gi * xi);
|
|
|
|
|
}
|
|
|
|
@@ -482,7 +409,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|
|
|
|
float wi = w[(i + r) * w_stride];
|
|
|
|
|
float gi = g[i + r];
|
|
|
|
|
gx[i + r] = static_cast<T>(
|
|
|
|
|
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
|
|
|
|
|
normalizer * (wi * gi - factors[meanwg]) -
|
|
|
|
|
xi * factors[meanwgxc] * factors[normalizer2]);
|
|
|
|
|
if (has_w) {
|
|
|
|
|
gw[i + r] = static_cast<T>(gi * xi);
|
|
|
|
|
}
|
|
|
|
|