RMS norm without scaling (#1915)

This commit is contained in:
Angelos Katharopoulos
2025-02-28 20:26:57 -08:00
committed by GitHub
parent 5d68082881
commit 5e6c130d93
9 changed files with 220 additions and 101 deletions

View File

@@ -7,6 +7,8 @@
using namespace metal;
constant bool has_w [[function_constant(20)]];
template <typename T, int N_READS = RMS_N_READS>
[[kernel]] void layer_norm_single_row(
const device T* x,
@@ -327,7 +329,9 @@ template <typename T, int N_READS = RMS_N_READS>
gx[i] = static_cast<T>(
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
thread_x[i] * meanwgxc * normalizer2);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
if (has_w) {
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
}
}
} else {
for (int i = 0; i < N_READS; i++) {
@@ -336,7 +340,9 @@ template <typename T, int N_READS = RMS_N_READS>
gx[i] = static_cast<T>(
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
thread_x[i] * meanwgxc * normalizer2);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
if (has_w) {
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
}
}
}
}
@@ -465,7 +471,9 @@ template <typename T, int N_READS = RMS_N_READS>
float gi = g[i + r];
gx[i + r] = static_cast<T>(
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
gw[i + r] = static_cast<T>(gi * xi);
if (has_w) {
gw[i + r] = static_cast<T>(gi * xi);
}
}
} else {
for (int i = 0; i < N_READS; i++) {
@@ -475,7 +483,9 @@ template <typename T, int N_READS = RMS_N_READS>
float gi = g[i + r];
gx[i + r] = static_cast<T>(
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
gw[i + r] = static_cast<T>(gi * xi);
if (has_w) {
gw[i + r] = static_cast<T>(gi * xi);
}
}
}
}

View File

@@ -7,6 +7,8 @@
using namespace metal;
constant bool has_w [[function_constant(20)]];
template <typename T, int N_READS = RMS_N_READS>
[[kernel]] void rms_single_row(
const device T* x,
@@ -243,7 +245,9 @@ template <typename T, int N_READS = RMS_N_READS>
gx[i] = static_cast<T>(
thread_g[i] * thread_w[i] * normalizer -
thread_x[i] * meangwx * normalizer3);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
if (has_w) {
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
}
}
} else {
for (int i = 0; i < N_READS; i++) {
@@ -251,7 +255,9 @@ template <typename T, int N_READS = RMS_N_READS>
gx[i] = static_cast<T>(
thread_g[i] * thread_w[i] * normalizer -
thread_x[i] * meangwx * normalizer3);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
if (has_w) {
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
}
}
}
}
@@ -351,7 +357,9 @@ template <typename T, int N_READS = RMS_N_READS>
gx[i + r] =
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
gw[i + r] = static_cast<T>(gi * xi * normalizer);
if (has_w) {
gw[i + r] = static_cast<T>(gi * xi * normalizer);
}
}
} else {
for (int i = 0; i < N_READS; i++) {
@@ -362,7 +370,9 @@ template <typename T, int N_READS = RMS_N_READS>
gx[i + r] =
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
gw[i + r] = static_cast<T>(gi * xi * normalizer);
if (has_w) {
gw[i + r] = static_cast<T>(gi * xi * normalizer);
}
}
}
}