chore: clean up the rope scalling factor param in create cos sin theta

This commit is contained in:
Anchen
2023-12-23 17:17:14 +11:00
committed by Awni Hannun
parent 784149d699
commit bd63a3e5ee

View File

@@ -52,11 +52,11 @@ class LinearScalingRoPE(nn.RoPE):
x = mx.reshape(x, (-1, shape[-2], shape[-1])) x = mx.reshape(x, (-1, shape[-2], shape[-1]))
N = x.shape[1] + offset N = x.shape[1] + offset
costheta, sintheta = LinearScalingRoPE.create_cos_sin_theta( costheta, sintheta = LinearScalingRoPE.create_cos_sin_theta(
self.rope_scaling_factor,
N, N,
self.dims, self.dims,
offset=offset, offset=offset,
base=self.base, base=self.base,
rope_scaling_factor=self.rope_scaling_factor,
dtype=x.dtype, dtype=x.dtype,
) )
@@ -66,11 +66,11 @@ class LinearScalingRoPE(nn.RoPE):
@staticmethod @staticmethod
def create_cos_sin_theta( def create_cos_sin_theta(
rope_scaling_factor: float,
N: int, N: int,
D: int, D: int,
offset: int = 0, offset: int = 0,
base: float = 10000, base: float = 10000,
rope_scaling_factor: float = 1.0,
dtype=mx.float32, dtype=mx.float32,
): ):
D = D // 2 D = D // 2