mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
RMS norm without scaling (#1915)
This commit is contained in:

committed by
GitHub

parent
5d68082881
commit
5e6c130d93
@@ -7,6 +7,8 @@
|
||||
|
||||
using namespace metal;
|
||||
|
||||
constant bool has_w [[function_constant(20)]];
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void layer_norm_single_row(
|
||||
const device T* x,
|
||||
@@ -327,7 +329,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
gx[i] = static_cast<T>(
|
||||
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||
thread_x[i] * meanwgxc * normalizer2);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||
if (has_w) {
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
@@ -336,7 +340,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
gx[i] = static_cast<T>(
|
||||
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||
thread_x[i] * meanwgxc * normalizer2);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||
if (has_w) {
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -465,7 +471,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
float gi = g[i + r];
|
||||
gx[i + r] = static_cast<T>(
|
||||
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
|
||||
gw[i + r] = static_cast<T>(gi * xi);
|
||||
if (has_w) {
|
||||
gw[i + r] = static_cast<T>(gi * xi);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
@@ -475,7 +483,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
float gi = g[i + r];
|
||||
gx[i + r] = static_cast<T>(
|
||||
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
|
||||
gw[i + r] = static_cast<T>(gi * xi);
|
||||
if (has_w) {
|
||||
gw[i + r] = static_cast<T>(gi * xi);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -7,6 +7,8 @@
|
||||
|
||||
using namespace metal;
|
||||
|
||||
constant bool has_w [[function_constant(20)]];
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void rms_single_row(
|
||||
const device T* x,
|
||||
@@ -243,7 +245,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
gx[i] = static_cast<T>(
|
||||
thread_g[i] * thread_w[i] * normalizer -
|
||||
thread_x[i] * meangwx * normalizer3);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||
if (has_w) {
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
@@ -251,7 +255,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
gx[i] = static_cast<T>(
|
||||
thread_g[i] * thread_w[i] * normalizer -
|
||||
thread_x[i] * meangwx * normalizer3);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||
if (has_w) {
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -351,7 +357,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
|
||||
gx[i + r] =
|
||||
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||
if (has_w) {
|
||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
@@ -362,7 +370,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
|
||||
gx[i + r] =
|
||||
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||
if (has_w) {
|
||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -77,7 +77,7 @@ void RMSNorm::eval_gpu(
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
}
|
||||
|
||||
uint32_t w_stride = w.strides()[0];
|
||||
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(
|
||||
x.data_shared_ptr() == nullptr ? out : x, 0);
|
||||
@@ -101,20 +101,16 @@ void RMSNormVJP::eval_gpu(
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
||||
auto check_input = [&d, &s](const array& x) -> array {
|
||||
if (x.flags().row_contiguous) {
|
||||
return x;
|
||||
}
|
||||
// Make sure we 'll only ever allocate once. The point of that goes beyond
|
||||
// the minor optimization. We need to ensure that there will be no
|
||||
// reallocation such that the references won't change when we
|
||||
// push_back(...). So tl;dr 3 possible copies x, g and gw_temp.
|
||||
copies.reserve(3);
|
||||
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
|
||||
return x_copy;
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
@@ -122,6 +118,9 @@ void RMSNormVJP::eval_gpu(
|
||||
array& gx = outputs[0];
|
||||
array& gw = outputs[1];
|
||||
|
||||
// Check whether we had a weight
|
||||
bool has_w = w.ndim() != 0;
|
||||
|
||||
// Allocate space for the outputs
|
||||
bool x_in_gx = false;
|
||||
bool g_in_gx = false;
|
||||
@@ -140,15 +139,18 @@ void RMSNormVJP::eval_gpu(
|
||||
|
||||
// Allocate the gradient accumulator gw and a temporary to store the
|
||||
// gradients before they are accumulated.
|
||||
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
|
||||
array gw_temp =
|
||||
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
|
||||
bool g_in_gw = false;
|
||||
if (!g_in_gx && g.is_donatable()) {
|
||||
gw_temp.move_shared_buffer(g);
|
||||
g_in_gw = true;
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
if (has_w) {
|
||||
if (!g_in_gx && g.is_donatable()) {
|
||||
gw_temp.move_shared_buffer(g);
|
||||
g_in_gw = true;
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
d.add_temporary(gw_temp, s.index);
|
||||
}
|
||||
}
|
||||
copies.push_back(gw_temp);
|
||||
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
|
||||
|
||||
const int simd_size = 32;
|
||||
@@ -159,9 +161,15 @@ void RMSNormVJP::eval_gpu(
|
||||
op_name += "_looped";
|
||||
}
|
||||
op_name += type_to_name(gx);
|
||||
|
||||
std::string hash_name = op_name + ((has_w) ? "_w" : "_now");
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_w, MTL::DataType::DataTypeBool, 20},
|
||||
};
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts);
|
||||
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
@@ -179,7 +187,7 @@ void RMSNormVJP::eval_gpu(
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
}
|
||||
|
||||
uint32_t w_stride = w.strides()[0];
|
||||
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(x_in_gx ? gx : x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
@@ -192,12 +200,12 @@ void RMSNormVJP::eval_gpu(
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
strided_reduce_general_dispatch(
|
||||
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
if (has_w) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
strided_reduce_general_dispatch(
|
||||
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
|
||||
}
|
||||
}
|
||||
|
||||
void LayerNorm::eval_gpu(
|
||||
@@ -295,20 +303,16 @@ void LayerNormVJP::eval_gpu(
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
||||
auto check_input = [&d, &s](const array& x) -> array {
|
||||
if (x.flags().row_contiguous) {
|
||||
return x;
|
||||
}
|
||||
// Make sure we 'll only ever allocate once. The point of that goes beyond
|
||||
// the minor optimization. We need to ensure that there will be no
|
||||
// reallocation such that the references won't change when we
|
||||
// push_back(...). So tl;dr 3 possible copies x, g and gw_temp.
|
||||
copies.reserve(3);
|
||||
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
|
||||
return x_copy;
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
@@ -318,6 +322,9 @@ void LayerNormVJP::eval_gpu(
|
||||
array& gw = outputs[1];
|
||||
array& gb = outputs[2];
|
||||
|
||||
// Check whether we had a weight
|
||||
bool has_w = w.ndim() != 0;
|
||||
|
||||
// Allocate space for the outputs
|
||||
bool x_in_gx = false;
|
||||
bool g_in_gx = false;
|
||||
@@ -336,15 +343,18 @@ void LayerNormVJP::eval_gpu(
|
||||
|
||||
// Allocate a temporary to store the gradients for w and allocate the output
|
||||
// gradient accumulators.
|
||||
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
|
||||
array gw_temp =
|
||||
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
|
||||
bool g_in_gw = false;
|
||||
if (!g_in_gx && g.is_donatable()) {
|
||||
gw_temp.move_shared_buffer(g);
|
||||
g_in_gw = true;
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
if (has_w) {
|
||||
if (!g_in_gx && g.is_donatable()) {
|
||||
gw_temp.move_shared_buffer(g);
|
||||
g_in_gw = true;
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
}
|
||||
d.add_temporary(gw_temp, s.index);
|
||||
}
|
||||
copies.push_back(gw_temp);
|
||||
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
|
||||
gb.set_data(allocator::malloc_or_wait(gb.nbytes()));
|
||||
|
||||
@@ -372,8 +382,14 @@ void LayerNormVJP::eval_gpu(
|
||||
op_name += "_looped";
|
||||
}
|
||||
op_name += type_to_name(gx);
|
||||
|
||||
std::string hash_name = op_name + ((has_w) ? "_w" : "_now");
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_w, MTL::DataType::DataTypeBool, 20},
|
||||
};
|
||||
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts);
|
||||
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
@@ -404,14 +420,12 @@ void LayerNormVJP::eval_gpu(
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
if (gw.ndim() == 1 && gw.size() == axis_size) {
|
||||
if (has_w) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
strided_reduce_general_dispatch(
|
||||
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
Reference in New Issue
Block a user