mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
64 lines
1.8 KiB
Python
64 lines
1.8 KiB
Python
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
from time_utils import time_fn
|
|
|
|
|
|
def rms_norm(x, w, eps):
|
|
ot = x.dtype
|
|
x = x.astype(mx.float32)
|
|
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
|
y = (x * n).astype(ot)
|
|
if w is not None:
|
|
y = y * w
|
|
return y
|
|
|
|
|
|
def time_rms_norm():
|
|
f1 = lambda x, w, y: (rms_norm(x, w, 1e-5) * y).sum()
|
|
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, 1e-5) * y).sum()
|
|
g1 = mx.grad(f1, argnums=(0, 1))
|
|
g2 = mx.grad(f2, argnums=(0, 1))
|
|
|
|
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
|
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
|
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
|
mx.eval(x, w, y)
|
|
|
|
def rms_norm_loop(g, x, w):
|
|
gx, gw = x, w
|
|
for _ in range(32):
|
|
gx, gw = g(gx, gw, y)
|
|
return gx, gw
|
|
|
|
time_fn(rms_norm_loop, g1, x, w)
|
|
time_fn(rms_norm_loop, g2, x, w)
|
|
time_fn(rms_norm_loop, mx.compile(g1), x, w)
|
|
time_fn(rms_norm_loop, mx.compile(g2), x, w)
|
|
|
|
f1 = lambda x, y: (rms_norm(x, None, 1e-5) * y).sum()
|
|
f2 = lambda x, y: (mx.fast.rms_norm(x, None, 1e-5) * y).sum()
|
|
g1 = mx.grad(f1, argnums=(0,))
|
|
g2 = mx.grad(f2, argnums=(0,))
|
|
|
|
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
|
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
|
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
|
mx.eval(x, w, y)
|
|
|
|
def rms_norm_loop(g, x):
|
|
gx = x
|
|
for _ in range(32):
|
|
gx = g(gx, y)
|
|
return gx
|
|
|
|
time_fn(rms_norm_loop, g1, x)
|
|
time_fn(rms_norm_loop, g2, x)
|
|
time_fn(rms_norm_loop, mx.compile(g1), x)
|
|
time_fn(rms_norm_loop, mx.compile(g2), x)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
time_rms_norm()
|