Fixes for phi4 mini (#1305)

This commit is contained in:
Awni Hannun
2025-02-26 16:21:54 -08:00
committed by GitHub
parent 0f240a4c7e
commit 00a7379070
2 changed files with 16 additions and 6 deletions

View File

@@ -51,11 +51,13 @@ class SuScaledRotaryEmbedding(nn.Module):
+ math.log(max_position_embeddings / original_max_position_embeddings)
/ math.log(original_max_position_embeddings)
)
self.dim = dims
def __call__(self, x, offset: int = 0):
x[..., : self.dim] = self.scale * x[..., : self.dim]
return mx.fast.rope(
self.scale * x,
x.shape[-1],
x,
self.dim,
traditional=False,
base=None,
scale=1.0,