mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00
Add theta cache for Rope and mask cache for ALiBi (#375)
This commit is contained in:
parent
c6d2878c1a
commit
48f6ca8c3a
@ -25,8 +25,15 @@ class RoPE(Module):
|
|||||||
base (float, optional): The base used to compute angular frequency for
|
base (float, optional): The base used to compute angular frequency for
|
||||||
each dimension in the positional encodings. Default: ``10000``.
|
each dimension in the positional encodings. Default: ``10000``.
|
||||||
scale (float, optional): The scale used to scale the positions. Default: ``1.0``.
|
scale (float, optional): The scale used to scale the positions. Default: ``1.0``.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_cos_sin_theta_key (tuple): Cached key for the precomputed cosine and sine values.
|
||||||
|
_cos_sin_theta_value (tuple): Cached cosine and sine values.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_cos_sin_theta_key = None
|
||||||
|
_cos_sin_theta_value = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dims: int,
|
dims: int,
|
||||||
@ -86,8 +93,9 @@ class RoPE(Module):
|
|||||||
|
|
||||||
return mx.reshape(rx, shape)
|
return mx.reshape(rx, shape)
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def create_cos_sin_theta(
|
def create_cos_sin_theta(
|
||||||
|
cls,
|
||||||
N: int,
|
N: int,
|
||||||
D: int,
|
D: int,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
@ -95,11 +103,14 @@ class RoPE(Module):
|
|||||||
scale: float = 1.0,
|
scale: float = 1.0,
|
||||||
dtype=mx.float32,
|
dtype=mx.float32,
|
||||||
):
|
):
|
||||||
D = D // 2
|
if (N, D, offset, base, scale, dtype) != cls._cos_sin_theta_key:
|
||||||
positions = mx.arange(offset, N, dtype=dtype) * scale
|
D = D // 2
|
||||||
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
|
positions = mx.arange(offset, N, dtype=dtype) * scale
|
||||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
|
||||||
return mx.cos(theta), mx.sin(theta)
|
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
||||||
|
cls._cos_sin_theta_key = (N, D, offset, base, scale, dtype)
|
||||||
|
cls._cos_sin_theta_value = (mx.cos(theta), mx.sin(theta))
|
||||||
|
return cls._cos_sin_theta_value
|
||||||
|
|
||||||
|
|
||||||
class SinusoidalPositionalEncoding(Module):
|
class SinusoidalPositionalEncoding(Module):
|
||||||
@ -163,22 +174,42 @@ class SinusoidalPositionalEncoding(Module):
|
|||||||
|
|
||||||
|
|
||||||
class ALiBi(Module):
|
class ALiBi(Module):
|
||||||
@staticmethod
|
_alibi_mask_key = None
|
||||||
|
_alibi_mask = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
def create_alibi_matrix(
|
def create_alibi_matrix(
|
||||||
|
cls,
|
||||||
q_sequence_length: int,
|
q_sequence_length: int,
|
||||||
k_sequence_length: int,
|
k_sequence_length: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
offset: int,
|
offset: int,
|
||||||
dtype=mx.float32,
|
dtype=mx.float32,
|
||||||
):
|
):
|
||||||
x1 = mx.arange(offset, q_sequence_length)
|
if (
|
||||||
x2 = mx.arange(0, k_sequence_length)
|
q_sequence_length,
|
||||||
distance_matrix = -mx.abs(
|
k_sequence_length,
|
||||||
mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1))
|
num_heads,
|
||||||
)
|
offset,
|
||||||
alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads)
|
dtype,
|
||||||
alibi_mask = (distance_matrix * alibi_slope).astype(dtype)
|
) != cls._alibi_mask_key:
|
||||||
return alibi_mask
|
x1 = mx.arange(offset, q_sequence_length)
|
||||||
|
x2 = mx.arange(0, k_sequence_length)
|
||||||
|
distance_matrix = -mx.abs(
|
||||||
|
mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1))
|
||||||
|
)
|
||||||
|
alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads)
|
||||||
|
alibi_mask = (distance_matrix * alibi_slope).astype(dtype)
|
||||||
|
cls._alibi_mask_key = (
|
||||||
|
q_sequence_length,
|
||||||
|
k_sequence_length,
|
||||||
|
num_heads,
|
||||||
|
offset,
|
||||||
|
dtype,
|
||||||
|
)
|
||||||
|
cls._alibi_mask = alibi_mask
|
||||||
|
|
||||||
|
return cls._alibi_mask
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_alibi_slope(num_heads):
|
def create_alibi_slope(num_heads):
|
||||||
@ -196,4 +227,4 @@ class ALiBi(Module):
|
|||||||
)
|
)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
alibi_mask = alibi_mask + mask
|
alibi_mask = alibi_mask + mask
|
||||||
return attention_scores + alibi_mask
|
return attention_scores + alibi_mask
|
Loading…
Reference in New Issue
Block a user