# Copyright © 2023-2024 Apple Inc. import mlx.core as mx import mlx.nn as nn from time_utils import time_fn def time_rope(): rope = nn.RoPE(64) # vec x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16) mx.eval(x) def rope_vec(x): for _ in range(32): x = rope(x, offset=100) return x time_fn(rope_vec, x) # matrix x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16) mx.eval(x) def rope_mat(x): for _ in range(32): x = rope(x) return x time_fn(rope_mat, x) if __name__ == "__main__": time_rope()