mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-31 10:58:07 +08:00 
			
		
		
		
	Rope theta to support Coda Llama (#121)
* rope theta for llama model * llama chat/code * nit
This commit is contained in:
		| @@ -3,7 +3,8 @@ | ||||
| An example of generating text with Llama (1 or 2) using MLX. | ||||
|  | ||||
| Llama is a set of open source language models from Meta AI Research[^1][^2] | ||||
| ranging from 7B to 70B parameters. | ||||
| ranging from 7B to 70B parameters. This example also supports Llama Chat and | ||||
| Code Llama. | ||||
|  | ||||
| ### Setup | ||||
|  | ||||
|   | ||||
| @@ -23,6 +23,7 @@ class ModelArgs: | ||||
|     n_kv_heads: int | ||||
|     norm_eps: float | ||||
|     vocab_size: int | ||||
|     rope_theta: float | ||||
|  | ||||
|  | ||||
| class RMSNorm(nn.Module): | ||||
| @@ -39,6 +40,27 @@ class RMSNorm(nn.Module): | ||||
|         return self.weight * output | ||||
|  | ||||
|  | ||||
| class RoPE(nn.RoPE): | ||||
|     def __init__(self, dims: int, traditional: bool = False, base: float = 10000): | ||||
|         super().__init__(dims, traditional) | ||||
|         self.base = base | ||||
|  | ||||
|     def __call__(self, x, offset: int = 0): | ||||
|         shape = x.shape | ||||
|         x = mx.reshape(x, (-1, shape[-2], shape[-1])) | ||||
|         N = x.shape[1] + offset | ||||
|         costheta, sintheta = RoPE.create_cos_sin_theta( | ||||
|             N, self.dims, offset=offset, base=self.base, dtype=x.dtype | ||||
|         ) | ||||
|  | ||||
|         rope = ( | ||||
|             self._compute_traditional_rope if self.traditional else self._compute_rope | ||||
|         ) | ||||
|         rx = rope(costheta, sintheta, x) | ||||
|  | ||||
|         return mx.reshape(rx, shape) | ||||
|  | ||||
|  | ||||
| class Attention(nn.Module): | ||||
|     def __init__(self, args: ModelArgs): | ||||
|         super().__init__() | ||||
| @@ -55,7 +77,7 @@ class Attention(nn.Module): | ||||
|         self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) | ||||
|         self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) | ||||
|         self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) | ||||
|         self.rope = nn.RoPE(args.head_dim, traditional=True) | ||||
|         self.rope = RoPE(args.head_dim, traditional=True, base=args.rope_theta) | ||||
|  | ||||
|     def __call__( | ||||
|         self, | ||||
| @@ -315,7 +337,9 @@ def load_model(model_path): | ||||
|             config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0] | ||||
|         if config.get("vocab_size", -1) < 0: | ||||
|             config["vocab_size"] = weights["output.weight"].shape[-1] | ||||
|         unused = ["multiple_of", "ffn_dim_multiplier", "rope_theta"] | ||||
|         if "rope_theta" not in config: | ||||
|             config["rope_theta"] = 10000 | ||||
|         unused = ["multiple_of", "ffn_dim_multiplier"] | ||||
|         for k in unused: | ||||
|             if k in config: | ||||
|                 config.pop(k) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun