Add theta cache for Rope and mask cache for ALiBi (#375)

This commit is contained in:
Anchen 2024-01-07 00:22:58 -08:00 committed by GitHub
parent c6d2878c1a
commit 48f6ca8c3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -25,8 +25,15 @@ class RoPE(Module):
base (float, optional): The base used to compute angular frequency for base (float, optional): The base used to compute angular frequency for
each dimension in the positional encodings. Default: ``10000``. each dimension in the positional encodings. Default: ``10000``.
scale (float, optional): The scale used to scale the positions. Default: ``1.0``. scale (float, optional): The scale used to scale the positions. Default: ``1.0``.
Attributes:
_cos_sin_theta_key (tuple): Cached key for the precomputed cosine and sine values.
_cos_sin_theta_value (tuple): Cached cosine and sine values.
""" """
_cos_sin_theta_key = None
_cos_sin_theta_value = None
def __init__( def __init__(
self, self,
dims: int, dims: int,
@ -86,8 +93,9 @@ class RoPE(Module):
return mx.reshape(rx, shape) return mx.reshape(rx, shape)
@staticmethod @classmethod
def create_cos_sin_theta( def create_cos_sin_theta(
cls,
N: int, N: int,
D: int, D: int,
offset: int = 0, offset: int = 0,
@ -95,11 +103,14 @@ class RoPE(Module):
scale: float = 1.0, scale: float = 1.0,
dtype=mx.float32, dtype=mx.float32,
): ):
D = D // 2 if (N, D, offset, base, scale, dtype) != cls._cos_sin_theta_key:
positions = mx.arange(offset, N, dtype=dtype) * scale D = D // 2
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D)) positions = mx.arange(offset, N, dtype=dtype) * scale
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
return mx.cos(theta), mx.sin(theta) theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
cls._cos_sin_theta_key = (N, D, offset, base, scale, dtype)
cls._cos_sin_theta_value = (mx.cos(theta), mx.sin(theta))
return cls._cos_sin_theta_value
class SinusoidalPositionalEncoding(Module): class SinusoidalPositionalEncoding(Module):
@ -163,22 +174,42 @@ class SinusoidalPositionalEncoding(Module):
class ALiBi(Module): class ALiBi(Module):
@staticmethod _alibi_mask_key = None
_alibi_mask = None
@classmethod
def create_alibi_matrix( def create_alibi_matrix(
cls,
q_sequence_length: int, q_sequence_length: int,
k_sequence_length: int, k_sequence_length: int,
num_heads: int, num_heads: int,
offset: int, offset: int,
dtype=mx.float32, dtype=mx.float32,
): ):
x1 = mx.arange(offset, q_sequence_length) if (
x2 = mx.arange(0, k_sequence_length) q_sequence_length,
distance_matrix = -mx.abs( k_sequence_length,
mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1)) num_heads,
) offset,
alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads) dtype,
alibi_mask = (distance_matrix * alibi_slope).astype(dtype) ) != cls._alibi_mask_key:
return alibi_mask x1 = mx.arange(offset, q_sequence_length)
x2 = mx.arange(0, k_sequence_length)
distance_matrix = -mx.abs(
mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1))
)
alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads)
alibi_mask = (distance_matrix * alibi_slope).astype(dtype)
cls._alibi_mask_key = (
q_sequence_length,
k_sequence_length,
num_heads,
offset,
dtype,
)
cls._alibi_mask = alibi_mask
return cls._alibi_mask
@staticmethod @staticmethod
def create_alibi_slope(num_heads): def create_alibi_slope(num_heads):
@ -196,4 +227,4 @@ class ALiBi(Module):
) )
if mask is not None: if mask is not None:
alibi_mask = alibi_mask + mask alibi_mask = alibi_mask + mask
return attention_scores + alibi_mask return attention_scores + alibi_mask