This commit is contained in:
Awni Hannun
2024-08-20 17:24:05 -07:00
parent 15975697d2
commit 0a52a9d55a
2 changed files with 6 additions and 18 deletions

View File

@@ -59,19 +59,17 @@ class Attention(nn.Module):
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = 1.0
if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
self.rope = SuScaledRotaryEmbedding(
head_dim,
traditional=False,
base=args.rope_theta,
scale=rope_scale,
max_position_embeddings=args.max_position_embeddings,
original_max_position_embeddings=args.original_max_position_embeddings,
short_factor=args.rope_scaling["short_factor"],
long_factor=args.rope_scaling["long_factor"],
)
else:
rope_scale = 1.0
if args.rope_scaling and args.rope_scaling["type"] == "linear":
assert isinstance(args.rope_scaling["factor"], float)
rope_scale = 1 / args.rope_scaling["factor"]

View File

@@ -11,9 +11,7 @@ class SuScaledRotaryEmbedding(nn.Module):
def __init__(
self,
dims: int,
traditional: bool = False,
base: float = 10000.0,
scale: float = 1.0,
max_position_embeddings: int = 131072,
original_max_position_embeddings: int = 4096,
short_factor: Union[List[float], float] = 1.0,
@@ -24,10 +22,7 @@ class SuScaledRotaryEmbedding(nn.Module):
Args:
dims (int): The feature dimensions to be rotated.
traditional (bool, optional): Unused. Default: ``False``.
base (int, optional): Base for the exponential scaling.
scale (float, optional): The scale used to scale the positions.
Default: ``1.0``.
max_position_embeddings (int, optional): The maximum sequence
length that this model was trained with. This is used to determine
the size of the original RoPE embeddings when using long scaling.
@@ -44,14 +39,9 @@ class SuScaledRotaryEmbedding(nn.Module):
``original_max_position_embeddings``. Default: ``1.0``.
"""
super().__init__()
self._short_freqs = mx.array(short_factor, dtype=mx.float32) * base ** (
mx.arange(0, dims, 2, dtype=mx.float32) / dims
)
self._long_freqs = (
scale
* mx.array(long_factor, dtype=mx.float32)
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
)
freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
self._short_freqs = mx.array(short_factor, dtype=mx.float32) * freqs
self._long_freqs = mx.array(long_factor, dtype=mx.float32) * freqs
self.original_max_position_embeddings = original_max_position_embeddings
self.scale = math.sqrt(
1
@@ -66,11 +56,11 @@ class SuScaledRotaryEmbedding(nn.Module):
else self._short_freqs
)
return mx.fast.rope(
x,
self.scale * x,
x.shape[-1],
traditional=False,
base=None,
scale=self.scale,
scale=1.0,
offset=offset,
freqs=freqs,
)