chore: clean up the rope scalling factor param in create cos sin theta

This commit is contained in:
Anchen 2023-12-23 17:17:14 +11:00 committed by Awni Hannun
parent 784149d699
commit bd63a3e5ee

View File

@ -52,11 +52,11 @@ class LinearScalingRoPE(nn.RoPE):
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = LinearScalingRoPE.create_cos_sin_theta(
self.rope_scaling_factor,
N,
self.dims,
offset=offset,
base=self.base,
rope_scaling_factor=self.rope_scaling_factor,
dtype=x.dtype,
)
@ -66,11 +66,11 @@ class LinearScalingRoPE(nn.RoPE):
@staticmethod
def create_cos_sin_theta(
rope_scaling_factor: float,
N: int,
D: int,
offset: int = 0,
base: float = 10000,
rope_scaling_factor: float = 1.0,
dtype=mx.float32,
):
D = D // 2