RMS norm without scaling (#1915)

This commit is contained in:
Angelos Katharopoulos
2025-02-28 20:26:57 -08:00
committed by GitHub
parent 5d68082881
commit 5e6c130d93
9 changed files with 220 additions and 101 deletions

View File

@@ -7,6 +7,8 @@
using namespace metal;
constant bool has_w [[function_constant(20)]];
template <typename T, int N_READS = RMS_N_READS>
[[kernel]] void layer_norm_single_row(
const device T* x,
@@ -327,7 +329,9 @@ template <typename T, int N_READS = RMS_N_READS>
gx[i] = static_cast<T>(
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
thread_x[i] * meanwgxc * normalizer2);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
if (has_w) {
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
}
}
} else {
for (int i = 0; i < N_READS; i++) {
@@ -336,7 +340,9 @@ template <typename T, int N_READS = RMS_N_READS>
gx[i] = static_cast<T>(
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
thread_x[i] * meanwgxc * normalizer2);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
if (has_w) {
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
}
}
}
}
@@ -465,7 +471,9 @@ template <typename T, int N_READS = RMS_N_READS>
float gi = g[i + r];
gx[i + r] = static_cast<T>(
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
gw[i + r] = static_cast<T>(gi * xi);
if (has_w) {
gw[i + r] = static_cast<T>(gi * xi);
}
}
} else {
for (int i = 0; i < N_READS; i++) {
@@ -475,7 +483,9 @@ template <typename T, int N_READS = RMS_N_READS>
float gi = g[i + r];
gx[i + r] = static_cast<T>(
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
gw[i + r] = static_cast<T>(gi * xi);
if (has_w) {
gw[i + r] = static_cast<T>(gi * xi);
}
}
}
}

View File

@@ -7,6 +7,8 @@
using namespace metal;
constant bool has_w [[function_constant(20)]];
template <typename T, int N_READS = RMS_N_READS>
[[kernel]] void rms_single_row(
const device T* x,
@@ -243,7 +245,9 @@ template <typename T, int N_READS = RMS_N_READS>
gx[i] = static_cast<T>(
thread_g[i] * thread_w[i] * normalizer -
thread_x[i] * meangwx * normalizer3);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
if (has_w) {
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
}
}
} else {
for (int i = 0; i < N_READS; i++) {
@@ -251,7 +255,9 @@ template <typename T, int N_READS = RMS_N_READS>
gx[i] = static_cast<T>(
thread_g[i] * thread_w[i] * normalizer -
thread_x[i] * meangwx * normalizer3);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
if (has_w) {
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
}
}
}
}
@@ -351,7 +357,9 @@ template <typename T, int N_READS = RMS_N_READS>
gx[i + r] =
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
gw[i + r] = static_cast<T>(gi * xi * normalizer);
if (has_w) {
gw[i + r] = static_cast<T>(gi * xi * normalizer);
}
}
} else {
for (int i = 0; i < N_READS; i++) {
@@ -362,7 +370,9 @@ template <typename T, int N_READS = RMS_N_READS>
gx[i + r] =
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
gw[i + r] = static_cast<T>(gi * xi * normalizer);
if (has_w) {
gw[i + r] = static_cast<T>(gi * xi * normalizer);
}
}
}
}

View File

