diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 71d35f370..a872d5b99 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -44,7 +44,6 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp diff --git a/mlx/backend/common/rope.cpp b/mlx/backend/common/rope.cpp deleted file mode 100644 index 15b5de7e5..000000000 --- a/mlx/backend/common/rope.cpp +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include "mlx/fast_primitives.h" - -namespace mlx::core::fast { - -void RoPE::eval_cpu( - const std::vector& inputs, - std::vector& outputs) { - throw std::runtime_error("NYI"); -} - -} // namespace mlx::core::fast diff --git a/mlx/backend/common/softmax.cpp b/mlx/backend/common/softmax.cpp index 87ce748c8..777163b9f 100644 --- a/mlx/backend/common/softmax.cpp +++ b/mlx/backend/common/softmax.cpp @@ -67,11 +67,15 @@ void Softmax::eval(const std::vector& inputs, array& out) { } }; array in = check_input(std::move(inputs[0])); - out.set_data( - allocator::malloc_or_wait(in.data_size() * in.itemsize()), - in.data_size(), - in.strides(), - in.flags()); + if (in.is_donatable()) { + out.copy_shared_buffer(in); + } else { + out.set_data( + allocator::malloc_or_wait(in.data_size() * in.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + } switch (in.dtype()) { case bool_: diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index b8d3b26fe..0b561728d 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -33,6 +33,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index b265babbe..1a112731f 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -23,6 +23,7 @@ set( "gemv" "quantized" "random" + "rms_norm" "rope" "scan" "scaled_dot_product_attention" diff --git a/mlx/backend/metal/kernels/defines.h b/mlx/backend/metal/kernels/defines.h index bdd1419f2..9e62b7c32 100644 --- a/mlx/backend/metal/kernels/defines.h +++ b/mlx/backend/metal/kernels/defines.h @@ -14,3 +14,5 @@ static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4; static MTL_CONST constexpr int REDUCE_N_READS = 16; static MTL_CONST constexpr int SOFTMAX_N_READS = 4; static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096; +static MTL_CONST constexpr int RMS_N_READS = 4; +static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096; diff --git a/mlx/backend/metal/kernels/rms_norm.metal b/mlx/backend/metal/kernels/rms_norm.metal new file mode 100644 index 000000000..0382b6335 --- /dev/null +++ b/mlx/backend/metal/kernels/rms_norm.metal @@ -0,0 +1,194 @@ +// Copyright © 2024 Apple Inc. + +#include +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/kernels/utils.h" + +using namespace metal; + +template +[[kernel]] void rms_single_row( + const device T* x, + const device T* w, + device T* out, + constant float& eps, + constant uint& axis_size, + constant uint& w_stride, + threadgroup float* local_inv_mean [[threadgroup(0)]], + threadgroup float* local_sums [[threadgroup(1)]], + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + float acc = 0; + x += gid * axis_size + lid * N_READS; + w += w_stride * lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float xi = x[i]; + acc += xi * xi; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((lid * N_READS + i) < axis_size) { + float xi = x[i]; + acc += xi * xi; + } + } + } + acc = simd_sum(acc); + // Initialize shared memory + if (simd_group_id == 0) { + local_sums[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write simd accumulations into shared memory + if (simd_lane_id == 0) { + local_sums[simd_group_id] = acc; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Accumulate over simd groups + if (simd_group_id == 0) { + acc = simd_sum(local_sums[simd_lane_id]); + if (simd_lane_id == 0) { + local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write the outputs + out += gid * axis_size + lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + out[i] = w[w_stride * i] * static_cast(x[i] * local_inv_mean[0]); + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((lid * N_READS + i) < axis_size) { + out[i] = w[w_stride * i] * static_cast(x[i] * local_inv_mean[0]); + } + } + } +} + +template +[[kernel]] void rms_looped( + const device T* x, + const device T* w, + device T* out, + constant float& eps, + constant uint& axis_size, + constant uint& w_stride, + threadgroup float* local_inv_mean [[threadgroup(0)]], + threadgroup float* local_sums [[threadgroup(1)]], + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + float acc = 0; + x += gid * axis_size + lid * N_READS; + w += w_stride * lid * N_READS; + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float xi = x[i + r]; + acc += xi * xi; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + float xi = x[i + r]; + acc += xi * xi; + } + } + } + } + acc = simd_sum(acc); + // Initialize shared memory + if (simd_group_id == 0) { + local_sums[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write simd accumulations into shared memory + if (simd_lane_id == 0) { + local_sums[simd_group_id] = acc; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Accumulate over simd groups + if (simd_group_id == 0) { + acc = simd_sum(local_sums[simd_lane_id]); + if (simd_lane_id == 0) { + local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write the outputs + out += gid * axis_size + lid * N_READS; + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + out[r + i] = w[w_stride * (i + r)] * + static_cast(x[r + i] * local_inv_mean[0]); + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + out[r + i] = w[w_stride * (i + r)] * + static_cast(x[r + i] * local_inv_mean[0]); + } + } + } + } +} + +// clang-format off +#define instantiate_rms_single_row(name, itype) \ + template [[host_name("rms" #name)]] [[kernel]] void \ + rms_single_row( \ + const device itype* x, \ + const device itype* w, \ + device itype* out, \ + constant float& eps, \ + constant uint& axis_size, \ + constant uint& w_stride, \ + threadgroup float* local_inv_mean [[threadgroup(0)]], \ + threadgroup float* local_sums [[threadgroup(1)]], \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + +#define instantiate_rms_looped(name, itype) \ + template [[host_name("rms_looped" #name)]] [[kernel]] void \ + rms_looped( \ + const device itype* x, \ + const device itype* w, \ + device itype* out, \ + constant float& eps, \ + constant uint& axis_size, \ + constant uint& w_stride, \ + threadgroup float* local_inv_mean [[threadgroup(0)]], \ + threadgroup float* local_sums [[threadgroup(1)]], \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint lsize [[threads_per_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + +#define instantiate_rms(name, itype) \ + instantiate_rms_single_row(name, itype) \ + instantiate_rms_looped(name, itype) + +instantiate_rms(float32, float) +instantiate_rms(float16, half) +instantiate_rms(bfloat16, bfloat16_t) + // clang-format on diff --git a/mlx/backend/metal/kernels/softmax.metal b/mlx/backend/metal/kernels/softmax.metal index 0877de075..2fdcaaa56 100644 --- a/mlx/backend/metal/kernels/softmax.metal +++ b/mlx/backend/metal/kernels/softmax.metal @@ -1,6 +1,5 @@ // Copyright © 2023 Apple Inc. -#include #include #include @@ -224,5 +223,6 @@ template instantiate_softmax_single_row(name, itype) \ instantiate_softmax_looped(name, itype) -instantiate_softmax(float32, float) instantiate_softmax(float16, half) - instantiate_softmax(bfloat16, bfloat16_t) +instantiate_softmax(float32, float) +instantiate_softmax(float16, half) +instantiate_softmax(bfloat16, bfloat16_t) diff --git a/mlx/backend/metal/rms_norm.cpp b/mlx/backend/metal/rms_norm.cpp new file mode 100644 index 000000000..a3a783f31 --- /dev/null +++ b/mlx/backend/metal/rms_norm.cpp @@ -0,0 +1,98 @@ +// Copyright © 2024 Apple Inc. +#include + +#include "mlx/backend/metal/copy.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/fast_primitives.h" + +namespace mlx::core::fast { + +void RMSNorm::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& s = stream(); + auto& d = metal::device(s.device); + auto& out = outputs[0]; + + // Make sure that the last dimension is contiguous + std::vector copies; + auto check_input = [&copies, &s](const array& x) { + bool no_copy = x.strides()[x.ndim() - 1] == 1; + if (x.ndim() > 1) { + auto s = x.strides()[x.ndim() - 2]; + no_copy &= (s == 0 || s == x.shape().back()); + } + if (no_copy) { + return x; + } else { + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + copies.push_back(x_copy); + return x_copy; + } + }; + const array& x = check_input(inputs[0]); + const array& w = inputs[1]; + + if (x.is_donatable()) { + out.move_shared_buffer(x); + } else { + out.set_data( + allocator::malloc_or_wait(x.data_size() * x.itemsize()), + x.data_size(), + x.strides(), + x.flags()); + } + + auto axis_size = static_cast(x.shape().back()); + int n_rows = x.data_size() / axis_size; + + const int simd_size = 32; + const int n_reads = RMS_N_READS; + const int looped_limit = RMS_LOOPED_LIMIT; + std::string op_name = "rms"; + if (axis_size > looped_limit) { + op_name += "_looped"; + } + op_name += type_to_name(out); + auto compute_encoder = d.get_command_encoder(s.index); + { + auto kernel = d.get_kernel(op_name); + + MTL::Size grid_dims, group_dims; + if (axis_size <= looped_limit) { + size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads; + size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size; + size_t threadgroup_size = simd_size * simds_needed; + assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup()); + size_t n_threads = n_rows * threadgroup_size; + grid_dims = MTL::Size(n_threads, 1, 1); + group_dims = MTL::Size(threadgroup_size, 1, 1); + } else { + size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup(); + size_t n_threads = n_rows * threadgroup_size; + grid_dims = MTL::Size(n_threads, 1, 1); + group_dims = MTL::Size(threadgroup_size, 1, 1); + } + + uint32_t w_stride = w.strides()[0]; + compute_encoder->setComputePipelineState(kernel); + set_array_buffer( + compute_encoder, x.data_shared_ptr() == nullptr ? out : x, 0); + set_array_buffer(compute_encoder, w, 1); + set_array_buffer(compute_encoder, out, 2); + compute_encoder->setBytes(&eps_, sizeof(float), 3); + compute_encoder->setBytes(&axis_size, sizeof(int), 4); + compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 5); + compute_encoder->setThreadgroupMemoryLength( + 16 * 8, 0); // minimum of 16 bytes + compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); +} + +} // namespace mlx::core::fast diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index be25bc032..3a1405e53 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -37,11 +37,15 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { } }; const array& in = check_input(inputs[0]); - out.set_data( - allocator::malloc_or_wait(in.data_size() * in.itemsize()), - in.data_size(), - in.strides(), - in.flags()); + if (in.is_donatable()) { + out.move_shared_buffer(in); + } else { + out.set_data( + allocator::malloc_or_wait(in.data_size() * in.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + } int axis_size = in.shape().back(); int n_rows = in.data_size() / axis_size; @@ -75,6 +79,8 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { } compute_encoder->setComputePipelineState(kernel); + set_array_buffer( + compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0); set_array_buffer(compute_encoder, in, 0); set_array_buffer(compute_encoder, out, 1); compute_encoder->setBytes(&axis_size, sizeof(int), 2); diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index cd5bef5c4..6bbde8a50 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -102,6 +102,7 @@ NO_GPU(Transpose) NO_GPU(Inverse) namespace fast { +NO_GPU_MULTI(RMSNorm) NO_GPU_MULTI(RoPE) NO_GPU(ScaledDotProductAttention) } // namespace fast diff --git a/mlx/fast.cpp b/mlx/fast.cpp index cfdae139a..cd1f4ecce 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -46,6 +46,59 @@ std::pair, std::vector> Custom::vmap( return {outputs, out_axes}; } +array rms_norm( + const array& x, + const array& weight, + float eps, + StreamOrDevice s_ /* = {} */) { + if (x.ndim() == 0) { + std::ostringstream msg; + msg << "[rms_norm] Input must have at least 1 dimension but got input with " + "0 dimensions."; + throw std::invalid_argument(msg.str()); + } + if (weight.ndim() != 1) { + std::ostringstream msg; + msg << "[rms_norm] weight must have 1 dimension but has " << weight.ndim() + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + auto out_type = result_type({x, weight}); + if (!is_floating_point(out_type) || is_complex(out_type)) { + std::ostringstream msg; + msg << "[rms_norm] Received unsupported type " << out_type << "."; + throw std::invalid_argument(msg.str()); + } + + auto s = to_stream(s_); + auto fallback = [eps, out_type, s](const std::vector& inputs) { + auto x = astype(inputs[0], float32, s); + x = multiply( + x, + rsqrt( + add(mean(square(x, s), -1, /* keepdims */ true, s), + array(eps, float32), + s), + s), + s); + x = astype(x, out_type, s); + return std::vector{multiply(inputs[1], x, s)}; + }; + if (s.device == Device::gpu) { + return array( + x.shape(), + x.dtype(), + std::make_unique(s, fallback, eps), + {astype(x, out_type, s), astype(weight, out_type, s)}); + } + return fallback({x, weight})[0]; +} + +bool RMSNorm::is_equivalent(const Primitive& other) const { + const RMSNorm& a_other = static_cast(other); + return eps_ == a_other.eps_; +} + array rope( const array& x, int dims, diff --git a/mlx/fast.h b/mlx/fast.h index 74b04886c..7e08533ca 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -8,6 +8,12 @@ namespace mlx::core::fast { +array rms_norm( + const array& x, + const array& weight, + float eps, + StreamOrDevice s = {}); + array rope( const array& x, int dims, @@ -15,7 +21,7 @@ array rope( float base, float scale, int offset, - StreamOrDevice s /* = {} */); + StreamOrDevice s = {}); /** Computes: O = softmax(Q @ K.T) @ V **/ array scaled_dot_product_attention( diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index b581b09d9..ea2b56d05 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -1,3 +1,5 @@ +// Copyright © 2024 Apple Inc. + #include "mlx/primitives.h" namespace mlx::core::fast { @@ -31,6 +33,29 @@ class Custom : public Primitive { std::function(std::vector)> fallback_; }; +class RMSNorm : public Custom { + public: + RMSNorm( + Stream stream, + std::function(std::vector)> fallback, + float eps) + : Custom(stream, fallback), eps_(eps){}; + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("NYI"); + }; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_PRINT(RMSNorm) + bool is_equivalent(const Primitive& other) const override; + + private: + std::function(std::vector)> fallback_; + float eps_; +}; + class RoPE : public Custom { public: RoPE( @@ -49,7 +74,9 @@ class RoPE : public Custom { offset_(offset){}; void eval_cpu(const std::vector& inputs, std::vector& outputs) - override; + override { + throw std::runtime_error("NYI"); + }; void eval_gpu(const std::vector& inputs, std::vector& outputs) override; diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 42107d658..60d033b0d 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -117,6 +117,8 @@ class RMSNorm(Module): where :math:`\gamma` is a learned per feature dimension parameter initialized at 1. + Note the accumulation for the mean is done in 32-bit precision. + [1]: https://arxiv.org/abs/1910.07467 Args: @@ -133,18 +135,7 @@ class RMSNorm(Module): return f"{self.weight.shape[0]}, eps={self.eps}" def __call__(self, x): - # S is 1/sqrt(N) where N is the size of the features of x and is used - # to compute a numerically more stable RMS of x by multiplying with S - # first and summing. - # - # This way we prefer underflow over overflow which is controlled with - # the parameter epsilon anyway. - S = 1 / x.shape[-1] ** 0.5 - - n = (x * S).square().sum(axis=-1, keepdims=True) - n = mx.rsqrt(n + self.eps) - - return self.weight * x * n + return mx.fast.rms_norm(x, self.weight, self.eps) class GroupNorm(Module): diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 20f0f7033..5bc274ca1 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -15,6 +15,37 @@ void init_fast(nb::module_& parent_module) { auto m = parent_module.def_submodule("fast", "mlx.core.fast: fast operations"); + m.def( + "rms_norm", + [](const array& x, + const array& weight, + float eps, + const StreamOrDevice& s /* = {} */) { + return fast::rms_norm(x, weight, eps, s); + }, + "x"_a, + "weight"_a, + "eps"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def rms_norm(x: array, weight: array, eps: float, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Root Mean Square normalization (RMS norm). + + The normalization is with respect to the last axis of the input ``x``. + + Args: + x (array): Input array. + weight (array): A multiplicative weight to scale the result by. + The ``weight`` should be one-dimensional with the same size + as the last axis of ``x``. + eps (float): A small additive constant for numerical stability. + + Returns: + array: The output array. + )pbdoc"); + m.def( "rope", [](const array& a, diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index 1cb4ddcca..d3285e993 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -115,6 +115,57 @@ class TestFast(mlx_tests.MLXTestCase): ) self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + def test_rms_norm(self): + def rms_norm(x, weight, eps): + x = x.astype(mx.float32) + x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps) + return weight * x.astype(weight.dtype) + + # Per dtype absolute tolerance + tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2} + + dtypes = [mx.float32, mx.float16, mx.bfloat16] + epss = [1e-3, 1e-5] + dimss = [31, 32, 33] + defaults = (mx.float32, 1e-5, 32) + + for dtype in dtypes: + _, eps, dims = defaults + x = mx.random.uniform( + shape=( + 2, + dims, + ) + ).astype(dtype) + weight = mx.random.uniform(shape=(dims,)).astype(dtype) + rx = rms_norm(x, weight, eps) + rx_fast = mx.fast.rms_norm(x, weight, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + + for eps in epss: + dtype, _, dims = defaults + x = mx.random.uniform(shape=(2, dims)).astype(dtype) + weight = mx.random.uniform(shape=(dims,)).astype(dtype) + rx = rms_norm(x, weight, eps) + rx_fast = mx.fast.rms_norm(x, weight, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + + for dims in dimss: + dtype, eps, _ = defaults + x = mx.random.uniform(shape=(2, dims)).astype(dtype) + weight = mx.random.uniform(shape=(dims,)).astype(dtype) + rx = rms_norm(x, weight, eps) + rx_fast = mx.fast.rms_norm(x, weight, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + + # Test > 4096 + dims, dtype, eps = 4099, mx.float32, 1e-5 + x = mx.random.uniform(shape=(dims,)).astype(dtype) + weight = mx.random.uniform(shape=(dims,)).astype(dtype) + rx = rms_norm(x, weight, eps) + rx_fast = mx.fast.rms_norm(x, weight, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6) + def test_fast_transforms(self): x = mx.random.uniform(shape=(2, 2, 8))