Rope theta to support Coda Llama (#121)

* rope theta for llama model

* llama chat/code

* nit
This commit is contained in:
Awni Hannun 2023-12-15 19:51:51 -08:00 committed by GitHub
parent db134d976d
commit 08e862336a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 4 deletions

View File

@ -3,7 +3,8 @@
An example of generating text with Llama (1 or 2) using MLX. An example of generating text with Llama (1 or 2) using MLX.
Llama is a set of open source language models from Meta AI Research[^1][^2] Llama is a set of open source language models from Meta AI Research[^1][^2]
ranging from 7B to 70B parameters. ranging from 7B to 70B parameters. This example also supports Llama Chat and
Code Llama.
### Setup ### Setup

View File

@ -23,6 +23,7 @@ class ModelArgs:
n_kv_heads: int n_kv_heads: int
norm_eps: float norm_eps: float
vocab_size: int vocab_size: int
rope_theta: float
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
@ -39,6 +40,27 @@ class RMSNorm(nn.Module):
return self.weight * output return self.weight * output
class RoPE(nn.RoPE):
def __init__(self, dims: int, traditional: bool = False, base: float = 10000):
super().__init__(dims, traditional)
self.base = base
def __call__(self, x, offset: int = 0):
shape = x.shape
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, base=self.base, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return mx.reshape(rx, shape)
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
@ -55,7 +77,7 @@ class Attention(nn.Module):
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
self.rope = nn.RoPE(args.head_dim, traditional=True) self.rope = RoPE(args.head_dim, traditional=True, base=args.rope_theta)
def __call__( def __call__(
self, self,
@ -315,7 +337,9 @@ def load_model(model_path):
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0] config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
if config.get("vocab_size", -1) < 0: if config.get("vocab_size", -1) < 0:
config["vocab_size"] = weights["output.weight"].shape[-1] config["vocab_size"] = weights["output.weight"].shape[-1]
unused = ["multiple_of", "ffn_dim_multiplier", "rope_theta"] if "rope_theta" not in config:
config["rope_theta"] = 10000
unused = ["multiple_of", "ffn_dim_multiplier"]
for k in unused: for k in unused:
if k in config: if k in config:
config.pop(k) config.pop(k)

View File

@ -206,7 +206,10 @@ if __name__ == "__main__":
if (len(tokens) % 10) == 0: if (len(tokens) % 10) == 0:
mx.eval(tokens) mx.eval(tokens)
eos_index = next((i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), None) eos_index = next(
(i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id),
None,
)
if eos_index is not None: if eos_index is not None:
tokens = tokens[:eos_index] tokens = tokens[:eos_index]