@@ -77,7 +77,7 @@ void RMSNorm::eval_gpu(
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
uint32_t w_stride = w.strides()[0];
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(
x.data_shared_ptr() == nullptr ? out : x, 0);
@@ -101,20 +101,16 @@ void RMSNormVJP::eval_gpu(
// Ensure row contiguity. We could relax this step by checking that the array
// is contiguous (no broadcasts or holes) and that the input strides are the
// same as the cotangent strides but for now this is simpler.
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) -> const array& {
auto check_input = [&d, &s](const array& x) -> array {
if (x.flags().row_contiguous) {
return x;
}
// Make sure we 'll only ever allocate once. The point of that goes beyond
// the minor optimization. We need to ensure that there will be no
// reallocation such that the references won't change when we
// push_back(...). So tl;dr 3 possible copies x, g and gw_temp.
copies.reserve(3);
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
};
const array& x = check_input(inputs[0]);
const array& w = inputs[1];
@@ -122,6 +118,9 @@ void RMSNormVJP::eval_gpu(
array& gx = outputs[0];
array& gw = outputs[1];
// Check whether we had a weight
bool has_w = w.ndim() != 0;
// Allocate space for the outputs
bool x_in_gx = false;
bool g_in_gx = false;
@@ -140,15 +139,18 @@ void RMSNormVJP::eval_gpu(
// Allocate the gradient accumulator gw and a temporary to store the
// gradients before they are accumulated.
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
array gw_temp =
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
bool g_in_gw = false;
if (!g_in_gx && g.is_donatable()) {
gw_temp.move_shared_buffer(g);
g_in_gw = true;
} else {
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
if (has_w) {
if (!g_in_gx && g.is_donatable()) {
gw_temp.move_shared_buffer(g);
g_in_gw = true;
} else {
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
d.add_temporary(gw_temp, s.index);
}
}
copies.push_back(gw_temp);
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
const int simd_size = 32;
@@ -159,9 +161,15 @@ void RMSNormVJP::eval_gpu(
op_name += "_looped";
}
op_name += type_to_name(gx);
std::string hash_name = op_name + ((has_w) ? "_w" : "_now");
metal::MTLFCList func_consts = {
{&has_w, MTL::DataType::DataTypeBool, 20},
};
auto& compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel(op_name);
auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts);
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
@@ -179,7 +187,7 @@ void RMSNormVJP::eval_gpu(
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
uint32_t w_stride = w.strides()[0];
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(x_in_gx ? gx : x, 0);
compute_encoder.set_input_array(w, 1);
@@ -192,12 +200,12 @@ void RMSNormVJP::eval_gpu(
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
strided_reduce_general_dispatch(
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
d.add_temporaries(std::move(copies), s.index);
if (has_w) {
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
strided_reduce_general_dispatch(
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
}
}
void LayerNorm::eval_gpu(
@@ -295,20 +303,16 @@ void LayerNormVJP::eval_gpu(
// Ensure row contiguity. We could relax this step by checking that the array
// is contiguous (no broadcasts or holes) and that the input strides are the
// same as the cotangent strides but for now this is simpler.
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) -> const array& {
auto check_input = [&d, &s](const array& x) -> array {
if (x.flags().row_contiguous) {
return x;
}
// Make sure we 'll only ever allocate once. The point of that goes beyond
// the minor optimization. We need to ensure that there will be no
// reallocation such that the references won't change when we
// push_back(...). So tl;dr 3 possible copies x, g and gw_temp.
copies.reserve(3);
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
};
const array& x = check_input(inputs[0]);
const array& w = inputs[1];
@@ -318,6 +322,9 @@ void LayerNormVJP::eval_gpu(
array& gw = outputs[1];
array& gb = outputs[2];
// Check whether we had a weight
bool has_w = w.ndim() != 0;
// Allocate space for the outputs
bool x_in_gx = false;
bool g_in_gx = false;
@@ -336,15 +343,18 @@ void LayerNormVJP::eval_gpu(
// Allocate a temporary to store the gradients for w and allocate the output
// gradient accumulators.
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
array gw_temp =
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
bool g_in_gw = false;
if (!g_in_gx && g.is_donatable()) {
gw_temp.move_shared_buffer(g);
g_in_gw = true;
} else {
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
if (has_w) {
if (!g_in_gx && g.is_donatable()) {
gw_temp.move_shared_buffer(g);
g_in_gw = true;
} else {
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
}
d.add_temporary(gw_temp, s.index);
}
copies.push_back(gw_temp);
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
gb.set_data(allocator::malloc_or_wait(gb.nbytes()));
@@ -372,8 +382,14 @@ void LayerNormVJP::eval_gpu(
op_name += "_looped";
}
op_name += type_to_name(gx);
std::string hash_name = op_name + ((has_w) ? "_w" : "_now");
metal::MTLFCList func_consts = {
{&has_w, MTL::DataType::DataTypeBool, 20},
};
{
auto kernel = d.get_kernel(op_name);
auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts);
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
@@ -404,14 +420,12 @@ void LayerNormVJP::eval_gpu(
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
if (gw.ndim() == 1 && gw.size() == axis_size) {
if (has_w) {
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
strided_reduce_general_dispatch(
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
}
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core::fast

View File

@@ -54,30 +54,34 @@ std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
array rms_norm(
const array& x,
const array& weight,
const std::optional<array>& weight,
float eps,
StreamOrDevice s_ /* = {} */) {
bool has_weight = weight.has_value();
if (x.ndim() == 0) {
std::ostringstream msg;
msg << "[rms_norm] Input must have at least 1 dimension but got input with "
"0 dimensions.";
throw std::invalid_argument(msg.str());
}
if (weight.ndim() != 1) {
std::ostringstream msg;
msg << "[rms_norm] weight must have 1 dimension but has " << weight.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
if (weight.size() != x.shape(-1)) {
std::ostringstream msg;
msg << "[rms_norm] weight must have the same size as the last dimension of"
" x but has "
<< weight.size() << " elements.";
throw std::invalid_argument(msg.str());
if (has_weight) {
if ((*weight).ndim() != 1) {
std::ostringstream msg;
msg << "[rms_norm] (*weight) must have 1 dimension but has "
<< (*weight).ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if ((*weight).size() != x.shape(-1)) {
std::ostringstream msg;
msg << "[rms_norm] (*weight) must have the same size as the last dimension of"
" x but has "
<< (*weight).size() << " elements.";
throw std::invalid_argument(msg.str());
}
}
auto out_type = result_type(x, weight);
auto out_type = (weight.has_value()) ? result_type(x, (*weight)) : x.dtype();
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[rms_norm] Received unsupported type " << out_type << ".";
@@ -85,27 +89,36 @@ array rms_norm(
}
auto s = to_stream(s_);
auto fallback = [eps, out_type, s](const std::vector<array>& inputs) {
auto x = astype(inputs[0], float32, s);
x = multiply(
x,
rsqrt(
add(mean(square(x, s), -1, /* keepdims */ true, s),
array(eps, float32),
auto fallback =
[has_weight, eps, out_type, s](const std::vector<array>& inputs) {
auto x = astype(inputs[0], float32, s);
x = multiply(
x,
rsqrt(
add(mean(square(x, s), -1, /* keepdims */ true, s),
array(eps, float32),
s),
s),
s),
s);
x = astype(x, out_type, s);
return std::vector<array>{multiply(inputs[1], x, s)};
};
s);
x = astype(x, out_type, s);
if (has_weight) {
x = multiply(x, inputs[1], s);
}
return std::vector<array>{x};
};
auto passed_weight =
(has_weight) ? astype(*weight, out_type, s) : array(1, out_type);
if (s.device == Device::gpu) {
return array(
x.shape(),
out_type,
std::make_shared<RMSNorm>(s, fallback, eps),
{astype(x, out_type, s), astype(weight, out_type, s)});
{astype(x, out_type, s), passed_weight});
}
return fallback({x, weight})[0];
return fallback({x, passed_weight})[0];
}
std::vector<array> RMSNorm::vjp(
@@ -141,8 +154,12 @@ std::vector<array> RMSNorm::vjp(
// df/dw
std::vector<int> axes(g.ndim() - 1);
std::iota(axes.begin(), axes.end(), 0);
vjps.push_back(
sum(multiply(g, multiply(x, n, s), s), axes, /* keepdims= */ false, s));
if (w.ndim() == 0) {
vjps.push_back(zeros_like(w, s));
} else {
vjps.push_back(sum(
multiply(g, multiply(x, n, s), s), axes, /* keepdims= */ false, s));
}
return vjps;
};
@@ -177,28 +194,30 @@ array layer_norm(
const std::optional<array>& bias,
float eps,
StreamOrDevice s_ /* = {} */) {
bool has_weight = weight.has_value();
bool has_bias = bias.has_value();
if (x.ndim() == 0) {
std::ostringstream msg;
msg << "[layer_norm] Input must have at least 1 dimension but got input with "
"0 dimensions.";
throw std::invalid_argument(msg.str());
}
if (weight.has_value() && (*weight).ndim() != 1) {
if (has_weight && (*weight).ndim() != 1) {
std::ostringstream msg;
msg << "[layer_norm] weight must have 1 dimension but has "
<< (*weight).ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (bias.has_value() && (*bias).ndim() != 1) {
if (has_bias && (*bias).ndim() != 1) {
std::ostringstream msg;
msg << "[layer_norm] bias must have 1 dimension but has " << (*bias).ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
auto out_type = (weight.has_value())
? ((bias.has_value()) ? result_type(x, *weight, *bias)
: result_type(x, *weight))
auto out_type = (has_weight)
? ((has_bias) ? result_type(x, *weight, *bias) : result_type(x, *weight))
: x.dtype();
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
@@ -207,8 +226,6 @@ array layer_norm(
}
auto s = to_stream(s_);
bool has_weight = weight.has_value();
bool has_bias = bias.has_value();
auto fallback = [has_weight, has_bias, eps, out_type, s](
const std::vector<array>& inputs) {
auto x = astype(inputs[0], float32, s);
@@ -234,9 +251,9 @@ array layer_norm(
};
auto passed_weight =
astype((weight.has_value()) ? *weight : array(1, out_type), out_type);
(has_weight) ? astype(*weight, out_type, s) : array(1, out_type);
auto passed_bias =
astype((bias.has_value()) ? *bias : array(0, out_type), out_type);
(has_bias) ? astype(*bias, out_type, s) : array(0, out_type);
if (s.device == Device::gpu) {
return array(

View File

@@ -10,7 +10,7 @@ namespace mlx::core::fast {
array rms_norm(
const array& x,
const array& weight,
const std::optional<array>& weight,
float eps,
StreamOrDevice s = {});