mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
RMS norm without scaling (#1915)
This commit is contained in:
parent
5d68082881
commit
5e6c130d93
@ -10,7 +10,12 @@ def layer_norm(x, w, b, eps):
|
|||||||
x = x.astype(mx.float32)
|
x = x.astype(mx.float32)
|
||||||
mu = mx.mean(x, -1, keepdims=True)
|
mu = mx.mean(x, -1, keepdims=True)
|
||||||
v = mx.var(x, -1, keepdims=True)
|
v = mx.var(x, -1, keepdims=True)
|
||||||
return (x - mu) * mx.rsqrt(v + eps) * w + b
|
y = (x - mu) * mx.rsqrt(v + eps)
|
||||||
|
if w is not None:
|
||||||
|
y = y * w
|
||||||
|
if b is not None:
|
||||||
|
y = y + b
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
def time_layer_norm():
|
def time_layer_norm():
|
||||||
@ -36,6 +41,28 @@ def time_layer_norm():
|
|||||||
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
|
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
|
||||||
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
|
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
|
||||||
|
|
||||||
|
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
||||||
|
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
||||||
|
g1 = mx.grad(f1, argnums=(0,))
|
||||||
|
g2 = mx.grad(f2, argnums=(0,))
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||||
|
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
|
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
|
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||||
|
mx.eval(x, w, b, y)
|
||||||
|
|
||||||
|
def layer_norm_loop(g, x):
|
||||||
|
gx = x
|
||||||
|
for _ in range(32):
|
||||||
|
gx = g(gx, y)
|
||||||
|
return gx
|
||||||
|
|
||||||
|
time_fn(layer_norm_loop, g1, x)
|
||||||
|
time_fn(layer_norm_loop, g2, x)
|
||||||
|
time_fn(layer_norm_loop, mx.compile(g1), x)
|
||||||
|
time_fn(layer_norm_loop, mx.compile(g2), x)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
time_layer_norm()
|
time_layer_norm()
|
||||||
|
@ -9,7 +9,10 @@ def rms_norm(x, w, eps):
|
|||||||
ot = x.dtype
|
ot = x.dtype
|
||||||
x = x.astype(mx.float32)
|
x = x.astype(mx.float32)
|
||||||
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||||
return (x * n).astype(ot) * w
|
y = (x * n).astype(ot)
|
||||||
|
if w is not None:
|
||||||
|
y = y * w
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
def time_rms_norm():
|
def time_rms_norm():
|
||||||
@ -34,6 +37,27 @@ def time_rms_norm():
|
|||||||
time_fn(rms_norm_loop, mx.compile(g1), x, w)
|
time_fn(rms_norm_loop, mx.compile(g1), x, w)
|
||||||
time_fn(rms_norm_loop, mx.compile(g2), x, w)
|
time_fn(rms_norm_loop, mx.compile(g2), x, w)
|
||||||
|
|
||||||
|
f1 = lambda x, y: (rms_norm(x, None, 1e-5) * y).sum()
|
||||||
|
f2 = lambda x, y: (mx.fast.rms_norm(x, None, 1e-5) * y).sum()
|
||||||
|
g1 = mx.grad(f1, argnums=(0,))
|
||||||
|
g2 = mx.grad(f2, argnums=(0,))
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||||
|
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
|
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||||
|
mx.eval(x, w, y)
|
||||||
|
|
||||||
|
def rms_norm_loop(g, x):
|
||||||
|
gx = x
|
||||||
|
for _ in range(32):
|
||||||
|
gx = g(gx, y)
|
||||||
|
return gx
|
||||||
|
|
||||||
|
time_fn(rms_norm_loop, g1, x)
|
||||||
|
time_fn(rms_norm_loop, g2, x)
|
||||||
|
time_fn(rms_norm_loop, mx.compile(g1), x)
|
||||||
|
time_fn(rms_norm_loop, mx.compile(g2), x)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
time_rms_norm()
|
time_rms_norm()
|
||||||
|
@ -7,6 +7,8 @@
|
|||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
|
constant bool has_w [[function_constant(20)]];
|
||||||
|
|
||||||
template <typename T, int N_READS = RMS_N_READS>
|
template <typename T, int N_READS = RMS_N_READS>
|
||||||
[[kernel]] void layer_norm_single_row(
|
[[kernel]] void layer_norm_single_row(
|
||||||
const device T* x,
|
const device T* x,
|
||||||
@ -327,7 +329,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
gx[i] = static_cast<T>(
|
gx[i] = static_cast<T>(
|
||||||
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||||
thread_x[i] * meanwgxc * normalizer2);
|
thread_x[i] * meanwgxc * normalizer2);
|
||||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
if (has_w) {
|
||||||
|
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
@ -336,7 +340,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
gx[i] = static_cast<T>(
|
gx[i] = static_cast<T>(
|
||||||
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||||
thread_x[i] * meanwgxc * normalizer2);
|
thread_x[i] * meanwgxc * normalizer2);
|
||||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
if (has_w) {
|
||||||
|
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -465,7 +471,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
float gi = g[i + r];
|
float gi = g[i + r];
|
||||||
gx[i + r] = static_cast<T>(
|
gx[i + r] = static_cast<T>(
|
||||||
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
|
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
|
||||||
gw[i + r] = static_cast<T>(gi * xi);
|
if (has_w) {
|
||||||
|
gw[i + r] = static_cast<T>(gi * xi);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
@ -475,7 +483,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
float gi = g[i + r];
|
float gi = g[i + r];
|
||||||
gx[i + r] = static_cast<T>(
|
gx[i + r] = static_cast<T>(
|
||||||
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
|
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
|
||||||
gw[i + r] = static_cast<T>(gi * xi);
|
if (has_w) {
|
||||||
|
gw[i + r] = static_cast<T>(gi * xi);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,8 @@
|
|||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
|
constant bool has_w [[function_constant(20)]];
|
||||||
|
|
||||||
template <typename T, int N_READS = RMS_N_READS>
|
template <typename T, int N_READS = RMS_N_READS>
|
||||||
[[kernel]] void rms_single_row(
|
[[kernel]] void rms_single_row(
|
||||||
const device T* x,
|
const device T* x,
|
||||||
@ -243,7 +245,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
gx[i] = static_cast<T>(
|
gx[i] = static_cast<T>(
|
||||||
thread_g[i] * thread_w[i] * normalizer -
|
thread_g[i] * thread_w[i] * normalizer -
|
||||||
thread_x[i] * meangwx * normalizer3);
|
thread_x[i] * meangwx * normalizer3);
|
||||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
if (has_w) {
|
||||||
|
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
@ -251,7 +255,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
gx[i] = static_cast<T>(
|
gx[i] = static_cast<T>(
|
||||||
thread_g[i] * thread_w[i] * normalizer -
|
thread_g[i] * thread_w[i] * normalizer -
|
||||||
thread_x[i] * meangwx * normalizer3);
|
thread_x[i] * meangwx * normalizer3);
|
||||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
if (has_w) {
|
||||||
|
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -351,7 +357,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
|
|
||||||
gx[i + r] =
|
gx[i + r] =
|
||||||
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
if (has_w) {
|
||||||
|
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
@ -362,7 +370,9 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
|
|
||||||
gx[i + r] =
|
gx[i + r] =
|
||||||
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
if (has_w) {
|
||||||
|
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -77,7 +77,7 @@ void RMSNorm::eval_gpu(
|
|||||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t w_stride = w.strides()[0];
|
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
compute_encoder.set_input_array(
|
compute_encoder.set_input_array(
|
||||||
x.data_shared_ptr() == nullptr ? out : x, 0);
|
x.data_shared_ptr() == nullptr ? out : x, 0);
|
||||||
@ -101,20 +101,16 @@ void RMSNormVJP::eval_gpu(
|
|||||||
// Ensure row contiguity. We could relax this step by checking that the array
|
// Ensure row contiguity. We could relax this step by checking that the array
|
||||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||||
// same as the cotangent strides but for now this is simpler.
|
// same as the cotangent strides but for now this is simpler.
|
||||||
std::vector<array> copies;
|
auto check_input = [&d, &s](const array& x) -> array {
|
||||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
|
||||||
if (x.flags().row_contiguous) {
|
if (x.flags().row_contiguous) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
// Make sure we 'll only ever allocate once. The point of that goes beyond
|
|
||||||
// the minor optimization. We need to ensure that there will be no
|
|
||||||
// reallocation such that the references won't change when we
|
|
||||||
// push_back(...). So tl;dr 3 possible copies x, g and gw_temp.
|
|
||||||
copies.reserve(3);
|
|
||||||
|
|
||||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
copy_gpu(x, x_copy, CopyType::General, s);
|
||||||
return copies.back();
|
d.add_temporary(x_copy, s.index);
|
||||||
|
|
||||||
|
return x_copy;
|
||||||
};
|
};
|
||||||
const array& x = check_input(inputs[0]);
|
const array& x = check_input(inputs[0]);
|
||||||
const array& w = inputs[1];
|
const array& w = inputs[1];
|
||||||
@ -122,6 +118,9 @@ void RMSNormVJP::eval_gpu(
|
|||||||
array& gx = outputs[0];
|
array& gx = outputs[0];
|
||||||
array& gw = outputs[1];
|
array& gw = outputs[1];
|
||||||
|
|
||||||
|
// Check whether we had a weight
|
||||||
|
bool has_w = w.ndim() != 0;
|
||||||
|
|
||||||
// Allocate space for the outputs
|
// Allocate space for the outputs
|
||||||
bool x_in_gx = false;
|
bool x_in_gx = false;
|
||||||
bool g_in_gx = false;
|
bool g_in_gx = false;
|
||||||
@ -140,15 +139,18 @@ void RMSNormVJP::eval_gpu(
|
|||||||
|
|
||||||
// Allocate the gradient accumulator gw and a temporary to store the
|
// Allocate the gradient accumulator gw and a temporary to store the
|
||||||
// gradients before they are accumulated.
|
// gradients before they are accumulated.
|
||||||
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
|
array gw_temp =
|
||||||
|
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
|
||||||
bool g_in_gw = false;
|
bool g_in_gw = false;
|
||||||
if (!g_in_gx && g.is_donatable()) {
|
if (has_w) {
|
||||||
gw_temp.move_shared_buffer(g);
|
if (!g_in_gx && g.is_donatable()) {
|
||||||
g_in_gw = true;
|
gw_temp.move_shared_buffer(g);
|
||||||
} else {
|
g_in_gw = true;
|
||||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
} else {
|
||||||
|
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||||
|
d.add_temporary(gw_temp, s.index);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
copies.push_back(gw_temp);
|
|
||||||
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
|
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
|
||||||
|
|
||||||
const int simd_size = 32;
|
const int simd_size = 32;
|
||||||
@ -159,9 +161,15 @@ void RMSNormVJP::eval_gpu(
|
|||||||
op_name += "_looped";
|
op_name += "_looped";
|
||||||
}
|
}
|
||||||
op_name += type_to_name(gx);
|
op_name += type_to_name(gx);
|
||||||
|
|
||||||
|
std::string hash_name = op_name + ((has_w) ? "_w" : "_now");
|
||||||
|
metal::MTLFCList func_consts = {
|
||||||
|
{&has_w, MTL::DataType::DataTypeBool, 20},
|
||||||
|
};
|
||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
{
|
{
|
||||||
auto kernel = d.get_kernel(op_name);
|
auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts);
|
||||||
|
|
||||||
MTL::Size grid_dims, group_dims;
|
MTL::Size grid_dims, group_dims;
|
||||||
if (axis_size <= looped_limit) {
|
if (axis_size <= looped_limit) {
|
||||||
@ -179,7 +187,7 @@ void RMSNormVJP::eval_gpu(
|
|||||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t w_stride = w.strides()[0];
|
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
compute_encoder.set_input_array(x_in_gx ? gx : x, 0);
|
compute_encoder.set_input_array(x_in_gx ? gx : x, 0);
|
||||||
compute_encoder.set_input_array(w, 1);
|
compute_encoder.set_input_array(w, 1);
|
||||||
@ -192,12 +200,12 @@ void RMSNormVJP::eval_gpu(
|
|||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
ReductionPlan plan(
|
if (has_w) {
|
||||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
ReductionPlan plan(
|
||||||
strided_reduce_general_dispatch(
|
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||||
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
|
strided_reduce_general_dispatch(
|
||||||
|
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
|
||||||
d.add_temporaries(std::move(copies), s.index);
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void LayerNorm::eval_gpu(
|
void LayerNorm::eval_gpu(
|
||||||
@ -295,20 +303,16 @@ void LayerNormVJP::eval_gpu(
|
|||||||
// Ensure row contiguity. We could relax this step by checking that the array
|
// Ensure row contiguity. We could relax this step by checking that the array
|
||||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||||
// same as the cotangent strides but for now this is simpler.
|
// same as the cotangent strides but for now this is simpler.
|
||||||
std::vector<array> copies;
|
auto check_input = [&d, &s](const array& x) -> array {
|
||||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
|
||||||
if (x.flags().row_contiguous) {
|
if (x.flags().row_contiguous) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
// Make sure we 'll only ever allocate once. The point of that goes beyond
|
|
||||||
// the minor optimization. We need to ensure that there will be no
|
|
||||||
// reallocation such that the references won't change when we
|
|
||||||
// push_back(...). So tl;dr 3 possible copies x, g and gw_temp.
|
|
||||||
copies.reserve(3);
|
|
||||||
|
|
||||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
copy_gpu(x, x_copy, CopyType::General, s);
|
||||||
return copies.back();
|
d.add_temporary(x_copy, s.index);
|
||||||
|
|
||||||
|
return x_copy;
|
||||||
};
|
};
|
||||||
const array& x = check_input(inputs[0]);
|
const array& x = check_input(inputs[0]);
|
||||||
const array& w = inputs[1];
|
const array& w = inputs[1];
|
||||||
@ -318,6 +322,9 @@ void LayerNormVJP::eval_gpu(
|
|||||||
array& gw = outputs[1];
|
array& gw = outputs[1];
|
||||||
array& gb = outputs[2];
|
array& gb = outputs[2];
|
||||||
|
|
||||||
|
// Check whether we had a weight
|
||||||
|
bool has_w = w.ndim() != 0;
|
||||||
|
|
||||||
// Allocate space for the outputs
|
// Allocate space for the outputs
|
||||||
bool x_in_gx = false;
|
bool x_in_gx = false;
|
||||||
bool g_in_gx = false;
|
bool g_in_gx = false;
|
||||||
@ -336,15 +343,18 @@ void LayerNormVJP::eval_gpu(
|
|||||||
|
|
||||||
// Allocate a temporary to store the gradients for w and allocate the output
|
// Allocate a temporary to store the gradients for w and allocate the output
|
||||||
// gradient accumulators.
|
// gradient accumulators.
|
||||||
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
|
array gw_temp =
|
||||||
|
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
|
||||||
bool g_in_gw = false;
|
bool g_in_gw = false;
|
||||||
if (!g_in_gx && g.is_donatable()) {
|
if (has_w) {
|
||||||
gw_temp.move_shared_buffer(g);
|
if (!g_in_gx && g.is_donatable()) {
|
||||||
g_in_gw = true;
|
gw_temp.move_shared_buffer(g);
|
||||||
} else {
|
g_in_gw = true;
|
||||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
} else {
|
||||||
|
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||||
|
}
|
||||||
|
d.add_temporary(gw_temp, s.index);
|
||||||
}
|
}
|
||||||
copies.push_back(gw_temp);
|
|
||||||
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
|
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
|
||||||
gb.set_data(allocator::malloc_or_wait(gb.nbytes()));
|
gb.set_data(allocator::malloc_or_wait(gb.nbytes()));
|
||||||
|
|
||||||
@ -372,8 +382,14 @@ void LayerNormVJP::eval_gpu(
|
|||||||
op_name += "_looped";
|
op_name += "_looped";
|
||||||
}
|
}
|
||||||
op_name += type_to_name(gx);
|
op_name += type_to_name(gx);
|
||||||
|
|
||||||
|
std::string hash_name = op_name + ((has_w) ? "_w" : "_now");
|
||||||
|
metal::MTLFCList func_consts = {
|
||||||
|
{&has_w, MTL::DataType::DataTypeBool, 20},
|
||||||
|
};
|
||||||
|
|
||||||
{
|
{
|
||||||
auto kernel = d.get_kernel(op_name);
|
auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts);
|
||||||
|
|
||||||
MTL::Size grid_dims, group_dims;
|
MTL::Size grid_dims, group_dims;
|
||||||
if (axis_size <= looped_limit) {
|
if (axis_size <= looped_limit) {
|
||||||
@ -404,14 +420,12 @@ void LayerNormVJP::eval_gpu(
|
|||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (gw.ndim() == 1 && gw.size() == axis_size) {
|
if (has_w) {
|
||||||
ReductionPlan plan(
|
ReductionPlan plan(
|
||||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||||
strided_reduce_general_dispatch(
|
strided_reduce_general_dispatch(
|
||||||
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
|
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
d.add_temporaries(std::move(copies), s.index);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::fast
|
} // namespace mlx::core::fast
|
||||||
|
95
mlx/fast.cpp
95
mlx/fast.cpp
@ -54,30 +54,34 @@ std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
|
|||||||
|
|
||||||
array rms_norm(
|
array rms_norm(
|
||||||
const array& x,
|
const array& x,
|
||||||
const array& weight,
|
const std::optional<array>& weight,
|
||||||
float eps,
|
float eps,
|
||||||
StreamOrDevice s_ /* = {} */) {
|
StreamOrDevice s_ /* = {} */) {
|
||||||
|
bool has_weight = weight.has_value();
|
||||||
|
|
||||||
if (x.ndim() == 0) {
|
if (x.ndim() == 0) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[rms_norm] Input must have at least 1 dimension but got input with "
|
msg << "[rms_norm] Input must have at least 1 dimension but got input with "
|
||||||
"0 dimensions.";
|
"0 dimensions.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
if (weight.ndim() != 1) {
|
if (has_weight) {
|
||||||
std::ostringstream msg;
|
if ((*weight).ndim() != 1) {
|
||||||
msg << "[rms_norm] weight must have 1 dimension but has " << weight.ndim()
|
std::ostringstream msg;
|
||||||
<< " dimensions.";
|
msg << "[rms_norm] (*weight) must have 1 dimension but has "
|
||||||
throw std::invalid_argument(msg.str());
|
<< (*weight).ndim() << " dimensions.";
|
||||||
}
|
throw std::invalid_argument(msg.str());
|
||||||
if (weight.size() != x.shape(-1)) {
|
}
|
||||||
std::ostringstream msg;
|
if ((*weight).size() != x.shape(-1)) {
|
||||||
msg << "[rms_norm] weight must have the same size as the last dimension of"
|
std::ostringstream msg;
|
||||||
" x but has "
|
msg << "[rms_norm] (*weight) must have the same size as the last dimension of"
|
||||||
<< weight.size() << " elements.";
|
" x but has "
|
||||||
throw std::invalid_argument(msg.str());
|
<< (*weight).size() << " elements.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto out_type = result_type(x, weight);
|
auto out_type = (weight.has_value()) ? result_type(x, (*weight)) : x.dtype();
|
||||||
if (!issubdtype(out_type, floating)) {
|
if (!issubdtype(out_type, floating)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[rms_norm] Received unsupported type " << out_type << ".";
|
msg << "[rms_norm] Received unsupported type " << out_type << ".";
|
||||||
@ -85,27 +89,36 @@ array rms_norm(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto s = to_stream(s_);
|
auto s = to_stream(s_);
|
||||||
auto fallback = [eps, out_type, s](const std::vector<array>& inputs) {
|
auto fallback =
|
||||||
auto x = astype(inputs[0], float32, s);
|
[has_weight, eps, out_type, s](const std::vector<array>& inputs) {
|
||||||
x = multiply(
|
auto x = astype(inputs[0], float32, s);
|
||||||
x,
|
x = multiply(
|
||||||
rsqrt(
|
x,
|
||||||
add(mean(square(x, s), -1, /* keepdims */ true, s),
|
rsqrt(
|
||||||
array(eps, float32),
|
add(mean(square(x, s), -1, /* keepdims */ true, s),
|
||||||
|
array(eps, float32),
|
||||||
|
s),
|
||||||
s),
|
s),
|
||||||
s),
|
s);
|
||||||
s);
|
x = astype(x, out_type, s);
|
||||||
x = astype(x, out_type, s);
|
|
||||||
return std::vector<array>{multiply(inputs[1], x, s)};
|
if (has_weight) {
|
||||||
};
|
x = multiply(x, inputs[1], s);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::vector<array>{x};
|
||||||
|
};
|
||||||
|
|
||||||
|
auto passed_weight =
|
||||||
|
(has_weight) ? astype(*weight, out_type, s) : array(1, out_type);
|
||||||
if (s.device == Device::gpu) {
|
if (s.device == Device::gpu) {
|
||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
out_type,
|
out_type,
|
||||||
std::make_shared<RMSNorm>(s, fallback, eps),
|
std::make_shared<RMSNorm>(s, fallback, eps),
|
||||||
{astype(x, out_type, s), astype(weight, out_type, s)});
|
{astype(x, out_type, s), passed_weight});
|
||||||
}
|
}
|
||||||
return fallback({x, weight})[0];
|
return fallback({x, passed_weight})[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> RMSNorm::vjp(
|
std::vector<array> RMSNorm::vjp(
|
||||||
@ -141,8 +154,12 @@ std::vector<array> RMSNorm::vjp(
|
|||||||
// df/dw
|
// df/dw
|
||||||
std::vector<int> axes(g.ndim() - 1);
|
std::vector<int> axes(g.ndim() - 1);
|
||||||
std::iota(axes.begin(), axes.end(), 0);
|
std::iota(axes.begin(), axes.end(), 0);
|
||||||
vjps.push_back(
|
if (w.ndim() == 0) {
|
||||||
sum(multiply(g, multiply(x, n, s), s), axes, /* keepdims= */ false, s));
|
vjps.push_back(zeros_like(w, s));
|
||||||
|
} else {
|
||||||
|
vjps.push_back(sum(
|
||||||
|
multiply(g, multiply(x, n, s), s), axes, /* keepdims= */ false, s));
|
||||||
|
}
|
||||||
|
|
||||||
return vjps;
|
return vjps;
|
||||||
};
|
};
|
||||||
@ -177,28 +194,30 @@ array layer_norm(
|
|||||||
const std::optional<array>& bias,
|
const std::optional<array>& bias,
|
||||||
float eps,
|
float eps,
|
||||||
StreamOrDevice s_ /* = {} */) {
|
StreamOrDevice s_ /* = {} */) {
|
||||||
|
bool has_weight = weight.has_value();
|
||||||
|
bool has_bias = bias.has_value();
|
||||||
|
|
||||||
if (x.ndim() == 0) {
|
if (x.ndim() == 0) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[layer_norm] Input must have at least 1 dimension but got input with "
|
msg << "[layer_norm] Input must have at least 1 dimension but got input with "
|
||||||
"0 dimensions.";
|
"0 dimensions.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
if (weight.has_value() && (*weight).ndim() != 1) {
|
if (has_weight && (*weight).ndim() != 1) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[layer_norm] weight must have 1 dimension but has "
|
msg << "[layer_norm] weight must have 1 dimension but has "
|
||||||
<< (*weight).ndim() << " dimensions.";
|
<< (*weight).ndim() << " dimensions.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
if (bias.has_value() && (*bias).ndim() != 1) {
|
if (has_bias && (*bias).ndim() != 1) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[layer_norm] bias must have 1 dimension but has " << (*bias).ndim()
|
msg << "[layer_norm] bias must have 1 dimension but has " << (*bias).ndim()
|
||||||
<< " dimensions.";
|
<< " dimensions.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto out_type = (weight.has_value())
|
auto out_type = (has_weight)
|
||||||
? ((bias.has_value()) ? result_type(x, *weight, *bias)
|
? ((has_bias) ? result_type(x, *weight, *bias) : result_type(x, *weight))
|
||||||
: result_type(x, *weight))
|
|
||||||
: x.dtype();
|
: x.dtype();
|
||||||
if (!issubdtype(out_type, floating)) {
|
if (!issubdtype(out_type, floating)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -207,8 +226,6 @@ array layer_norm(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto s = to_stream(s_);
|
auto s = to_stream(s_);
|
||||||
bool has_weight = weight.has_value();
|
|
||||||
bool has_bias = bias.has_value();
|
|
||||||
auto fallback = [has_weight, has_bias, eps, out_type, s](
|
auto fallback = [has_weight, has_bias, eps, out_type, s](
|
||||||
const std::vector<array>& inputs) {
|
const std::vector<array>& inputs) {
|
||||||
auto x = astype(inputs[0], float32, s);
|
auto x = astype(inputs[0], float32, s);
|
||||||
@ -234,9 +251,9 @@ array layer_norm(
|
|||||||
};
|
};
|
||||||
|
|
||||||
auto passed_weight =
|
auto passed_weight =
|
||||||
astype((weight.has_value()) ? *weight : array(1, out_type), out_type);
|
(has_weight) ? astype(*weight, out_type, s) : array(1, out_type);
|
||||||
auto passed_bias =
|
auto passed_bias =
|
||||||
astype((bias.has_value()) ? *bias : array(0, out_type), out_type);
|
(has_bias) ? astype(*bias, out_type, s) : array(0, out_type);
|
||||||
|
|
||||||
if (s.device == Device::gpu) {
|
if (s.device == Device::gpu) {
|
||||||
return array(
|
return array(
|
||||||
|
@ -10,7 +10,7 @@ namespace mlx::core::fast {
|
|||||||
|
|
||||||
array rms_norm(
|
array rms_norm(
|
||||||
const array& x,
|
const array& x,
|
||||||
const array& weight,
|
const std::optional<array>& weight,
|
||||||
float eps,
|
float eps,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
@ -25,12 +25,12 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
"rms_norm",
|
"rms_norm",
|
||||||
&mx::fast::rms_norm,
|
&mx::fast::rms_norm,
|
||||||
"x"_a,
|
"x"_a,
|
||||||
"weight"_a,
|
"weight"_a.none(),
|
||||||
"eps"_a,
|
"eps"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def rms_norm(x: array, weight: array, eps: float, *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def rms_norm(x: array, weight: Optional[array], eps: float, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Root Mean Square normalization (RMS norm).
|
Root Mean Square normalization (RMS norm).
|
||||||
|
|
||||||
@ -38,9 +38,9 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (array): Input array.
|
x (array): Input array.
|
||||||
weight (array): A multiplicative weight to scale the result by.
|
weight (array, optional): A multiplicative weight to scale the result by.
|
||||||
The ``weight`` should be one-dimensional with the same size
|
The ``weight`` should be one-dimensional with the same size
|
||||||
as the last axis of ``x``.
|
as the last axis of ``x``. If set to ``None`` then no scaling happens.
|
||||||
eps (float): A small additive constant for numerical stability.
|
eps (float): A small additive constant for numerical stability.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -298,6 +298,9 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
rx = rms_norm(x, weight, eps)
|
rx = rms_norm(x, weight, eps)
|
||||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||||
|
rx = rms_norm(x, mx.ones_like(weight), eps)
|
||||||
|
rx_fast = mx.fast.rms_norm(x, None, eps)
|
||||||
|
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||||
|
|
||||||
for eps in epss:
|
for eps in epss:
|
||||||
dtype, _, dims = defaults
|
dtype, _, dims = defaults
|
||||||
@ -306,6 +309,9 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
rx = rms_norm(x, weight, eps)
|
rx = rms_norm(x, weight, eps)
|
||||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||||
|
rx = rms_norm(x, mx.ones_like(weight), eps)
|
||||||
|
rx_fast = mx.fast.rms_norm(x, None, eps)
|
||||||
|
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||||
|
|
||||||
for dims in dimss:
|
for dims in dimss:
|
||||||
dtype, eps, _ = defaults
|
dtype, eps, _ = defaults
|
||||||
@ -314,6 +320,9 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
rx = rms_norm(x, weight, eps)
|
rx = rms_norm(x, weight, eps)
|
||||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||||
|
rx = rms_norm(x, mx.ones_like(weight), eps)
|
||||||
|
rx_fast = mx.fast.rms_norm(x, None, eps)
|
||||||
|
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||||
|
|
||||||
# Test > 4096
|
# Test > 4096
|
||||||
dims, dtype, eps = 4099, mx.float32, 1e-5
|
dims, dtype, eps = 4099, mx.float32, 1e-5
|
||||||
@ -333,6 +342,8 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
eps = 1e-5
|
eps = 1e-5
|
||||||
f1 = lambda x, w, y: (rms_norm(x, w, eps) * y).sum()
|
f1 = lambda x, w, y: (rms_norm(x, w, eps) * y).sum()
|
||||||
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, eps) * y).sum()
|
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, eps) * y).sum()
|
||||||
|
f3 = lambda x, y: (rms_norm(x, mx.ones((x.shape[-1],)), eps) * y).sum()
|
||||||
|
f4 = lambda x, y: (mx.fast.rms_norm(x, None, eps) * y).sum()
|
||||||
|
|
||||||
x = mx.random.uniform(shape=(8, 100, D))
|
x = mx.random.uniform(shape=(8, 100, D))
|
||||||
w = mx.random.uniform(shape=(D,))
|
w = mx.random.uniform(shape=(D,))
|
||||||
@ -341,6 +352,9 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y)
|
gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y)
|
||||||
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
|
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
|
||||||
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
|
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
|
||||||
|
gx1 = mx.grad(f3, argnums=(0,))(x, y)
|
||||||
|
gx2 = mx.grad(f4, argnums=(0,))(x, y)
|
||||||
|
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
|
||||||
|
|
||||||
D = 8192
|
D = 8192
|
||||||
x = mx.random.uniform(shape=(2, 2, D))
|
x = mx.random.uniform(shape=(2, 2, D))
|
||||||
@ -350,6 +364,9 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y)
|
gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y)
|
||||||
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
|
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
|
||||||
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
|
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
|
||||||
|
gx1 = mx.grad(f3, argnums=(0,))(x, y)
|
||||||
|
gx2 = mx.grad(f4, argnums=(0,))(x, y)
|
||||||
|
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
|
||||||
|
|
||||||
def gf(f):
|
def gf(f):
|
||||||
def inner(x, w, y):
|
def inner(x, w, y):
|
||||||
|
Loading…
Reference in New Issue
Block a user