diff --git a/llms/mlx_lm/models/helium.py b/llms/mlx_lm/models/helium.py index 88bb69de..23f45bc0 100644 --- a/llms/mlx_lm/models/helium.py +++ b/llms/mlx_lm/models/helium.py @@ -32,6 +32,7 @@ def rotate_half(x: mx.array) -> mx.array: x2 = x[..., 1::2] return mx.concatenate([-x2, x1], axis=-1) + def apply_rotary_pos_emb(q: mx.array, k: mx.array, cos: mx.array, sin: mx.array, position_ids=None, unsqueeze_dim=1) -> Tuple[mx.array, mx.array]: """ Applies Rotary Position Embedding to the query and key tensors.