mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 21:01:32 +08:00
Merge branch 'ml-explore:main' into adding-full-finetuning
This commit is contained in:
@@ -38,7 +38,9 @@ To see a description of all the arguments you can do:
|
||||
>>> help(generate)
|
||||
```
|
||||
|
||||
Check out the [generation example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py) to see how to use the API in more detail.
|
||||
Check out the [generation
|
||||
example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py)
|
||||
to see how to use the API in more detail.
|
||||
|
||||
The `mlx-lm` package also comes with functionality to quantize and optionally
|
||||
upload models to the Hugging Face Hub.
|
||||
@@ -122,10 +124,44 @@ mlx_lm.convert \
|
||||
--upload-repo mlx-community/my-4bit-mistral
|
||||
```
|
||||
|
||||
### Long Prompts and Generations
|
||||
|
||||
MLX LM has some tools to scale efficiently to long prompts and generations:
|
||||
|
||||
- A rotating fixed-size key-value cache.
|
||||
- Prompt caching
|
||||
|
||||
To use the rotating key-value cache pass the argument `--max-kv-size n` where
|
||||
`n` can be any integer. Smaller values like `512` will use very little RAM but
|
||||
result in worse quality. Larger values like `4096` or higher will use more RAM
|
||||
but have better quality.
|
||||
|
||||
Caching prompts can substantially speedup reusing the same long context with
|
||||
different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example:
|
||||
|
||||
```bash
|
||||
cat prompt.txt | mlx_lm.cache_prompt \
|
||||
--model mistralai/Mistral-7B-Instruct-v0.3 \
|
||||
--prompt - \
|
||||
--kv-cache-file mistral_prompt.safetensors
|
||||
```
|
||||
|
||||
Then use the cached prompt with `mlx_lm.generate`:
|
||||
|
||||
```
|
||||
mlx_lm.generate \
|
||||
--kv-cache-file mistral_prompt.safetensors \
|
||||
--prompt "\nSummarize the above text."
|
||||
```
|
||||
|
||||
The cached prompt is treated as a prefix to the supplied prompt. Also notice
|
||||
when using a cached prompt, the model to use is read from the cache and need
|
||||
not be supplied explicitly.
|
||||
|
||||
### Supported Models
|
||||
|
||||
The example supports Hugging Face format Mistral, Llama, and Phi-2 style
|
||||
models. If the model you want to run is not supported, file an
|
||||
MLX LM supports thousands of Hugging Face format LLMs. If the model you want to
|
||||
run is not supported, file an
|
||||
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
|
||||
submit a pull request.
|
||||
|
||||
|
@@ -56,7 +56,7 @@ def setup_arg_parser():
|
||||
parser.add_argument(
|
||||
"--max-kv-size",
|
||||
type=int,
|
||||
default=1024,
|
||||
default=None,
|
||||
help="Set the maximum key-value cache size",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -147,3 +147,7 @@ def main():
|
||||
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
||||
metadata["max_kv_size"] = str(args.max_kv_size)
|
||||
mx.save_safetensors(args.kv_cache_file, cache_dict, metadata)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@@ -12,7 +12,6 @@ DEFAULT_MAX_TOKENS = 100
|
||||
DEFAULT_TEMP = 0.6
|
||||
DEFAULT_TOP_P = 1.0
|
||||
DEFAULT_SEED = 0
|
||||
DEFAULT_MAX_KV_SIZE = 1024
|
||||
|
||||
|
||||
def setup_arg_parser():
|
||||
@@ -81,6 +80,7 @@ def setup_arg_parser():
|
||||
"--max-kv-size",
|
||||
type=int,
|
||||
help="Set the maximum key-value cache size",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-cache-file",
|
||||
@@ -171,7 +171,7 @@ def main():
|
||||
if args.use_default_chat_template:
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
elif tokenizer.chat_template is None:
|
||||
elif cache_history is not None:
|
||||
tokenizer.chat_template = metadata["chat_template"]
|
||||
|
||||
if not args.ignore_chat_template and (
|
||||
@@ -199,12 +199,9 @@ def main():
|
||||
|
||||
# Determine the max kv size from the kv cache or passed arguments
|
||||
max_kv_size = args.max_kv_size
|
||||
if max_kv_size is None:
|
||||
max_kv_size = (
|
||||
int(metadata["max_kv_size"])
|
||||
if cache_history is not None
|
||||
else DEFAULT_MAX_KV_SIZE
|
||||
)
|
||||
if cache_history is not None:
|
||||
max_kv_size = metadata["max_kv_size"]
|
||||
max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None
|
||||
|
||||
generate(
|
||||
model,
|
||||
|
227
llms/mlx_lm/models/nemotron.py
Normal file
227
llms/mlx_lm/models/nemotron.py
Normal file
@@ -0,0 +1,227 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
hidden_act: str
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int
|
||||
head_dim: Optional[int] = None
|
||||
max_position_embeddings: Optional[int] = None
|
||||
attention_bias: bool = False
|
||||
mlp_bias: bool = False
|
||||
partial_rotary_factor: float = 0.5
|
||||
rope_theta: float = 10000.0
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.rope_scaling:
|
||||
if not "factor" in self.rope_scaling:
|
||||
raise ValueError(f"rope_scaling must contain 'factor'")
|
||||
rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
|
||||
"rope_type"
|
||||
)
|
||||
if rope_type is None:
|
||||
raise ValueError(
|
||||
f"rope_scaling must contain either 'type' or 'rope_type'"
|
||||
)
|
||||
if rope_type not in ["linear"]:
|
||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def relu_squared(x):
|
||||
return nn.relu(x).square()
|
||||
|
||||
|
||||
class NemotronLayerNorm1P(nn.LayerNorm):
|
||||
def __call__(self, x):
|
||||
weight = self.weight + 1 if "weight" in self else None
|
||||
bias = self.bias if "bias" in self else None
|
||||
return mx.fast.layer_norm(x, weight, bias, self.eps)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
|
||||
self.partial_rotary_factor = args.partial_rotary_factor
|
||||
|
||||
self.scale = head_dim**-0.5
|
||||
if hasattr(args, "attention_bias"):
|
||||
attention_bias = args.attention_bias
|
||||
else:
|
||||
attention_bias = False
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
||||
|
||||
rope_scale = 1.0
|
||||
if args.rope_scaling and args.rope_scaling["type"] == "linear":
|
||||
assert isinstance(args.rope_scaling["factor"], float)
|
||||
rope_scale = 1 / args.rope_scaling["factor"]
|
||||
self.rope = nn.RoPE(
|
||||
int(self.partial_rotary_factor * self.head_dim),
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
B, L, _ = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
hidden_dim = args.intermediate_size
|
||||
mlp_bias = args.mlp_bias
|
||||
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(relu_squared(self.up_proj(x)))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.input_layernorm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps)
|
||||
self.post_attention_layernorm = NemotronLayerNorm1P(
|
||||
args.hidden_size, eps=args.norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class NemotronModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = NemotronModel(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return (
|
||||
self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
|
||||
)
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
@@ -96,6 +96,7 @@ def linear_to_lora_layers(
|
||||
"llama",
|
||||
"phi",
|
||||
"mixtral",
|
||||
"nemotron",
|
||||
"stablelm",
|
||||
"qwen2",
|
||||
"qwen2_moe",
|
||||
|
@@ -1,3 +1,3 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
__version__ = "0.17.1"
|
||||
__version__ = "0.18.1"
|
||||
|
Reference in New Issue
Block a user