# Copyright © 2023-2024 Apple Inc. import mlx.core as mx import mlx.nn as nn from time_utils import time_fn def rms_norm(x, w, eps): ot = x.dtype x = x.astype(mx.float32) n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps) y = (x * n).astype(ot) if w is not None: y = y * w return y def time_rms_norm(): f1 = lambda x, w, y: (rms_norm(x, w, 1e-5) * y).sum() f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, 1e-5) * y).sum() g1 = mx.grad(f1, argnums=(0, 1)) g2 = mx.grad(f2, argnums=(0, 1)) x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) w = mx.random.uniform(shape=(4096,)).astype(mx.float16) y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) mx.eval(x, w, y) def rms_norm_loop(g, x, w): gx, gw = x, w for _ in range(32): gx, gw = g(gx, gw, y) return gx, gw time_fn(rms_norm_loop, g1, x, w) time_fn(rms_norm_loop, g2, x, w) time_fn(rms_norm_loop, mx.compile(g1), x, w) time_fn(rms_norm_loop, mx.compile(g2), x, w) f1 = lambda x, y: (rms_norm(x, None, 1e-5) * y).sum() f2 = lambda x, y: (mx.fast.rms_norm(x, None, 1e-5) * y).sum() g1 = mx.grad(f1, argnums=(0,)) g2 = mx.grad(f2, argnums=(0,)) x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) w = mx.random.uniform(shape=(4096,)).astype(mx.float16) y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) mx.eval(x, w, y) def rms_norm_loop(g, x): gx = x for _ in range(32): gx = g(gx, y) return gx time_fn(rms_norm_loop, g1, x) time_fn(rms_norm_loop, g2, x) time_fn(rms_norm_loop, mx.compile(g1), x) time_fn(rms_norm_loop, mx.compile(g2), x) if __name__ == "__main__": time_rms_norm()