mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Add Phi-3.5-MoE (#946)
* add phimoe * add phimoe to tunner * add switch_mlp * fix SuScaled args * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -16,6 +16,8 @@ class SuScaledRotaryEmbedding(nn.Module):
|
||||
original_max_position_embeddings: int = 4096,
|
||||
short_factor: Union[List[float], float] = 1.0,
|
||||
long_factor: Union[List[float], float] = 1.0,
|
||||
short_mscale: float = None,
|
||||
long_mscale: float = None,
|
||||
):
|
||||
"""
|
||||
Phi3Su Scaled Rotary Embedding layer for Phi-3 models.
|
||||
@@ -37,12 +39,14 @@ class SuScaledRotaryEmbedding(nn.Module):
|
||||
long_factor (float or list[float], optional): List of scaling
|
||||
factors for sequences of length greater than
|
||||
``original_max_position_embeddings``. Default: ``1.0``.
|
||||
short_mscale (float, optional): Scale the input prior to embedding.
|
||||
long_mscale (float, optional): Scale the input prior to embedding.
|
||||
"""
|
||||
super().__init__()
|
||||
freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
|
||||
self._freqs = mx.array(long_factor, dtype=mx.float32) * freqs
|
||||
self.original_max_position_embeddings = original_max_position_embeddings
|
||||
self.scale = math.sqrt(
|
||||
self.scale = long_mscale or math.sqrt(
|
||||
1
|
||||
+ math.log(max_position_embeddings / original_max_position_embeddings)
|
||||
/ math.log(original_max_position_embeddings)
|
||||
|
||||
Reference in New Issue
Block a user