Fix copy donation and add partial rope (#881)

This commit is contained in:
Angelos Katharopoulos
2024-03-22 17:28:26 -07:00
committed by GitHub
parent 8e5a5a1ccd
commit 6ee1112f30
6 changed files with 42 additions and 20 deletions

View File

@@ -6,21 +6,21 @@ from time_utils import time_fn
def time_rope():
rope = nn.RoPE(4096)
rope = nn.RoPE(64)
# vec
x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16)
x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16)
mx.eval(x)
def rope_vec(x):
for _ in range(32):
x = rope(x)
x = rope(x, offset=100)
return x
time_fn(rope_vec, x)
# matrix
x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16)
x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16)
mx.eval(x)
def rope_mat(x):