2024-02-14 14:04:25 -08:00
|
|
|
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
import mlx.nn as nn
|
|
|
|
from time_utils import time_fn
|
|
|
|
|
|
|
|
|
|
|
|
def time_rope():
|
2024-03-22 17:28:26 -07:00
|
|
|
rope = nn.RoPE(64)
|
2024-02-14 14:04:25 -08:00
|
|
|
|
|
|
|
# vec
|
2024-03-22 17:28:26 -07:00
|
|
|
x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16)
|
2024-02-14 14:04:25 -08:00
|
|
|
mx.eval(x)
|
|
|
|
|
|
|
|
def rope_vec(x):
|
|
|
|
for _ in range(32):
|
2024-03-22 17:28:26 -07:00
|
|
|
x = rope(x, offset=100)
|
2024-02-14 14:04:25 -08:00
|
|
|
return x
|
|
|
|
|
|
|
|
time_fn(rope_vec, x)
|
|
|
|
|
|
|
|
# matrix
|
2024-03-22 17:28:26 -07:00
|
|
|
x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16)
|
2024-02-14 14:04:25 -08:00
|
|
|
mx.eval(x)
|
|
|
|
|
|
|
|
def rope_mat(x):
|
|
|
|
for _ in range(32):
|
|
|
|
x = rope(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
time_fn(rope_mat, x)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
time_rope()
|