This commit is contained in:
Goekdeniz-Guelmez 2025-01-18 20:59:02 +01:00
parent 342fa4af66
commit 62842d218d

View File

@ -32,6 +32,7 @@ def rotate_half(x: mx.array) -> mx.array:
x2 = x[..., 1::2]
return mx.concatenate([-x2, x1], axis=-1)
def apply_rotary_pos_emb(q: mx.array, k: mx.array, cos: mx.array, sin: mx.array, position_ids=None, unsqueeze_dim=1) -> Tuple[mx.array, mx.array]:
"""
Applies Rotary Position Embedding to the query and key tensors.