mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 21:01:32 +08:00
fix
This commit is contained in:
@@ -59,19 +59,17 @@ class Attention(nn.Module):
|
||||
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
|
||||
rope_scale = 1.0
|
||||
if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
|
||||
self.rope = SuScaledRotaryEmbedding(
|
||||
head_dim,
|
||||
traditional=False,
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
max_position_embeddings=args.max_position_embeddings,
|
||||
original_max_position_embeddings=args.original_max_position_embeddings,
|
||||
short_factor=args.rope_scaling["short_factor"],
|
||||
long_factor=args.rope_scaling["long_factor"],
|
||||
)
|
||||
else:
|
||||
rope_scale = 1.0
|
||||
if args.rope_scaling and args.rope_scaling["type"] == "linear":
|
||||
assert isinstance(args.rope_scaling["factor"], float)
|
||||
rope_scale = 1 / args.rope_scaling["factor"]
|
||||
|
@@ -11,9 +11,7 @@ class SuScaledRotaryEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
traditional: bool = False,
|
||||
base: float = 10000.0,
|
||||
scale: float = 1.0,
|
||||
max_position_embeddings: int = 131072,
|
||||
original_max_position_embeddings: int = 4096,
|
||||
short_factor: Union[List[float], float] = 1.0,
|
||||
@@ -24,10 +22,7 @@ class SuScaledRotaryEmbedding(nn.Module):
|
||||
|
||||
Args:
|
||||
dims (int): The feature dimensions to be rotated.
|
||||
traditional (bool, optional): Unused. Default: ``False``.
|
||||
base (int, optional): Base for the exponential scaling.
|
||||
scale (float, optional): The scale used to scale the positions.
|
||||
Default: ``1.0``.
|
||||
max_position_embeddings (int, optional): The maximum sequence
|
||||
length that this model was trained with. This is used to determine
|
||||
the size of the original RoPE embeddings when using long scaling.
|
||||
@@ -44,14 +39,9 @@ class SuScaledRotaryEmbedding(nn.Module):
|
||||
``original_max_position_embeddings``. Default: ``1.0``.
|
||||
"""
|
||||
super().__init__()
|
||||
self._short_freqs = mx.array(short_factor, dtype=mx.float32) * base ** (
|
||||
mx.arange(0, dims, 2, dtype=mx.float32) / dims
|
||||
)
|
||||
self._long_freqs = (
|
||||
scale
|
||||
* mx.array(long_factor, dtype=mx.float32)
|
||||
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
|
||||
)
|
||||
freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
|
||||
self._short_freqs = mx.array(short_factor, dtype=mx.float32) * freqs
|
||||
self._long_freqs = mx.array(long_factor, dtype=mx.float32) * freqs
|
||||
self.original_max_position_embeddings = original_max_position_embeddings
|
||||
self.scale = math.sqrt(
|
||||
1
|
||||
@@ -66,11 +56,11 @@ class SuScaledRotaryEmbedding(nn.Module):
|
||||
else self._short_freqs
|
||||
)
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
self.scale * x,
|
||||
x.shape[-1],
|
||||
traditional=False,
|
||||
base=None,
|
||||
scale=self.scale,
|
||||
scale=1.0,
|
||||
offset=offset,
|
||||
freqs=freqs,
|
||||
)
|
||||
|
Reference in New Issue
Block a user