mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
RMS norm without scaling (#1915)
This commit is contained in:
committed by
GitHub
parent
5d68082881
commit
5e6c130d93
@@ -7,6 +7,8 @@
|
||||
|
||||
using namespace metal;
|
||||
|
||||
constant bool has_w [[function_constant(20)]];
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void layer_norm_single_row(
|
||||
const device T* x,
|
||||
@@ -327,7 +329,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
gx[i] = static_cast<T>(
|
||||
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||
thread_x[i] * meanwgxc * normalizer2);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||
if (has_w) {
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
@@ -336,7 +340,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
gx[i] = static_cast<T>(
|
||||
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||
thread_x[i] * meanwgxc * normalizer2);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||
if (has_w) {
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -465,7 +471,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
float gi = g[i + r];
|
||||
gx[i + r] = static_cast<T>(
|
||||
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
|
||||
gw[i + r] = static_cast<T>(gi * xi);
|
||||
if (has_w) {
|
||||
gw[i + r] = static_cast<T>(gi * xi);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
@@ -475,7 +483,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
float gi = g[i + r];
|
||||
gx[i + r] = static_cast<T>(
|
||||
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
|
||||
gw[i + r] = static_cast<T>(gi * xi);
|
||||
if (has_w) {
|
||||
gw[i + r] = static_cast<T>(gi * xi);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
|
||||
using namespace metal;
|
||||
|
||||
constant bool has_w [[function_constant(20)]];
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void rms_single_row(
|
||||
const device T* x,
|
||||
@@ -243,7 +245,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
gx[i] = static_cast<T>(
|
||||
thread_g[i] * thread_w[i] * normalizer -
|
||||
thread_x[i] * meangwx * normalizer3);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||
if (has_w) {
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
@@ -251,7 +255,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
gx[i] = static_cast<T>(
|
||||
thread_g[i] * thread_w[i] * normalizer -
|
||||
thread_x[i] * meangwx * normalizer3);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||
if (has_w) {
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -351,7 +357,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
|
||||
gx[i + r] =
|
||||
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||
if (has_w) {
|
||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
@@ -362,7 +370,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
|
||||
gx[i + r] =
|
||||
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||
if (has_w) {
|
||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user