mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Rope theta to support Coda Llama (#121)
* rope theta for llama model * llama chat/code * nit
This commit is contained in:
parent
db134d976d
commit
08e862336a
@ -3,7 +3,8 @@
|
|||||||
An example of generating text with Llama (1 or 2) using MLX.
|
An example of generating text with Llama (1 or 2) using MLX.
|
||||||
|
|
||||||
Llama is a set of open source language models from Meta AI Research[^1][^2]
|
Llama is a set of open source language models from Meta AI Research[^1][^2]
|
||||||
ranging from 7B to 70B parameters.
|
ranging from 7B to 70B parameters. This example also supports Llama Chat and
|
||||||
|
Code Llama.
|
||||||
|
|
||||||
### Setup
|
### Setup
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ class ModelArgs:
|
|||||||
n_kv_heads: int
|
n_kv_heads: int
|
||||||
norm_eps: float
|
norm_eps: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
|
rope_theta: float
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
@ -39,6 +40,27 @@ class RMSNorm(nn.Module):
|
|||||||
return self.weight * output
|
return self.weight * output
|
||||||
|
|
||||||
|
|
||||||
|
class RoPE(nn.RoPE):
|
||||||
|
def __init__(self, dims: int, traditional: bool = False, base: float = 10000):
|
||||||
|
super().__init__(dims, traditional)
|
||||||
|
self.base = base
|
||||||
|
|
||||||
|
def __call__(self, x, offset: int = 0):
|
||||||
|
shape = x.shape
|
||||||
|
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
||||||
|
N = x.shape[1] + offset
|
||||||
|
costheta, sintheta = RoPE.create_cos_sin_theta(
|
||||||
|
N, self.dims, offset=offset, base=self.base, dtype=x.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
rope = (
|
||||||
|
self._compute_traditional_rope if self.traditional else self._compute_rope
|
||||||
|
)
|
||||||
|
rx = rope(costheta, sintheta, x)
|
||||||
|
|
||||||
|
return mx.reshape(rx, shape)
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -55,7 +77,7 @@ class Attention(nn.Module):
|
|||||||
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||||
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||||
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
||||||
self.rope = nn.RoPE(args.head_dim, traditional=True)
|
self.rope = RoPE(args.head_dim, traditional=True, base=args.rope_theta)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -315,7 +337,9 @@ def load_model(model_path):
|
|||||||
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
|
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
|
||||||
if config.get("vocab_size", -1) < 0:
|
if config.get("vocab_size", -1) < 0:
|
||||||
config["vocab_size"] = weights["output.weight"].shape[-1]
|
config["vocab_size"] = weights["output.weight"].shape[-1]
|
||||||
unused = ["multiple_of", "ffn_dim_multiplier", "rope_theta"]
|
if "rope_theta" not in config:
|
||||||
|
config["rope_theta"] = 10000
|
||||||
|
unused = ["multiple_of", "ffn_dim_multiplier"]
|
||||||
for k in unused:
|
for k in unused:
|
||||||
if k in config:
|
if k in config:
|
||||||
config.pop(k)
|
config.pop(k)
|
||||||
|
@ -206,7 +206,10 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if (len(tokens) % 10) == 0:
|
if (len(tokens) % 10) == 0:
|
||||||
mx.eval(tokens)
|
mx.eval(tokens)
|
||||||
eos_index = next((i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), None)
|
eos_index = next(
|
||||||
|
(i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
if eos_index is not None:
|
if eos_index is not None:
|
||||||
tokens = tokens[:eos_index]
|
tokens = tokens[:eos_index]
|
||||||
|
Loading…
Reference in New Issue
Block a user