From 48f6ca8c3a8666052b72c8c68c705bce1880df3e Mon Sep 17 00:00:00 2001
From: Anchen
Date: Sun, 7 Jan 2024 00:22:58 -0800
Subject: [PATCH] Add theta cache for Rope and mask cache for ALiBi (#375)
---
python/mlx/nn/layers/positional_encoding.py | 63 +++++++++++++++------
1 file changed, 47 insertions(+), 16 deletions(-)
diff --git a/python/mlx/nn/layers/positional_encoding.py b/python/mlx/nn/layers/positional_encoding.py
index 032fd0f92..7ed4a19d7 100644
--- a/python/mlx/nn/layers/positional_encoding.py
+++ b/python/mlx/nn/layers/positional_encoding.py
@@ -25,8 +25,15 @@ class RoPE(Module):
base (float, optional): The base used to compute angular frequency for
each dimension in the positional encodings. Default: ``10000``.
scale (float, optional): The scale used to scale the positions. Default: ``1.0``.
+
+ Attributes:
+ _cos_sin_theta_key (tuple): Cached key for the precomputed cosine and sine values.
+ _cos_sin_theta_value (tuple): Cached cosine and sine values.
"""
+ _cos_sin_theta_key = None
+ _cos_sin_theta_value = None
+
def __init__(
self,
dims: int,
@@ -86,8 +93,9 @@ class RoPE(Module):
return mx.reshape(rx, shape)
- @staticmethod
+ @classmethod
def create_cos_sin_theta(
+ cls,
N: int,
D: int,
offset: int = 0,
@@ -95,11 +103,14 @@ class RoPE(Module):
scale: float = 1.0,
dtype=mx.float32,
):
- D = D // 2
- positions = mx.arange(offset, N, dtype=dtype) * scale
- freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
- theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
- return mx.cos(theta), mx.sin(theta)
+ if (N, D, offset, base, scale, dtype) != cls._cos_sin_theta_key:
+ D = D // 2
+ positions = mx.arange(offset, N, dtype=dtype) * scale
+ freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
+ theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
+ cls._cos_sin_theta_key = (N, D, offset, base, scale, dtype)
+ cls._cos_sin_theta_value = (mx.cos(theta), mx.sin(theta))
+ return cls._cos_sin_theta_value
class SinusoidalPositionalEncoding(Module):
@@ -163,22 +174,42 @@ class SinusoidalPositionalEncoding(Module):
class ALiBi(Module):
- @staticmethod
+ _alibi_mask_key = None
+ _alibi_mask = None
+
+ @classmethod
def create_alibi_matrix(
+ cls,
q_sequence_length: int,
k_sequence_length: int,
num_heads: int,
offset: int,
dtype=mx.float32,
):
- x1 = mx.arange(offset, q_sequence_length)
- x2 = mx.arange(0, k_sequence_length)
- distance_matrix = -mx.abs(
- mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1))
- )
- alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads)
- alibi_mask = (distance_matrix * alibi_slope).astype(dtype)
- return alibi_mask
+ if (
+ q_sequence_length,
+ k_sequence_length,
+ num_heads,
+ offset,
+ dtype,
+ ) != cls._alibi_mask_key:
+ x1 = mx.arange(offset, q_sequence_length)
+ x2 = mx.arange(0, k_sequence_length)
+ distance_matrix = -mx.abs(
+ mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1))
+ )
+ alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads)
+ alibi_mask = (distance_matrix * alibi_slope).astype(dtype)
+ cls._alibi_mask_key = (
+ q_sequence_length,
+ k_sequence_length,
+ num_heads,
+ offset,
+ dtype,
+ )
+ cls._alibi_mask = alibi_mask
+
+ return cls._alibi_mask
@staticmethod
def create_alibi_slope(num_heads):
@@ -196,4 +227,4 @@ class ALiBi(Module):
)
if mask is not None:
alibi_mask = alibi_mask + mask
- return attention_scores + alibi_mask
+ return attention_scores + alibi_mask
\ No newline at end of file