From 48f6ca8c3a8666052b72c8c68c705bce1880df3e Mon Sep 17 00:00:00 2001 From: Anchen Date: Sun, 7 Jan 2024 00:22:58 -0800 Subject: [PATCH] Add theta cache for Rope and mask cache for ALiBi (#375) --- python/mlx/nn/layers/positional_encoding.py | 63 +++++++++++++++------ 1 file changed, 47 insertions(+), 16 deletions(-) diff --git a/python/mlx/nn/layers/positional_encoding.py b/python/mlx/nn/layers/positional_encoding.py index 032fd0f92..7ed4a19d7 100644 --- a/python/mlx/nn/layers/positional_encoding.py +++ b/python/mlx/nn/layers/positional_encoding.py @@ -25,8 +25,15 @@ class RoPE(Module): base (float, optional): The base used to compute angular frequency for each dimension in the positional encodings. Default: ``10000``. scale (float, optional): The scale used to scale the positions. Default: ``1.0``. + + Attributes: + _cos_sin_theta_key (tuple): Cached key for the precomputed cosine and sine values. + _cos_sin_theta_value (tuple): Cached cosine and sine values. """ + _cos_sin_theta_key = None + _cos_sin_theta_value = None + def __init__( self, dims: int, @@ -86,8 +93,9 @@ class RoPE(Module): return mx.reshape(rx, shape) - @staticmethod + @classmethod def create_cos_sin_theta( + cls, N: int, D: int, offset: int = 0, @@ -95,11 +103,14 @@ class RoPE(Module): scale: float = 1.0, dtype=mx.float32, ): - D = D // 2 - positions = mx.arange(offset, N, dtype=dtype) * scale - freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D)) - theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) - return mx.cos(theta), mx.sin(theta) + if (N, D, offset, base, scale, dtype) != cls._cos_sin_theta_key: + D = D // 2 + positions = mx.arange(offset, N, dtype=dtype) * scale + freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D)) + theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) + cls._cos_sin_theta_key = (N, D, offset, base, scale, dtype) + cls._cos_sin_theta_value = (mx.cos(theta), mx.sin(theta)) + return cls._cos_sin_theta_value class SinusoidalPositionalEncoding(Module): @@ -163,22 +174,42 @@ class SinusoidalPositionalEncoding(Module): class ALiBi(Module): - @staticmethod + _alibi_mask_key = None + _alibi_mask = None + + @classmethod def create_alibi_matrix( + cls, q_sequence_length: int, k_sequence_length: int, num_heads: int, offset: int, dtype=mx.float32, ): - x1 = mx.arange(offset, q_sequence_length) - x2 = mx.arange(0, k_sequence_length) - distance_matrix = -mx.abs( - mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1)) - ) - alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads) - alibi_mask = (distance_matrix * alibi_slope).astype(dtype) - return alibi_mask + if ( + q_sequence_length, + k_sequence_length, + num_heads, + offset, + dtype, + ) != cls._alibi_mask_key: + x1 = mx.arange(offset, q_sequence_length) + x2 = mx.arange(0, k_sequence_length) + distance_matrix = -mx.abs( + mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1)) + ) + alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads) + alibi_mask = (distance_matrix * alibi_slope).astype(dtype) + cls._alibi_mask_key = ( + q_sequence_length, + k_sequence_length, + num_heads, + offset, + dtype, + ) + cls._alibi_mask = alibi_mask + + return cls._alibi_mask @staticmethod def create_alibi_slope(num_heads): @@ -196,4 +227,4 @@ class ALiBi(Module): ) if mask is not None: alibi_mask = alibi_mask + mask - return attention_scores + alibi_mask + return attention_scores + alibi_mask \ No newline at end of file