From 5e6c130d9321b7f47a823723caa4b0e4b0404fdc Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 28 Feb 2025 20:26:57 -0800 Subject: [PATCH] RMS norm without scaling (#1915) --- benchmarks/python/layer_norm_bench.py | 29 +++++- benchmarks/python/rms_norm_bench.py | 26 ++++- mlx/backend/metal/kernels/layer_norm.metal | 18 +++- mlx/backend/metal/kernels/rms_norm.metal | 18 +++- mlx/backend/metal/normalization.cpp | 108 ++++++++++++--------- mlx/fast.cpp | 95 ++++++++++-------- mlx/fast.h | 2 +- python/src/fast.cpp | 8 +- python/tests/test_fast.py | 17 ++++ 9 files changed, 220 insertions(+), 101 deletions(-) diff --git a/benchmarks/python/layer_norm_bench.py b/benchmarks/python/layer_norm_bench.py index 8b9635315..69263835a 100644 --- a/benchmarks/python/layer_norm_bench.py +++ b/benchmarks/python/layer_norm_bench.py @@ -10,7 +10,12 @@ def layer_norm(x, w, b, eps): x = x.astype(mx.float32) mu = mx.mean(x, -1, keepdims=True) v = mx.var(x, -1, keepdims=True) - return (x - mu) * mx.rsqrt(v + eps) * w + b + y = (x - mu) * mx.rsqrt(v + eps) + if w is not None: + y = y * w + if b is not None: + y = y + b + return y def time_layer_norm(): @@ -36,6 +41,28 @@ def time_layer_norm(): time_fn(layer_norm_loop, mx.compile(g1), x, w, b) time_fn(layer_norm_loop, mx.compile(g2), x, w, b) + f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum() + f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum() + g1 = mx.grad(f1, argnums=(0,)) + g2 = mx.grad(f2, argnums=(0,)) + + x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) + w = mx.random.uniform(shape=(4096,)).astype(mx.float16) + b = mx.random.uniform(shape=(4096,)).astype(mx.float16) + y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) + mx.eval(x, w, b, y) + + def layer_norm_loop(g, x): + gx = x + for _ in range(32): + gx = g(gx, y) + return gx + + time_fn(layer_norm_loop, g1, x) + time_fn(layer_norm_loop, g2, x) + time_fn(layer_norm_loop, mx.compile(g1), x) + time_fn(layer_norm_loop, mx.compile(g2), x) + if __name__ == "__main__": time_layer_norm() diff --git a/benchmarks/python/rms_norm_bench.py b/benchmarks/python/rms_norm_bench.py index a54dfe697..50f3a40b0 100644 --- a/benchmarks/python/rms_norm_bench.py +++ b/benchmarks/python/rms_norm_bench.py @@ -9,7 +9,10 @@ def rms_norm(x, w, eps): ot = x.dtype x = x.astype(mx.float32) n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps) - return (x * n).astype(ot) * w + y = (x * n).astype(ot) + if w is not None: + y = y * w + return y def time_rms_norm(): @@ -34,6 +37,27 @@ def time_rms_norm(): time_fn(rms_norm_loop, mx.compile(g1), x, w) time_fn(rms_norm_loop, mx.compile(g2), x, w) + f1 = lambda x, y: (rms_norm(x, None, 1e-5) * y).sum() + f2 = lambda x, y: (mx.fast.rms_norm(x, None, 1e-5) * y).sum() + g1 = mx.grad(f1, argnums=(0,)) + g2 = mx.grad(f2, argnums=(0,)) + + x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) + w = mx.random.uniform(shape=(4096,)).astype(mx.float16) + y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) + mx.eval(x, w, y) + + def rms_norm_loop(g, x): + gx = x + for _ in range(32): + gx = g(gx, y) + return gx + + time_fn(rms_norm_loop, g1, x) + time_fn(rms_norm_loop, g2, x) + time_fn(rms_norm_loop, mx.compile(g1), x) + time_fn(rms_norm_loop, mx.compile(g2), x) + if __name__ == "__main__": time_rms_norm() diff --git a/mlx/backend/metal/kernels/layer_norm.metal b/mlx/backend/metal/kernels/layer_norm.metal index 462eb3b94..4674a4228 100644 --- a/mlx/backend/metal/kernels/layer_norm.metal +++ b/mlx/backend/metal/kernels/layer_norm.metal @@ -7,6 +7,8 @@ using namespace metal; +constant bool has_w [[function_constant(20)]]; + template [[kernel]] void layer_norm_single_row( const device T* x, @@ -327,7 +329,9 @@ template gx[i] = static_cast( normalizer * (thread_w[i] * thread_g[i] - meanwg) - thread_x[i] * meanwgxc * normalizer2); - gw[i] = static_cast(thread_g[i] * thread_x[i]); + if (has_w) { + gw[i] = static_cast(thread_g[i] * thread_x[i]); + } } } else { for (int i = 0; i < N_READS; i++) { @@ -336,7 +340,9 @@ template gx[i] = static_cast( normalizer * (thread_w[i] * thread_g[i] - meanwg) - thread_x[i] * meanwgxc * normalizer2); - gw[i] = static_cast(thread_g[i] * thread_x[i]); + if (has_w) { + gw[i] = static_cast(thread_g[i] * thread_x[i]); + } } } } @@ -465,7 +471,9 @@ template float gi = g[i + r]; gx[i + r] = static_cast( normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2); - gw[i + r] = static_cast(gi * xi); + if (has_w) { + gw[i + r] = static_cast(gi * xi); + } } } else { for (int i = 0; i < N_READS; i++) { @@ -475,7 +483,9 @@ template float gi = g[i + r]; gx[i + r] = static_cast( normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2); - gw[i + r] = static_cast(gi * xi); + if (has_w) { + gw[i + r] = static_cast(gi * xi); + } } } } diff --git a/mlx/backend/metal/kernels/rms_norm.metal b/mlx/backend/metal/kernels/rms_norm.metal index f8fb53dd5..f4c1536de 100644 --- a/mlx/backend/metal/kernels/rms_norm.metal +++ b/mlx/backend/metal/kernels/rms_norm.metal @@ -7,6 +7,8 @@ using namespace metal; +constant bool has_w [[function_constant(20)]]; + template [[kernel]] void rms_single_row( const device T* x, @@ -243,7 +245,9 @@ template gx[i] = static_cast( thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3); - gw[i] = static_cast(thread_g[i] * thread_x[i] * normalizer); + if (has_w) { + gw[i] = static_cast(thread_g[i] * thread_x[i] * normalizer); + } } } else { for (int i = 0; i < N_READS; i++) { @@ -251,7 +255,9 @@ template gx[i] = static_cast( thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3); - gw[i] = static_cast(thread_g[i] * thread_x[i] * normalizer); + if (has_w) { + gw[i] = static_cast(thread_g[i] * thread_x[i] * normalizer); + } } } } @@ -351,7 +357,9 @@ template gx[i + r] = static_cast(gi * wi * normalizer - xi * meangwx * normalizer3); - gw[i + r] = static_cast(gi * xi * normalizer); + if (has_w) { + gw[i + r] = static_cast(gi * xi * normalizer); + } } } else { for (int i = 0; i < N_READS; i++) { @@ -362,7 +370,9 @@ template gx[i + r] = static_cast(gi * wi * normalizer - xi * meangwx * normalizer3); - gw[i + r] = static_cast(gi * xi * normalizer); + if (has_w) { + gw[i + r] = static_cast(gi * xi * normalizer); + } } } } diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index 7fa7e8646..3f5216392 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -77,7 +77,7 @@ void RMSNorm::eval_gpu( group_dims = MTL::Size(threadgroup_size, 1, 1); } - uint32_t w_stride = w.strides()[0]; + uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array( x.data_shared_ptr() == nullptr ? out : x, 0); @@ -101,20 +101,16 @@ void RMSNormVJP::eval_gpu( // Ensure row contiguity. We could relax this step by checking that the array // is contiguous (no broadcasts or holes) and that the input strides are the // same as the cotangent strides but for now this is simpler. - std::vector copies; - auto check_input = [&copies, &s](const array& x) -> const array& { + auto check_input = [&d, &s](const array& x) -> array { if (x.flags().row_contiguous) { return x; } - // Make sure we 'll only ever allocate once. The point of that goes beyond - // the minor optimization. We need to ensure that there will be no - // reallocation such that the references won't change when we - // push_back(...). So tl;dr 3 possible copies x, g and gw_temp. - copies.reserve(3); - copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); - copy_gpu(x, copies.back(), CopyType::General, s); - return copies.back(); + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + d.add_temporary(x_copy, s.index); + + return x_copy; }; const array& x = check_input(inputs[0]); const array& w = inputs[1]; @@ -122,6 +118,9 @@ void RMSNormVJP::eval_gpu( array& gx = outputs[0]; array& gw = outputs[1]; + // Check whether we had a weight + bool has_w = w.ndim() != 0; + // Allocate space for the outputs bool x_in_gx = false; bool g_in_gx = false; @@ -140,15 +139,18 @@ void RMSNormVJP::eval_gpu( // Allocate the gradient accumulator gw and a temporary to store the // gradients before they are accumulated. - array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}); + array gw_temp = + (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; bool g_in_gw = false; - if (!g_in_gx && g.is_donatable()) { - gw_temp.move_shared_buffer(g); - g_in_gw = true; - } else { - gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes())); + if (has_w) { + if (!g_in_gx && g.is_donatable()) { + gw_temp.move_shared_buffer(g); + g_in_gw = true; + } else { + gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes())); + d.add_temporary(gw_temp, s.index); + } } - copies.push_back(gw_temp); gw.set_data(allocator::malloc_or_wait(gw.nbytes())); const int simd_size = 32; @@ -159,9 +161,15 @@ void RMSNormVJP::eval_gpu( op_name += "_looped"; } op_name += type_to_name(gx); + + std::string hash_name = op_name + ((has_w) ? "_w" : "_now"); + metal::MTLFCList func_consts = { + {&has_w, MTL::DataType::DataTypeBool, 20}, + }; + auto& compute_encoder = d.get_command_encoder(s.index); { - auto kernel = d.get_kernel(op_name); + auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts); MTL::Size grid_dims, group_dims; if (axis_size <= looped_limit) { @@ -179,7 +187,7 @@ void RMSNormVJP::eval_gpu( group_dims = MTL::Size(threadgroup_size, 1, 1); } - uint32_t w_stride = w.strides()[0]; + uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(x_in_gx ? gx : x, 0); compute_encoder.set_input_array(w, 1); @@ -192,12 +200,12 @@ void RMSNormVJP::eval_gpu( compute_encoder.dispatch_threads(grid_dims, group_dims); } - ReductionPlan plan( - ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); - strided_reduce_general_dispatch( - gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s); - - d.add_temporaries(std::move(copies), s.index); + if (has_w) { + ReductionPlan plan( + ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); + strided_reduce_general_dispatch( + gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s); + } } void LayerNorm::eval_gpu( @@ -295,20 +303,16 @@ void LayerNormVJP::eval_gpu( // Ensure row contiguity. We could relax this step by checking that the array // is contiguous (no broadcasts or holes) and that the input strides are the // same as the cotangent strides but for now this is simpler. - std::vector copies; - auto check_input = [&copies, &s](const array& x) -> const array& { + auto check_input = [&d, &s](const array& x) -> array { if (x.flags().row_contiguous) { return x; } - // Make sure we 'll only ever allocate once. The point of that goes beyond - // the minor optimization. We need to ensure that there will be no - // reallocation such that the references won't change when we - // push_back(...). So tl;dr 3 possible copies x, g and gw_temp. - copies.reserve(3); - copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); - copy_gpu(x, copies.back(), CopyType::General, s); - return copies.back(); + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + d.add_temporary(x_copy, s.index); + + return x_copy; }; const array& x = check_input(inputs[0]); const array& w = inputs[1]; @@ -318,6 +322,9 @@ void LayerNormVJP::eval_gpu( array& gw = outputs[1]; array& gb = outputs[2]; + // Check whether we had a weight + bool has_w = w.ndim() != 0; + // Allocate space for the outputs bool x_in_gx = false; bool g_in_gx = false; @@ -336,15 +343,18 @@ void LayerNormVJP::eval_gpu( // Allocate a temporary to store the gradients for w and allocate the output // gradient accumulators. - array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}); + array gw_temp = + (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; bool g_in_gw = false; - if (!g_in_gx && g.is_donatable()) { - gw_temp.move_shared_buffer(g); - g_in_gw = true; - } else { - gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes())); + if (has_w) { + if (!g_in_gx && g.is_donatable()) { + gw_temp.move_shared_buffer(g); + g_in_gw = true; + } else { + gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes())); + } + d.add_temporary(gw_temp, s.index); } - copies.push_back(gw_temp); gw.set_data(allocator::malloc_or_wait(gw.nbytes())); gb.set_data(allocator::malloc_or_wait(gb.nbytes())); @@ -372,8 +382,14 @@ void LayerNormVJP::eval_gpu( op_name += "_looped"; } op_name += type_to_name(gx); + + std::string hash_name = op_name + ((has_w) ? "_w" : "_now"); + metal::MTLFCList func_consts = { + {&has_w, MTL::DataType::DataTypeBool, 20}, + }; + { - auto kernel = d.get_kernel(op_name); + auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts); MTL::Size grid_dims, group_dims; if (axis_size <= looped_limit) { @@ -404,14 +420,12 @@ void LayerNormVJP::eval_gpu( compute_encoder.dispatch_threads(grid_dims, group_dims); } - if (gw.ndim() == 1 && gw.size() == axis_size) { + if (has_w) { ReductionPlan plan( ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); strided_reduce_general_dispatch( gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s); } - - d.add_temporaries(std::move(copies), s.index); } } // namespace mlx::core::fast diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 1967c018f..82e1f0569 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -54,30 +54,34 @@ std::pair, std::vector> Custom::vmap( array rms_norm( const array& x, - const array& weight, + const std::optional& weight, float eps, StreamOrDevice s_ /* = {} */) { + bool has_weight = weight.has_value(); + if (x.ndim() == 0) { std::ostringstream msg; msg << "[rms_norm] Input must have at least 1 dimension but got input with " "0 dimensions."; throw std::invalid_argument(msg.str()); } - if (weight.ndim() != 1) { - std::ostringstream msg; - msg << "[rms_norm] weight must have 1 dimension but has " << weight.ndim() - << " dimensions."; - throw std::invalid_argument(msg.str()); - } - if (weight.size() != x.shape(-1)) { - std::ostringstream msg; - msg << "[rms_norm] weight must have the same size as the last dimension of" - " x but has " - << weight.size() << " elements."; - throw std::invalid_argument(msg.str()); + if (has_weight) { + if ((*weight).ndim() != 1) { + std::ostringstream msg; + msg << "[rms_norm] (*weight) must have 1 dimension but has " + << (*weight).ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + if ((*weight).size() != x.shape(-1)) { + std::ostringstream msg; + msg << "[rms_norm] (*weight) must have the same size as the last dimension of" + " x but has " + << (*weight).size() << " elements."; + throw std::invalid_argument(msg.str()); + } } - auto out_type = result_type(x, weight); + auto out_type = (weight.has_value()) ? result_type(x, (*weight)) : x.dtype(); if (!issubdtype(out_type, floating)) { std::ostringstream msg; msg << "[rms_norm] Received unsupported type " << out_type << "."; @@ -85,27 +89,36 @@ array rms_norm( } auto s = to_stream(s_); - auto fallback = [eps, out_type, s](const std::vector& inputs) { - auto x = astype(inputs[0], float32, s); - x = multiply( - x, - rsqrt( - add(mean(square(x, s), -1, /* keepdims */ true, s), - array(eps, float32), + auto fallback = + [has_weight, eps, out_type, s](const std::vector& inputs) { + auto x = astype(inputs[0], float32, s); + x = multiply( + x, + rsqrt( + add(mean(square(x, s), -1, /* keepdims */ true, s), + array(eps, float32), + s), s), - s), - s); - x = astype(x, out_type, s); - return std::vector{multiply(inputs[1], x, s)}; - }; + s); + x = astype(x, out_type, s); + + if (has_weight) { + x = multiply(x, inputs[1], s); + } + + return std::vector{x}; + }; + + auto passed_weight = + (has_weight) ? astype(*weight, out_type, s) : array(1, out_type); if (s.device == Device::gpu) { return array( x.shape(), out_type, std::make_shared(s, fallback, eps), - {astype(x, out_type, s), astype(weight, out_type, s)}); + {astype(x, out_type, s), passed_weight}); } - return fallback({x, weight})[0]; + return fallback({x, passed_weight})[0]; } std::vector RMSNorm::vjp( @@ -141,8 +154,12 @@ std::vector RMSNorm::vjp( // df/dw std::vector axes(g.ndim() - 1); std::iota(axes.begin(), axes.end(), 0); - vjps.push_back( - sum(multiply(g, multiply(x, n, s), s), axes, /* keepdims= */ false, s)); + if (w.ndim() == 0) { + vjps.push_back(zeros_like(w, s)); + } else { + vjps.push_back(sum( + multiply(g, multiply(x, n, s), s), axes, /* keepdims= */ false, s)); + } return vjps; }; @@ -177,28 +194,30 @@ array layer_norm( const std::optional& bias, float eps, StreamOrDevice s_ /* = {} */) { + bool has_weight = weight.has_value(); + bool has_bias = bias.has_value(); + if (x.ndim() == 0) { std::ostringstream msg; msg << "[layer_norm] Input must have at least 1 dimension but got input with " "0 dimensions."; throw std::invalid_argument(msg.str()); } - if (weight.has_value() && (*weight).ndim() != 1) { + if (has_weight && (*weight).ndim() != 1) { std::ostringstream msg; msg << "[layer_norm] weight must have 1 dimension but has " << (*weight).ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } - if (bias.has_value() && (*bias).ndim() != 1) { + if (has_bias && (*bias).ndim() != 1) { std::ostringstream msg; msg << "[layer_norm] bias must have 1 dimension but has " << (*bias).ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } - auto out_type = (weight.has_value()) - ? ((bias.has_value()) ? result_type(x, *weight, *bias) - : result_type(x, *weight)) + auto out_type = (has_weight) + ? ((has_bias) ? result_type(x, *weight, *bias) : result_type(x, *weight)) : x.dtype(); if (!issubdtype(out_type, floating)) { std::ostringstream msg; @@ -207,8 +226,6 @@ array layer_norm( } auto s = to_stream(s_); - bool has_weight = weight.has_value(); - bool has_bias = bias.has_value(); auto fallback = [has_weight, has_bias, eps, out_type, s]( const std::vector& inputs) { auto x = astype(inputs[0], float32, s); @@ -234,9 +251,9 @@ array layer_norm( }; auto passed_weight = - astype((weight.has_value()) ? *weight : array(1, out_type), out_type); + (has_weight) ? astype(*weight, out_type, s) : array(1, out_type); auto passed_bias = - astype((bias.has_value()) ? *bias : array(0, out_type), out_type); + (has_bias) ? astype(*bias, out_type, s) : array(0, out_type); if (s.device == Device::gpu) { return array( diff --git a/mlx/fast.h b/mlx/fast.h index 9e6586cf6..fe93de85e 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -10,7 +10,7 @@ namespace mlx::core::fast { array rms_norm( const array& x, - const array& weight, + const std::optional& weight, float eps, StreamOrDevice s = {}); diff --git a/python/src/fast.cpp b/python/src/fast.cpp index d7ccc000b..fc2cbd41d 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -25,12 +25,12 @@ void init_fast(nb::module_& parent_module) { "rms_norm", &mx::fast::rms_norm, "x"_a, - "weight"_a, + "weight"_a.none(), "eps"_a, nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def rms_norm(x: array, weight: array, eps: float, *, stream: Union[None, Stream, Device] = None) -> array"), + "def rms_norm(x: array, weight: Optional[array], eps: float, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Root Mean Square normalization (RMS norm). @@ -38,9 +38,9 @@ void init_fast(nb::module_& parent_module) { Args: x (array): Input array. - weight (array): A multiplicative weight to scale the result by. + weight (array, optional): A multiplicative weight to scale the result by. The ``weight`` should be one-dimensional with the same size - as the last axis of ``x``. + as the last axis of ``x``. If set to ``None`` then no scaling happens. eps (float): A small additive constant for numerical stability. Returns: diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index 2aa8b067c..2c90a3755 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -298,6 +298,9 @@ class TestFast(mlx_tests.MLXTestCase): rx = rms_norm(x, weight, eps) rx_fast = mx.fast.rms_norm(x, weight, eps) self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + rx = rms_norm(x, mx.ones_like(weight), eps) + rx_fast = mx.fast.rms_norm(x, None, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) for eps in epss: dtype, _, dims = defaults @@ -306,6 +309,9 @@ class TestFast(mlx_tests.MLXTestCase): rx = rms_norm(x, weight, eps) rx_fast = mx.fast.rms_norm(x, weight, eps) self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + rx = rms_norm(x, mx.ones_like(weight), eps) + rx_fast = mx.fast.rms_norm(x, None, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) for dims in dimss: dtype, eps, _ = defaults @@ -314,6 +320,9 @@ class TestFast(mlx_tests.MLXTestCase): rx = rms_norm(x, weight, eps) rx_fast = mx.fast.rms_norm(x, weight, eps) self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + rx = rms_norm(x, mx.ones_like(weight), eps) + rx_fast = mx.fast.rms_norm(x, None, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) # Test > 4096 dims, dtype, eps = 4099, mx.float32, 1e-5 @@ -333,6 +342,8 @@ class TestFast(mlx_tests.MLXTestCase): eps = 1e-5 f1 = lambda x, w, y: (rms_norm(x, w, eps) * y).sum() f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, eps) * y).sum() + f3 = lambda x, y: (rms_norm(x, mx.ones((x.shape[-1],)), eps) * y).sum() + f4 = lambda x, y: (mx.fast.rms_norm(x, None, eps) * y).sum() x = mx.random.uniform(shape=(8, 100, D)) w = mx.random.uniform(shape=(D,)) @@ -341,6 +352,9 @@ class TestFast(mlx_tests.MLXTestCase): gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y) self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5) self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5) + gx1 = mx.grad(f3, argnums=(0,))(x, y) + gx2 = mx.grad(f4, argnums=(0,))(x, y) + self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5) D = 8192 x = mx.random.uniform(shape=(2, 2, D)) @@ -350,6 +364,9 @@ class TestFast(mlx_tests.MLXTestCase): gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y) self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5) self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5) + gx1 = mx.grad(f3, argnums=(0,))(x, y) + gx2 = mx.grad(f4, argnums=(0,))(x, y) + self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5) def gf(f): def inner(x, w, y):