mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-11 06:04:36 +08:00
chore: clean up the rope scalling factor param in create cos sin theta
This commit is contained in:
@@ -52,11 +52,11 @@ class LinearScalingRoPE(nn.RoPE):
|
|||||||
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
||||||
N = x.shape[1] + offset
|
N = x.shape[1] + offset
|
||||||
costheta, sintheta = LinearScalingRoPE.create_cos_sin_theta(
|
costheta, sintheta = LinearScalingRoPE.create_cos_sin_theta(
|
||||||
self.rope_scaling_factor,
|
|
||||||
N,
|
N,
|
||||||
self.dims,
|
self.dims,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
base=self.base,
|
base=self.base,
|
||||||
|
rope_scaling_factor=self.rope_scaling_factor,
|
||||||
dtype=x.dtype,
|
dtype=x.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -66,11 +66,11 @@ class LinearScalingRoPE(nn.RoPE):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_cos_sin_theta(
|
def create_cos_sin_theta(
|
||||||
rope_scaling_factor: float,
|
|
||||||
N: int,
|
N: int,
|
||||||
D: int,
|
D: int,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
base: float = 10000,
|
base: float = 10000,
|
||||||
|
rope_scaling_factor: float = 1.0,
|
||||||
dtype=mx.float32,
|
dtype=mx.float32,
|
||||||
):
|
):
|
||||||
D = D // 2
|
D = D // 2
|
||||||
|
Reference in New Issue
Block a user