From 08e862336ade809bc37d1035f94b359e7d1a5152 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 15 Dec 2023 19:51:51 -0800 Subject: [PATCH] Rope theta to support Coda Llama (#121) * rope theta for llama model * llama chat/code * nit --- llama/README.md | 3 ++- llama/llama.py | 28 ++++++++++++++++++++++++++-- phi2/phi2.py | 5 ++++- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/llama/README.md b/llama/README.md index 3ad882de..220d1b16 100644 --- a/llama/README.md +++ b/llama/README.md @@ -3,7 +3,8 @@ An example of generating text with Llama (1 or 2) using MLX. Llama is a set of open source language models from Meta AI Research[^1][^2] -ranging from 7B to 70B parameters. +ranging from 7B to 70B parameters. This example also supports Llama Chat and +Code Llama. ### Setup diff --git a/llama/llama.py b/llama/llama.py index 73eb39c5..5f169de4 100644 --- a/llama/llama.py +++ b/llama/llama.py @@ -23,6 +23,7 @@ class ModelArgs: n_kv_heads: int norm_eps: float vocab_size: int + rope_theta: float class RMSNorm(nn.Module): @@ -39,6 +40,27 @@ class RMSNorm(nn.Module): return self.weight * output +class RoPE(nn.RoPE): + def __init__(self, dims: int, traditional: bool = False, base: float = 10000): + super().__init__(dims, traditional) + self.base = base + + def __call__(self, x, offset: int = 0): + shape = x.shape + x = mx.reshape(x, (-1, shape[-2], shape[-1])) + N = x.shape[1] + offset + costheta, sintheta = RoPE.create_cos_sin_theta( + N, self.dims, offset=offset, base=self.base, dtype=x.dtype + ) + + rope = ( + self._compute_traditional_rope if self.traditional else self._compute_rope + ) + rx = rope(costheta, sintheta, x) + + return mx.reshape(rx, shape) + + class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -55,7 +77,7 @@ class Attention(nn.Module): self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) - self.rope = nn.RoPE(args.head_dim, traditional=True) + self.rope = RoPE(args.head_dim, traditional=True, base=args.rope_theta) def __call__( self, @@ -315,7 +337,9 @@ def load_model(model_path): config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0] if config.get("vocab_size", -1) < 0: config["vocab_size"] = weights["output.weight"].shape[-1] - unused = ["multiple_of", "ffn_dim_multiplier", "rope_theta"] + if "rope_theta" not in config: + config["rope_theta"] = 10000 + unused = ["multiple_of", "ffn_dim_multiplier"] for k in unused: if k in config: config.pop(k) diff --git a/phi2/phi2.py b/phi2/phi2.py index 78885eb8..555ee232 100644 --- a/phi2/phi2.py +++ b/phi2/phi2.py @@ -206,7 +206,10 @@ if __name__ == "__main__": if (len(tokens) % 10) == 0: mx.eval(tokens) - eos_index = next((i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), None) + eos_index = next( + (i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), + None, + ) if eos_index is not None: tokens = tokens[:eos_index]