RMS norm without scaling (#1915)

This commit is contained in:
Angelos Katharopoulos
2025-02-28 20:26:57 -08:00
committed by GitHub
parent 5d68082881
commit 5e6c130d93
9 changed files with 220 additions and 101 deletions

View File

@@ -54,30 +54,34 @@ std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
array rms_norm(
const array& x,
const array& weight,
const std::optional<array>& weight,
float eps,
StreamOrDevice s_ /* = {} */) {
bool has_weight = weight.has_value();
if (x.ndim() == 0) {
std::ostringstream msg;
msg << "[rms_norm] Input must have at least 1 dimension but got input with "
"0 dimensions.";
throw std::invalid_argument(msg.str());
}
if (weight.ndim() != 1) {
std::ostringstream msg;
msg << "[rms_norm] weight must have 1 dimension but has " << weight.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
if (weight.size() != x.shape(-1)) {
std::ostringstream msg;
msg << "[rms_norm] weight must have the same size as the last dimension of"
" x but has "
<< weight.size() << " elements.";
throw std::invalid_argument(msg.str());
if (has_weight) {
if ((*weight).ndim() != 1) {
std::ostringstream msg;
msg << "[rms_norm] (*weight) must have 1 dimension but has "
<< (*weight).ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if ((*weight).size() != x.shape(-1)) {
std::ostringstream msg;
msg << "[rms_norm] (*weight) must have the same size as the last dimension of"
" x but has "
<< (*weight).size() << " elements.";
throw std::invalid_argument(msg.str());
}
}
auto out_type = result_type(x, weight);
auto out_type = (weight.has_value()) ? result_type(x, (*weight)) : x.dtype();
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[rms_norm] Received unsupported type " << out_type << ".";
@@ -85,27 +89,36 @@ array rms_norm(
}
auto s = to_stream(s_);
auto fallback = [eps, out_type, s](const std::vector<array>& inputs) {
auto x = astype(inputs[0], float32, s);
x = multiply(
x,
rsqrt(
add(mean(square(x, s), -1, /* keepdims */ true, s),
array(eps, float32),
auto fallback =
[has_weight, eps, out_type, s](const std::vector<array>& inputs) {
auto x = astype(inputs[0], float32, s);
x = multiply(
x,
rsqrt(
add(mean(square(x, s), -1, /* keepdims */ true, s),
array(eps, float32),
s),
s),
s),
s);
x = astype(x, out_type, s);
return std::vector<array>{multiply(inputs[1], x, s)};
};
s);
x = astype(x, out_type, s);
if (has_weight) {
x = multiply(x, inputs[1], s);
}
return std::vector<array>{x};
};
auto passed_weight =
(has_weight) ? astype(*weight, out_type, s) : array(1, out_type);
if (s.device == Device::gpu) {
return array(
x.shape(),
out_type,
std::make_shared<RMSNorm>(s, fallback, eps),
{astype(x, out_type, s), astype(weight, out_type, s)});
{astype(x, out_type, s), passed_weight});
}
return fallback({x, weight})[0];
return fallback({x, passed_weight})[0];
}
std::vector<array> RMSNorm::vjp(
@@ -141,8 +154,12 @@ std::vector<array> RMSNorm::vjp(
// df/dw
std::vector<int> axes(g.ndim() - 1);
std::iota(axes.begin(), axes.end(), 0);
vjps.push_back(
sum(multiply(g, multiply(x, n, s), s), axes, /* keepdims= */ false, s));
if (w.ndim() == 0) {
vjps.push_back(zeros_like(w, s));
} else {
vjps.push_back(sum(
multiply(g, multiply(x, n, s), s), axes, /* keepdims= */ false, s));
}
return vjps;
};
@@ -177,28 +194,30 @@ array layer_norm(
const std::optional<array>& bias,
float eps,
StreamOrDevice s_ /* = {} */) {
bool has_weight = weight.has_value();
bool has_bias = bias.has_value();
if (x.ndim() == 0) {
std::ostringstream msg;
msg << "[layer_norm] Input must have at least 1 dimension but got input with "
"0 dimensions.";
throw std::invalid_argument(msg.str());
}
if (weight.has_value() && (*weight).ndim() != 1) {
if (has_weight && (*weight).ndim() != 1) {
std::ostringstream msg;
msg << "[layer_norm] weight must have 1 dimension but has "
<< (*weight).ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (bias.has_value() && (*bias).ndim() != 1) {
if (has_bias && (*bias).ndim() != 1) {
std::ostringstream msg;
msg << "[layer_norm] bias must have 1 dimension but has " << (*bias).ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
auto out_type = (weight.has_value())
? ((bias.has_value()) ? result_type(x, *weight, *bias)
: result_type(x, *weight))
auto out_type = (has_weight)
? ((has_bias) ? result_type(x, *weight, *bias) : result_type(x, *weight))
: x.dtype();
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
@@ -207,8 +226,6 @@ array layer_norm(
}
auto s = to_stream(s_);
bool has_weight = weight.has_value();
bool has_bias = bias.has_value();
auto fallback = [has_weight, has_bias, eps, out_type, s](
const std::vector<array>& inputs) {
auto x = astype(inputs[0], float32, s);
@@ -234,9 +251,9 @@ array layer_norm(
};
auto passed_weight =
astype((weight.has_value()) ? *weight : array(1, out_type), out_type);
(has_weight) ? astype(*weight, out_type, s) : array(1, out_type);
auto passed_bias =
astype((bias.has_value()) ? *bias : array(0, out_type), out_type);
(has_bias) ? astype(*bias, out_type, s) : array(0, out_type);
if (s.device == Device::gpu) {
return array(