mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-31 20:14:36 +08:00
Fix copy donation and add partial rope (#881)
This commit is contained in:

committed by
GitHub

parent
8e5a5a1ccd
commit
6ee1112f30
@@ -6,21 +6,21 @@ from time_utils import time_fn
|
||||
|
||||
|
||||
def time_rope():
|
||||
rope = nn.RoPE(4096)
|
||||
rope = nn.RoPE(64)
|
||||
|
||||
# vec
|
||||
x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16)
|
||||
x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16)
|
||||
mx.eval(x)
|
||||
|
||||
def rope_vec(x):
|
||||
for _ in range(32):
|
||||
x = rope(x)
|
||||
x = rope(x, offset=100)
|
||||
return x
|
||||
|
||||
time_fn(rope_vec, x)
|
||||
|
||||
# matrix
|
||||
x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16)
|
||||
x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16)
|
||||
mx.eval(x)
|
||||
|
||||
def rope_mat(x):
|
||||
|
Reference in New Issue
Block a user