mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			199 lines
		
	
	
		
			5.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			199 lines
		
	
	
		
			5.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright © 2023 Apple Inc.
 | 
						|
 | 
						|
import math
 | 
						|
import time
 | 
						|
 | 
						|
import jax
 | 
						|
import jax.numpy as jnp
 | 
						|
from flax import linen as nn
 | 
						|
 | 
						|
 | 
						|
class RoPE(nn.Module):
 | 
						|
    dims: int
 | 
						|
    traditional: bool = False
 | 
						|
 | 
						|
    def _compute_rope(self, costheta, sintheta, x):
 | 
						|
        x1 = x[..., : self.dims // 2]
 | 
						|
        x2 = x[..., self.dims // 2 : self.dims]
 | 
						|
        rx1 = x1 * costheta - x2 * sintheta
 | 
						|
        rx2 = x1 * sintheta + x2 * costheta
 | 
						|
 | 
						|
        if self.dims < x.shape[-1]:
 | 
						|
            rx = jnp.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
 | 
						|
        else:
 | 
						|
            rx = jnp.concatenate([rx1, rx2], axis=-1)
 | 
						|
 | 
						|
        return rx
 | 
						|
 | 
						|
    def _compute_traditional_rope(self, costheta, sintheta, x):
 | 
						|
        x1 = x[..., ::2]
 | 
						|
        x2 = x[..., 1::2]
 | 
						|
        rx1 = x1 * costheta - x2 * sintheta
 | 
						|
        rx2 = x1 * sintheta + x2 * costheta
 | 
						|
 | 
						|
        if self.dims < x.shape[-1]:
 | 
						|
            raise NotImplementedError(
 | 
						|
                "RoPE doesn't implement partial traditional application"
 | 
						|
            )
 | 
						|
 | 
						|
        rx = jnp.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
 | 
						|
 | 
						|
        return rx
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def create_cos_sin_theta(
 | 
						|
        N: int,
 | 
						|
        D: int,
 | 
						|
        offset: int = 0,
 | 
						|
        base: float = 10000,
 | 
						|
        dtype=jnp.float32,
 | 
						|
    ):
 | 
						|
        D = D // 2
 | 
						|
        positions = jnp.arange(offset, N, dtype=dtype)
 | 
						|
        freqs = jnp.exp(-jnp.arange(0, D, dtype=dtype) * (math.log(base) / D))
 | 
						|
        theta = positions.reshape((-1, 1)) * freqs.reshape((1, -1))
 | 
						|
        costheta = jnp.cos(theta)
 | 
						|
        sintheta = jnp.sin(theta)
 | 
						|
 | 
						|
        return costheta, sintheta
 | 
						|
 | 
						|
    @nn.compact
 | 
						|
    def __call__(self, x, offset: int = 0):
 | 
						|
        shape = x.shape
 | 
						|
        x = x.reshape((-1, shape[-2], shape[-1]))
 | 
						|
        N = x.shape[1] + offset
 | 
						|
        costheta, sintheta = RoPE.create_cos_sin_theta(
 | 
						|
            N, self.dims, offset=offset, dtype=x.dtype
 | 
						|
        )
 | 
						|
 | 
						|
        rope = (
 | 
						|
            self._compute_traditional_rope if self.traditional else self._compute_rope
 | 
						|
        )
 | 
						|
        rx = rope(costheta, sintheta, x)
 | 
						|
 | 
						|
        return rx.reshape(shape)
 | 
						|
 | 
						|
 | 
						|
class LlamaAttention(nn.Module):
 | 
						|
    dims: int
 | 
						|
    num_heads: int
 | 
						|
    dtype: jnp.dtype
 | 
						|
 | 
						|
    def setup(self):
 | 
						|
        num_heads = self.num_heads
 | 
						|
        dims = self.dims
 | 
						|
 | 
						|
        self.rope = RoPE(dims // num_heads, True)
 | 
						|
        self.query_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
 | 
						|
        self.key_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
 | 
						|
        self.value_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
 | 
						|
        self.out_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
 | 
						|
 | 
						|
    def __call__(self, queries, keys, values, mask=None, cache=None):
 | 
						|
        queries = self.query_proj(queries)
 | 
						|
        keys = self.key_proj(keys)
 | 
						|
        values = self.value_proj(values)
 | 
						|
 | 
						|
        num_heads = self.num_heads
 | 
						|
        B, L, D = queries.shape
 | 
						|
        queries = queries.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
 | 
						|
        keys = keys.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
 | 
						|
        values = values.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
 | 
						|
 | 
						|
        if cache is not None:
 | 
						|
            key_cache, value_cache = cache
 | 
						|
            queries = self.rope(queries, offset=key_cache.shape[2])
 | 
						|
            keys = self.rope(keys, offset=key_cache.shape[2])
 | 
						|
            keys = jnp.concatenate([key_cache, keys], axis=2)
 | 
						|
            values = jnp.concatenate([value_cache, values], axis=2)
 | 
						|
        else:
 | 
						|
            queries = self.rope(queries)
 | 
						|
            keys = self.rope(keys)
 | 
						|
 | 
						|
        # Dimensions are [batch x num heads x sequence x hidden dim]
 | 
						|
        scale = math.sqrt(1 / queries.shape[-1])
 | 
						|
        scores = (queries * scale) @ keys.transpose((0, 1, 3, 2))
 | 
						|
        if mask is not None:
 | 
						|
            scores = scores + mask
 | 
						|
        scores = jax.nn.softmax(scores, axis=-1)
 | 
						|
        values_hat = (scores @ values).transpose((0, 2, 1, 3)).reshape((B, L, -1))
 | 
						|
 | 
						|
        return self.out_proj(values_hat), (keys, values)
 | 
						|
 | 
						|
 | 
						|
class LlamaEncoderLayer(nn.Module):
 | 
						|
    dims: int
 | 
						|
    mlp_dims: int
 | 
						|
    num_heads: int
 | 
						|
    dtype: jnp.dtype
 | 
						|
 | 
						|
    def setup(self):
 | 
						|
        dims = self.dims
 | 
						|
        mlp_dims = self.mlp_dims
 | 
						|
        num_heads = self.num_heads
 | 
						|
 | 
						|
        self.attention = LlamaAttention(dims, num_heads, dtype)
 | 
						|
 | 
						|
        self.norm1 = nn.RMSNorm(param_dtype=self.dtype)
 | 
						|
        self.norm2 = nn.RMSNorm(param_dtype=self.dtype)
 | 
						|
 | 
						|
        self.linear1 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
 | 
						|
        self.linear2 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
 | 
						|
        self.linear3 = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
 | 
						|
 | 
						|
    def __call__(self, x, mask=None, cache=None):
 | 
						|
        y = self.norm1(x)
 | 
						|
        y, cache = self.attention(y, y, y, mask, cache)
 | 
						|
        x = x + y
 | 
						|
 | 
						|
        y = self.norm2(x)
 | 
						|
        a = self.linear1(y)
 | 
						|
        b = self.linear2(y)
 | 
						|
        y = jax.nn.silu(a) * b
 | 
						|
        y = self.linear3(y)
 | 
						|
        x = x + y
 | 
						|
 | 
						|
        return x, cache
 | 
						|
 | 
						|
 | 
						|
def measure(model, x, cache):
 | 
						|
    for i in range(5):
 | 
						|
        y, c = model(x, mask=None, cache=cache)
 | 
						|
        jax.block_until_ready((y, c))
 | 
						|
 | 
						|
    start = time.time()
 | 
						|
    for i in range(5):
 | 
						|
        y, c = model(x, mask=None, cache=cache)
 | 
						|
        jax.block_until_ready((y, c))
 | 
						|
 | 
						|
    end = time.time()
 | 
						|
    return (end - start) * 1000 / 5
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    H = 32
 | 
						|
    D = 4096
 | 
						|
    F = 43 * 256
 | 
						|
    C = 1000
 | 
						|
    dtype = jnp.float16
 | 
						|
 | 
						|
    k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4)
 | 
						|
 | 
						|
    x = jax.random.normal(k1, (1, 1, D), dtype)
 | 
						|
    cache = [
 | 
						|
        jax.random.normal(k2, [1, H, C, D // H], dtype),
 | 
						|
        jax.random.normal(k3, [1, H, C, D // H], dtype),
 | 
						|
    ]
 | 
						|
 | 
						|
    layer = LlamaEncoderLayer(D, F, H, dtype=dtype)
 | 
						|
    params = layer.init(k4, x, mask=None, cache=cache)["params"]
 | 
						|
 | 
						|
    @jax.jit
 | 
						|
    def model_fn(x, mask, cache):
 | 
						|
        return layer.apply({"params": params}, x, mask=mask, cache=cache)
 | 
						|
 | 
						|
    T = measure(model_fn, x, cache)
 | 
						|
 | 
						|
    print("Time per layer per token:", T, "ms")
 | 
						|
    print("Lower bound total time per token:", T * 32, "ms")
 |