mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
RMS norm without scaling (#1915)
This commit is contained in:
committed by
GitHub
parent
5d68082881
commit
5e6c130d93
95
mlx/fast.cpp
95
mlx/fast.cpp
@@ -54,30 +54,34 @@ std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
|
||||
|
||||
array rms_norm(
|
||||
const array& x,
|
||||
const array& weight,
|
||||
const std::optional<array>& weight,
|
||||
float eps,
|
||||
StreamOrDevice s_ /* = {} */) {
|
||||
bool has_weight = weight.has_value();
|
||||
|
||||
if (x.ndim() == 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rms_norm] Input must have at least 1 dimension but got input with "
|
||||
"0 dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (weight.ndim() != 1) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rms_norm] weight must have 1 dimension but has " << weight.ndim()
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (weight.size() != x.shape(-1)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rms_norm] weight must have the same size as the last dimension of"
|
||||
" x but has "
|
||||
<< weight.size() << " elements.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
if (has_weight) {
|
||||
if ((*weight).ndim() != 1) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rms_norm] (*weight) must have 1 dimension but has "
|
||||
<< (*weight).ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if ((*weight).size() != x.shape(-1)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rms_norm] (*weight) must have the same size as the last dimension of"
|
||||
" x but has "
|
||||
<< (*weight).size() << " elements.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
auto out_type = result_type(x, weight);
|
||||
auto out_type = (weight.has_value()) ? result_type(x, (*weight)) : x.dtype();
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rms_norm] Received unsupported type " << out_type << ".";
|
||||
@@ -85,27 +89,36 @@ array rms_norm(
|
||||
}
|
||||
|
||||
auto s = to_stream(s_);
|
||||
auto fallback = [eps, out_type, s](const std::vector<array>& inputs) {
|
||||
auto x = astype(inputs[0], float32, s);
|
||||
x = multiply(
|
||||
x,
|
||||
rsqrt(
|
||||
add(mean(square(x, s), -1, /* keepdims */ true, s),
|
||||
array(eps, float32),
|
||||
auto fallback =
|
||||
[has_weight, eps, out_type, s](const std::vector<array>& inputs) {
|
||||
auto x = astype(inputs[0], float32, s);
|
||||
x = multiply(
|
||||
x,
|
||||
rsqrt(
|
||||
add(mean(square(x, s), -1, /* keepdims */ true, s),
|
||||
array(eps, float32),
|
||||
s),
|
||||
s),
|
||||
s),
|
||||
s);
|
||||
x = astype(x, out_type, s);
|
||||
return std::vector<array>{multiply(inputs[1], x, s)};
|
||||
};
|
||||
s);
|
||||
x = astype(x, out_type, s);
|
||||
|
||||
if (has_weight) {
|
||||
x = multiply(x, inputs[1], s);
|
||||
}
|
||||
|
||||
return std::vector<array>{x};
|
||||
};
|
||||
|
||||
auto passed_weight =
|
||||
(has_weight) ? astype(*weight, out_type, s) : array(1, out_type);
|
||||
if (s.device == Device::gpu) {
|
||||
return array(
|
||||
x.shape(),
|
||||
out_type,
|
||||
std::make_shared<RMSNorm>(s, fallback, eps),
|
||||
{astype(x, out_type, s), astype(weight, out_type, s)});
|
||||
{astype(x, out_type, s), passed_weight});
|
||||
}
|
||||
return fallback({x, weight})[0];
|
||||
return fallback({x, passed_weight})[0];
|
||||
}
|
||||
|
||||
std::vector<array> RMSNorm::vjp(
|
||||
@@ -141,8 +154,12 @@ std::vector<array> RMSNorm::vjp(
|
||||
// df/dw
|
||||
std::vector<int> axes(g.ndim() - 1);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
vjps.push_back(
|
||||
sum(multiply(g, multiply(x, n, s), s), axes, /* keepdims= */ false, s));
|
||||
if (w.ndim() == 0) {
|
||||
vjps.push_back(zeros_like(w, s));
|
||||
} else {
|
||||
vjps.push_back(sum(
|
||||
multiply(g, multiply(x, n, s), s), axes, /* keepdims= */ false, s));
|
||||
}
|
||||
|
||||
return vjps;
|
||||
};
|
||||
@@ -177,28 +194,30 @@ array layer_norm(
|
||||
const std::optional<array>& bias,
|
||||
float eps,
|
||||
StreamOrDevice s_ /* = {} */) {
|
||||
bool has_weight = weight.has_value();
|
||||
bool has_bias = bias.has_value();
|
||||
|
||||
if (x.ndim() == 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[layer_norm] Input must have at least 1 dimension but got input with "
|
||||
"0 dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (weight.has_value() && (*weight).ndim() != 1) {
|
||||
if (has_weight && (*weight).ndim() != 1) {
|
||||
std::ostringstream msg;
|
||||
msg << "[layer_norm] weight must have 1 dimension but has "
|
||||
<< (*weight).ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (bias.has_value() && (*bias).ndim() != 1) {
|
||||
if (has_bias && (*bias).ndim() != 1) {
|
||||
std::ostringstream msg;
|
||||
msg << "[layer_norm] bias must have 1 dimension but has " << (*bias).ndim()
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto out_type = (weight.has_value())
|
||||
? ((bias.has_value()) ? result_type(x, *weight, *bias)
|
||||
: result_type(x, *weight))
|
||||
auto out_type = (has_weight)
|
||||
? ((has_bias) ? result_type(x, *weight, *bias) : result_type(x, *weight))
|
||||
: x.dtype();
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
@@ -207,8 +226,6 @@ array layer_norm(
|
||||
}
|
||||
|
||||
auto s = to_stream(s_);
|
||||
bool has_weight = weight.has_value();
|
||||
bool has_bias = bias.has_value();
|
||||
auto fallback = [has_weight, has_bias, eps, out_type, s](
|
||||
const std::vector<array>& inputs) {
|
||||
auto x = astype(inputs[0], float32, s);
|
||||
@@ -234,9 +251,9 @@ array layer_norm(
|
||||
};
|
||||
|
||||
auto passed_weight =
|
||||
astype((weight.has_value()) ? *weight : array(1, out_type), out_type);
|
||||
(has_weight) ? astype(*weight, out_type, s) : array(1, out_type);
|
||||
auto passed_bias =
|
||||
astype((bias.has_value()) ? *bias : array(0, out_type), out_type);
|
||||
(has_bias) ? astype(*bias, out_type, s) : array(0, out_type);
|
||||
|
||||
if (s.device == Device::gpu) {
|
||||
return array(
|
||||
|
||||
Reference in New Issue
Block a user