From b1186e2a81c678087bcaab43202d040b126523c3 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 29 Aug 2024 15:05:17 -0700 Subject: [PATCH 01/23] Docs on prompt scaling (#963) * docs on prompt scaling * remove unused var * nits --- llms/README.md | 42 ++++++++++++++++++++++++++++++++++--- llms/mlx_lm/cache_prompt.py | 6 +++++- llms/mlx_lm/generate.py | 11 ++++------ llms/mlx_lm/version.py | 2 +- 4 files changed, 49 insertions(+), 12 deletions(-) diff --git a/llms/README.md b/llms/README.md index 497c0277..79f26d41 100644 --- a/llms/README.md +++ b/llms/README.md @@ -38,7 +38,9 @@ To see a description of all the arguments you can do: >>> help(generate) ``` -Check out the [generation example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py) to see how to use the API in more detail. +Check out the [generation +example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py) +to see how to use the API in more detail. The `mlx-lm` package also comes with functionality to quantize and optionally upload models to the Hugging Face Hub. @@ -122,10 +124,44 @@ mlx_lm.convert \ --upload-repo mlx-community/my-4bit-mistral ``` +### Long Prompts and Generations + +MLX LM has some tools to scale efficiently to long prompts and generations: + +- A rotating fixed-size key-value cache. +- Prompt caching + +To use the rotating key-value cache pass the argument `--max-kv-size n` where +`n` can be any integer. Smaller values like `512` will use very little RAM but +result in worse quality. Larger values like `4096` or higher will use more RAM +but have better quality. + +Caching prompts can substantially speedup reusing the same long context with +different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example: + +```bash +cat prompt.txt | mlx_lm.cache_prompt \ + --model mistralai/Mistral-7B-Instruct-v0.3 \ + --prompt - \ + --kv-cache-file mistral_prompt.safetensors +``` + +Then use the cached prompt with `mlx_lm.generate`: + +``` +mlx_lm.generate \ + --kv-cache-file mistral_prompt.safetensors \ + --prompt "\nSummarize the above text." +``` + +The cached prompt is treated as a prefix to the supplied prompt. Also notice +when using a cached prompt, the model to use is read from the cache and need +not be supplied explicitly. + ### Supported Models -The example supports Hugging Face format Mistral, Llama, and Phi-2 style -models. If the model you want to run is not supported, file an +MLX LM supports thousands of Hugging Face format LLMs. If the model you want to +run is not supported, file an [issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet, submit a pull request. diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index ad045f1a..fe088118 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -56,7 +56,7 @@ def setup_arg_parser(): parser.add_argument( "--max-kv-size", type=int, - default=1024, + default=None, help="Set the maximum key-value cache size", ) parser.add_argument( @@ -147,3 +147,7 @@ def main(): metadata["tokenizer_config"] = json.dumps(tokenizer_config) metadata["max_kv_size"] = str(args.max_kv_size) mx.save_safetensors(args.kv_cache_file, cache_dict, metadata) + + +if __name__ == "__main__": + main() diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 4aa4001a..54f6f4d2 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -12,7 +12,6 @@ DEFAULT_MAX_TOKENS = 100 DEFAULT_TEMP = 0.6 DEFAULT_TOP_P = 1.0 DEFAULT_SEED = 0 -DEFAULT_MAX_KV_SIZE = 1024 def setup_arg_parser(): @@ -81,6 +80,7 @@ def setup_arg_parser(): "--max-kv-size", type=int, help="Set the maximum key-value cache size", + default=None, ) parser.add_argument( "--kv-cache-file", @@ -199,12 +199,9 @@ def main(): # Determine the max kv size from the kv cache or passed arguments max_kv_size = args.max_kv_size - if max_kv_size is None: - max_kv_size = ( - int(metadata["max_kv_size"]) - if cache_history is not None - else DEFAULT_MAX_KV_SIZE - ) + if cache_history is not None: + max_kv_size = metadata["max_kv_size"] + max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None generate( model, diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py index 41237905..87e86846 100644 --- a/llms/mlx_lm/version.py +++ b/llms/mlx_lm/version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.17.1" +__version__ = "0.18.0" From fc93c557238e9441835afe2748fd170016cb068b Mon Sep 17 00:00:00 2001 From: L Date: Thu, 29 Aug 2024 21:08:57 -0700 Subject: [PATCH 02/23] feat(mlx_lm): Nemotron (#949) * feat: Nemotron https://huggingface.co/nvidia/Minitron-4B-Base This is basically Llama with partial RoPE and LayerNorm instead of BatchNorm. Also they add 1 to the LayerNorm weight for some reason. * fixup! feat: Nemotron * nits --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/models/nemotron.py | 227 +++++++++++++++++++++++++++++++++ llms/mlx_lm/tuner/utils.py | 1 + 2 files changed, 228 insertions(+) create mode 100644 llms/mlx_lm/models/nemotron.py diff --git a/llms/mlx_lm/models/nemotron.py b/llms/mlx_lm/models/nemotron.py new file mode 100644 index 00000000..ef55d1d7 --- /dev/null +++ b/llms/mlx_lm/models/nemotron.py @@ -0,0 +1,227 @@ +# Copyright © 2024 Apple Inc. + +from dataclasses import dataclass +from functools import partial +from typing import Dict, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, KVCache, create_attention_mask + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + hidden_act: str + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + norm_eps: float + vocab_size: int + num_key_value_heads: int + head_dim: Optional[int] = None + max_position_embeddings: Optional[int] = None + attention_bias: bool = False + mlp_bias: bool = False + partial_rotary_factor: float = 0.5 + rope_theta: float = 10000.0 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = False + + def __post_init__(self): + if self.rope_scaling: + if not "factor" in self.rope_scaling: + raise ValueError(f"rope_scaling must contain 'factor'") + rope_type = self.rope_scaling.get("type") or self.rope_scaling.get( + "rope_type" + ) + if rope_type is None: + raise ValueError( + f"rope_scaling must contain either 'type' or 'rope_type'" + ) + if rope_type not in ["linear"]: + raise ValueError("rope_scaling 'type' currently only supports 'linear'") + + +@partial(mx.compile, shapeless=True) +def relu_squared(x): + return nn.relu(x).square() + + +class NemotronLayerNorm1P(nn.LayerNorm): + def __call__(self, x): + weight = self.weight + 1 if "weight" in self else None + bias = self.bias if "bias" in self else None + return mx.fast.layer_norm(x, weight, bias, self.eps) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads + self.partial_rotary_factor = args.partial_rotary_factor + + self.scale = head_dim**-0.5 + if hasattr(args, "attention_bias"): + attention_bias = args.attention_bias + else: + attention_bias = False + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) + + rope_scale = 1.0 + if args.rope_scaling and args.rope_scaling["type"] == "linear": + assert isinstance(args.rope_scaling["factor"], float) + rope_scale = 1 / args.rope_scaling["factor"] + self.rope = nn.RoPE( + int(self.partial_rotary_factor * self.head_dim), + base=args.rope_theta, + scale=rope_scale, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[KVCache] = None, + ) -> mx.array: + B, L, _ = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + hidden_dim = args.intermediate_size + mlp_bias = args.mlp_bias + + self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias) + self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) + + def __call__(self, x) -> mx.array: + return self.down_proj(relu_squared(self.up_proj(x))) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args) + self.mlp = MLP(args) + self.input_layernorm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps) + self.post_attention_layernorm = NemotronLayerNorm1P( + args.hidden_size, eps=args.norm_eps + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[KVCache] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out + + +class NemotronModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) + + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, cache=c) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = NemotronModel(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out = self.model(inputs, cache) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out + + @property + def layers(self): + return self.model.layers + + @property + def head_dim(self): + return ( + self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads + ) + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 1a54a925..71fbfaab 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -93,6 +93,7 @@ def linear_to_lora_layers( "llama", "phi", "mixtral", + "nemotron", "stablelm", "qwen2", "qwen2_moe", From 3c6e8b11af9dda55ee8d38ee9f612a65118f7793 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 30 Aug 2024 05:56:27 -0700 Subject: [PATCH 03/23] fix (#965) --- llms/mlx_lm/generate.py | 2 +- llms/mlx_lm/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 54f6f4d2..f37037b6 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -171,7 +171,7 @@ def main(): if args.use_default_chat_template: if tokenizer.chat_template is None: tokenizer.chat_template = tokenizer.default_chat_template - elif tokenizer.chat_template is None: + elif cache_history is not None: tokenizer.chat_template = metadata["chat_template"] if not args.ignore_chat_template and ( diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py index 87e86846..a2eb9a25 100644 --- a/llms/mlx_lm/version.py +++ b/llms/mlx_lm/version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.18.0" +__version__ = "0.18.1" From bf921afcbef89463c2b8a78c4008d993500fadb4 Mon Sep 17 00:00:00 2001 From: James Zhao Date: Tue, 3 Sep 2024 23:16:21 +0300 Subject: [PATCH 04/23] Make sure to import the correct "version" module when installing mlx_whisper and mlx_lm from local source code. (#969) * Make sure to import the correct "version" module when installing the mlx_whisper package from local source code. * Make sure to import the correct "version" module when installing the mlx_lm package from local source code * fix --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/__init__.py | 2 +- llms/mlx_lm/{version.py => _version.py} | 0 llms/setup.py | 2 +- whisper/mlx_whisper/__init__.py | 2 +- whisper/mlx_whisper/{version.py => _version.py} | 0 whisper/setup.py | 2 +- 6 files changed, 4 insertions(+), 4 deletions(-) rename llms/mlx_lm/{version.py => _version.py} (100%) rename whisper/mlx_whisper/{version.py => _version.py} (100%) diff --git a/llms/mlx_lm/__init__.py b/llms/mlx_lm/__init__.py index e971c467..502c78e5 100644 --- a/llms/mlx_lm/__init__.py +++ b/llms/mlx_lm/__init__.py @@ -1,4 +1,4 @@ # Copyright © 2023-2024 Apple Inc. +from ._version import __version__ from .utils import convert, generate, load, stream_generate -from .version import __version__ diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/_version.py similarity index 100% rename from llms/mlx_lm/version.py rename to llms/mlx_lm/_version.py diff --git a/llms/setup.py b/llms/setup.py index ac294ae1..e2cfe0cd 100644 --- a/llms/setup.py +++ b/llms/setup.py @@ -10,7 +10,7 @@ with open(package_dir / "requirements.txt") as fid: requirements = [l.strip() for l in fid.readlines()] sys.path.append(str(package_dir)) -from version import __version__ +from _version import __version__ setup( name="mlx-lm", diff --git a/whisper/mlx_whisper/__init__.py b/whisper/mlx_whisper/__init__.py index e6de0858..14c5197f 100644 --- a/whisper/mlx_whisper/__init__.py +++ b/whisper/mlx_whisper/__init__.py @@ -1,5 +1,5 @@ # Copyright © 2023-2024 Apple Inc. from . import audio, decoding, load_models +from ._version import __version__ from .transcribe import transcribe -from .version import __version__ diff --git a/whisper/mlx_whisper/version.py b/whisper/mlx_whisper/_version.py similarity index 100% rename from whisper/mlx_whisper/version.py rename to whisper/mlx_whisper/_version.py diff --git a/whisper/setup.py b/whisper/setup.py index 086f6471..0cabd64b 100644 --- a/whisper/setup.py +++ b/whisper/setup.py @@ -12,7 +12,7 @@ with open(package_dir / "requirements.txt") as fid: sys.path.append(str(package_dir)) -from version import __version__ +from _version import __version__ setup( name="mlx-whisper", From 83a209e200eef505a688997f64afdff2c6e0d363 Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Tue, 3 Sep 2024 16:29:10 -0400 Subject: [PATCH 05/23] Add prompt piping (#962) * Initial commit of --prompt-only and prompt from STDIN feature * Switch to using --verbose instead of --prompt-only * Fix capitalization typo * Fix reference to changed option name * Update exception text --- llms/mlx_lm/generate.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index f37037b6..537bd853 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -2,6 +2,7 @@ import argparse import json +import sys import mlx.core as mx @@ -14,6 +15,10 @@ DEFAULT_TOP_P = 1.0 DEFAULT_SEED = 0 +def str2bool(string): + return string.lower() not in ["false", "f"] + + def setup_arg_parser(): """Set up and return the argument parser.""" parser = argparse.ArgumentParser(description="LLM inference script") @@ -39,7 +44,9 @@ def setup_arg_parser(): help="End of sequence token for tokenizer", ) parser.add_argument( - "--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model" + "--prompt", + default=DEFAULT_PROMPT, + help="Message to be processed by the model ('-' reads from stdin)", ) parser.add_argument( "--max-tokens", @@ -65,6 +72,12 @@ def setup_arg_parser(): action="store_true", help="Use the default chat template", ) + parser.add_argument( + "--verbose", + type=str2bool, + default=True, + help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'", + ) parser.add_argument( "--colorize", action="store_true", @@ -178,7 +191,12 @@ def main(): hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None ): - messages = [{"role": "user", "content": args.prompt}] + messages = [ + { + "role": "user", + "content": sys.stdin.read() if args.prompt == "-" else args.prompt, + } + ] prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) @@ -195,6 +213,8 @@ def main(): else: prompt = args.prompt + if args.colorize and not args.verbose: + raise ValueError("Cannot use --colorize with --verbose=False") formatter = colorprint_by_t0 if args.colorize else None # Determine the max kv size from the kv cache or passed arguments @@ -203,18 +223,20 @@ def main(): max_kv_size = metadata["max_kv_size"] max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None - generate( + response = generate( model, tokenizer, prompt, args.max_tokens, - verbose=True, + verbose=args.verbose, formatter=formatter, temp=args.temp, top_p=args.top_p, max_kv_size=max_kv_size, cache_history=cache_history, ) + if not args.verbose: + print(response) if __name__ == "__main__": From bd29aec299c8fa59c161a9c1207bfc59db31d845 Mon Sep 17 00:00:00 2001 From: madroid Date: Wed, 4 Sep 2024 21:19:32 +0800 Subject: [PATCH 06/23] Support HuggingFace model tree (#957) * Hub: Update quantization configuration fields * Hub: add base_model metadata * Hub: add quantization_config for model tree Quantized type * Hub: update quantization_config value * Hub: remove config print --- llms/mlx_lm/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 71476df3..eee28c9c 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -560,6 +560,7 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str): card = ModelCard.load(hf_path) card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"] + card.data.base_model = hf_path card.text = dedent( f""" # {upload_repo} @@ -666,6 +667,8 @@ def quantize_model( quantized_config = copy.deepcopy(config) nn.quantize(model, q_group_size, q_bits) quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} + # support hf model tree #957 + quantized_config["quantization_config"] = quantized_config["quantization"] quantized_weights = dict(tree_flatten(model.parameters())) return quantized_weights, quantized_config From 324184d670ec11916a5e92314171d497b312eefe Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 6 Sep 2024 20:19:27 -0700 Subject: [PATCH 07/23] Fix the cache_prompt (#979) --- llms/mlx_lm/cache_prompt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index fe088118..9829efb4 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -139,8 +139,8 @@ def main(): print("Saving...") cache_dict = {} for i, c in enumerate(cache): - cache_dict[f"{i}_keys"] = c.state[0] - cache_dict[f"{i}_values"] = c.state[1] + cache_dict[f"{i}_keys"] = c.state[0][..., : c.offset, :] + cache_dict[f"{i}_values"] = c.state[1][..., : c.offset, :] metadata = {} metadata["model"] = args.model metadata["chat_template"] = tokenizer.chat_template From c3e3411756e098a3f5f29d988e9221034f1af47c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 7 Sep 2024 06:06:15 -0700 Subject: [PATCH 08/23] Update LLM generation docs to use chat template (#973) * fix docs * add template to model cards as well * revert * version --- llms/README.md | 14 +++++++++++++- llms/mlx_lm/_version.py | 2 +- llms/mlx_lm/utils.py | 11 ++++++++++- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/llms/README.md b/llms/README.md index 79f26d41..b8e1914d 100644 --- a/llms/README.md +++ b/llms/README.md @@ -29,7 +29,14 @@ from mlx_lm import load, generate model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit") -response = generate(model, tokenizer, prompt="hello", verbose=True) +prompt = "Write a story about Einstein" + +messages = [{"role": "user", "content": prompt}] +prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True +) + +response = generate(model, tokenizer, prompt=prompt, verbose=True) ``` To see a description of all the arguments you can do: @@ -79,6 +86,11 @@ model, tokenizer = load(repo) prompt = "Write a story about Einstein" +messages = [{"role": "user", "content": prompt}] +prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True +) + for t in stream_generate(model, tokenizer, prompt, max_tokens=512): print(t, end="", flush=True) print() diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index a2eb9a25..8110c823 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.18.1" +__version__ = "0.18.2" diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index eee28c9c..ad9b3221 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -577,7 +577,16 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str): from mlx_lm import load, generate model, tokenizer = load("{upload_repo}") - response = generate(model, tokenizer, prompt="hello", verbose=True) + + prompt="hello" + + if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None: + messages = [{"role": "user", "content": prompt}] + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + response = generate(model, tokenizer, prompt=prompt, verbose=True) ``` """ ) From 6c2369e4b97f49fb5906ec46033497b39931b25d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 7 Sep 2024 14:46:57 -0700 Subject: [PATCH 09/23] Fix bug in upload + docs nit (#981) * fix bug in upload + docs nit * nit --- llms/mlx_lm/LORA.md | 30 +++++++----------------------- llms/mlx_lm/utils.py | 2 +- 2 files changed, 8 insertions(+), 24 deletions(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 2e739d0f..2d9a2553 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -166,44 +166,28 @@ Currently, `*.jsonl` files support three data formats: `chat`, `chat`: ```jsonl -{ - "messages": [ - { - "role": "system", - "content": "You are a helpful assistant." - }, - { - "role": "user", - "content": "Hello." - }, - { - "role": "assistant", - "content": "How can I assistant you today." - } - ] -} +{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello."}, {"role": "assistant", "content": "How can I assistant you today."}]} ``` `completions`: ```jsonl -{ - "prompt": "What is the capital of France?", - "completion": "Paris." -} +{"prompt": "What is the capital of France?", "completion": "Paris."} ``` `text`: ```jsonl -{ - "text": "This is an example for the model." -} +{"text": "This is an example for the model."} ``` Note, the format is automatically determined by the dataset. Note also, keys in each line not expected by the loader will be ignored. +> [!NOTE] +> Each example in the datasets must be on a single line. Do not put more than +> one example per line and do not split an example accross multiple lines. + ### Hugging Face Datasets To use Hugging Face datasets, first install the `datasets` package: diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index ad9b3221..b4a2ea51 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -581,7 +581,7 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str): prompt="hello" if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None: - messages = [{"role": "user", "content": prompt}] + messages = [{{"role": "user", "content": prompt}}] prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) From f530f56df2738a54982c4541189a8c8d7cd94c44 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 17 Sep 2024 16:22:48 -0700 Subject: [PATCH 10/23] don't use internal exception (#990) --- llms/mlx_lm/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index b4a2ea51..5621609d 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -14,7 +14,6 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download -from huggingface_hub.utils._errors import RepositoryNotFoundError from mlx.utils import tree_flatten from transformers import PreTrainedTokenizer @@ -91,7 +90,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path ], ) ) - except RepositoryNotFoundError: + except: raise ModelNotFoundError( f"Model not found for path or HF repo: {path_or_hf_repo}.\n" "Please make sure you specified the local path or Hugging Face" From 796d5e40e4cce0e0d49d3b3b3c00957b31702fe0 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 20 Sep 2024 13:33:45 -0700 Subject: [PATCH 11/23] Fix export to gguf (#993) --- llms/mlx_lm/gguf.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/llms/mlx_lm/gguf.py b/llms/mlx_lm/gguf.py index 5d524580..241ac35a 100644 --- a/llms/mlx_lm/gguf.py +++ b/llms/mlx_lm/gguf.py @@ -67,7 +67,7 @@ class HfVocab: def get_token_type( self, token_id: int, token_text: bytes, special_ids: Set[int] ) -> TokenType: - if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text.encode("utf-8")): + if re.fullmatch(r"<0x[0-9A-Fa-f]{2}>", token_text): return TokenType.BYTE return TokenType.CONTROL if token_id in special_ids else TokenType.NORMAL @@ -77,9 +77,7 @@ class HfVocab: def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]: for text in self.added_tokens_list: if text in self.specials: - toktype = self.get_token_type( - self.specials[text], b"", self.special_ids - ) + toktype = self.get_token_type(self.specials[text], "", self.special_ids) score = self.get_token_score(self.specials[text]) else: toktype = TokenType.USER_DEFINED @@ -243,15 +241,18 @@ def prepare_metadata(config, vocab): metadata["tokenizer.ggml.tokens"] = tokens metadata["tokenizer.ggml.scores"] = mx.array(scores, dtype=mx.float32) metadata["tokenizer.ggml.token_type"] = mx.array(toktypes, dtype=mx.uint32) - metadata["tokenizer.ggml.bos_token_id"] = mx.array( - vocab.tokenizer.bos_token_id, dtype=mx.uint32 - ) - metadata["tokenizer.ggml.eos_token_id"] = mx.array( - vocab.tokenizer.eos_token_id, dtype=mx.uint32 - ) - metadata["tokenizer.ggml.unknown_token_id"] = mx.array( - vocab.tokenizer.unk_token_id, dtype=mx.uint32 - ) + if vocab.tokenizer.bos_token_id is not None: + metadata["tokenizer.ggml.bos_token_id"] = mx.array( + vocab.tokenizer.bos_token_id, dtype=mx.uint32 + ) + if vocab.tokenizer.eos_token_id is not None: + metadata["tokenizer.ggml.eos_token_id"] = mx.array( + vocab.tokenizer.eos_token_id, dtype=mx.uint32 + ) + if vocab.tokenizer.unk_token_id is not None: + metadata["tokenizer.ggml.unknown_token_id"] = mx.array( + vocab.tokenizer.unk_token_id, dtype=mx.uint32 + ) metadata = {k: v for k, v in metadata.items() if v is not None} return metadata From 9bb2dd62f350d9f0f1a8003e354431b5172831e5 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 23 Sep 2024 11:39:25 -0700 Subject: [PATCH 12/23] Encodec (#991) * initial encodec * works * nits * use fast group norm * fix for rnn layer * fix mlx version * use custom LSTM kernel * audio encodec * fix example, support batched inference * nits --- README.md | 1 + encodec/README.md | 83 ++++ encodec/benchmarks/bench_mx.py | 30 ++ encodec/benchmarks/bench_pt.py | 34 ++ encodec/convert.py | 213 +++++++++++ encodec/encodec.py | 671 +++++++++++++++++++++++++++++++++ encodec/example.py | 37 ++ encodec/requirements.txt | 3 + encodec/test.py | 66 ++++ encodec/utils.py | 129 +++++++ 10 files changed, 1267 insertions(+) create mode 100644 encodec/README.md create mode 100644 encodec/benchmarks/bench_mx.py create mode 100644 encodec/benchmarks/bench_pt.py create mode 100644 encodec/convert.py create mode 100644 encodec/encodec.py create mode 100644 encodec/example.py create mode 100644 encodec/requirements.txt create mode 100644 encodec/test.py create mode 100644 encodec/utils.py diff --git a/README.md b/README.md index 00e57803..bd180975 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ Some more useful examples are listed below. ### Audio Models - Speech recognition with [OpenAI's Whisper](whisper). +- Audio compression and generation with [Meta's EnCodec](encodec). ### Multimodal models diff --git a/encodec/README.md b/encodec/README.md new file mode 100644 index 00000000..3ab2793c --- /dev/null +++ b/encodec/README.md @@ -0,0 +1,83 @@ +# EnCodec + +An example of Meta's EnCodec model in MLX.[^1] EnCodec is used to compress and +generate audio. + +### Setup + +Install the requirements: + +``` +pip install -r requirements.txt +``` + +Optionally install FFmpeg and SciPy for loading and saving audio files, +respectively. + +Install [FFmpeg](https://ffmpeg.org/): + +``` +# on macOS using Homebrew (https://brew.sh/) +brew install ffmpeg +``` + +Install SciPy: + +``` +pip install scipy +``` + +### Example + +An example using the model: + +```python +import mlx.core as mx +from utils import load, load_audio, save_audio + +# Load the 48 KHz model and preprocessor. +model, processor = load("mlx-community/encodec-48khz-float32") + +# Load an audio file +audio = load_audio("path/to/aduio", model.sampling_rate, model.channels) + +# Preprocess the audio (this can also be a list of arrays for batched +# processing). +feats, mask = processor(audio) + +# Encode at the given bandwidth. A lower bandwidth results in more +# compression but lower reconstruction quality. +@mx.compile +def encode(feats, mask): + return model.encode(feats, mask, bandwidth=3) + +# Decode to reconstruct the audio +@mx.compile +def decode(codes, scales, mask): + return model.decode(codes, scales, mask) + + +codes, scales = encode(feats, mask) +reconstructed = decode(codes, scales, mask) + +# Trim any padding: +reconstructed = reconstructed[0, : len(audio)] + +# Save the audio as a wave file +save_audio("reconstructed.wav", reconstructed, model.sampling_rate) +``` + +The 24 KHz, 32 KHz, and 48 KHz MLX formatted models are available in the +[Hugging Face MLX Community](https://huggingface.co/collections/mlx-community/encodec-66e62334038300b07a43b164) +in several data types. + +### Optional + +To convert models, use the `convert.py` script. To see the options, run: + +```bash +python convert.py -h +``` + +[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2210.13438) and + [code](https://github.com/facebookresearch/encodec) for more details. diff --git a/encodec/benchmarks/bench_mx.py b/encodec/benchmarks/bench_mx.py new file mode 100644 index 00000000..2acd4b75 --- /dev/null +++ b/encodec/benchmarks/bench_mx.py @@ -0,0 +1,30 @@ +# Copyright © 2024 Apple Inc. + +import time + +import mlx.core as mx +from utils import load + +model, processor = load("mlx-community/encodec-48khz-float32") + +audio = mx.random.uniform(shape=(288000, 2)) +feats, mask = processor(audio) +mx.eval(model, feats, mask) + + +@mx.compile +def fun(): + codes, scales = model.encode(feats, mask, bandwidth=3) + reconstructed = model.decode(codes, scales, mask) + return reconstructed + + +for _ in range(5): + mx.eval(fun()) + +tic = time.time() +for _ in range(10): + mx.eval(fun()) +toc = time.time() +ms = 1000 * (toc - tic) / 10 +print(f"Time per it: {ms:.3f}") diff --git a/encodec/benchmarks/bench_pt.py b/encodec/benchmarks/bench_pt.py new file mode 100644 index 00000000..5d158a32 --- /dev/null +++ b/encodec/benchmarks/bench_pt.py @@ -0,0 +1,34 @@ +# Copyright © 2024 Apple Inc. + +import time + +import numpy as np +import torch +from transformers import AutoProcessor, EncodecModel + +processor = AutoProcessor.from_pretrained("facebook/encodec_48khz") +audio = np.random.uniform(size=(2, 288000)).astype(np.float32) + +pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz").to("mps") +pt_inputs = processor( + raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt" +).to("mps") + + +def fun(): + pt_encoded = pt_model.encode(pt_inputs["input_values"], pt_inputs["padding_mask"]) + pt_audio = pt_model.decode( + pt_encoded.audio_codes, pt_encoded.audio_scales, pt_inputs["padding_mask"] + ) + torch.mps.synchronize() + + +for _ in range(5): + fun() + +tic = time.time() +for _ in range(10): + fun() +toc = time.time() +ms = 1000 * (toc - tic) / 10 +print(f"Time per it: {ms:.3f}") diff --git a/encodec/convert.py b/encodec/convert.py new file mode 100644 index 00000000..13bd31a6 --- /dev/null +++ b/encodec/convert.py @@ -0,0 +1,213 @@ +# Copyright © 2024 Apple Inc. + +import argparse +import json +from pathlib import Path +from textwrap import dedent +from types import SimpleNamespace +from typing import Any, Dict, Union + +import mlx.core as mx +import mlx.nn as nn +from huggingface_hub import snapshot_download +from mlx.utils import tree_flatten + +import encodec + + +def fetch_from_hub(hf_repo: str) -> Path: + model_path = Path( + snapshot_download( + repo_id=hf_repo, + allow_patterns=["*.json", "*.safetensors"], + ) + ) + return model_path + + +def upload_to_hub(path: str, upload_repo: str, hf_path: str): + """ + Uploads the model to Hugging Face hub. + + Args: + path (str): Local path to the model. + upload_repo (str): Name of the HF repo to upload to. + hf_path (str): Path to the original Hugging Face model. + """ + import os + + from huggingface_hub import HfApi, ModelCard, logging + + content = dedent( + f""" + --- + language: en + license: other + library: mlx + tags: + - mlx + --- + + The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was + converted to MLX format from + [{hf_path}](https://huggingface.co/{hf_path}). + + This model is intended to be used with the [EnCodec MLX + example](https://github.com/ml-explore/mlx-examples/tree/main/encodec). + """ + ) + + card = ModelCard(content) + card.save(os.path.join(path, "README.md")) + + logging.set_verbosity_info() + + api = HfApi() + api.create_repo(repo_id=upload_repo, exist_ok=True) + api.upload_folder( + folder_path=path, + repo_id=upload_repo, + repo_type="model", + multi_commits=True, + multi_commits_verbose=True, + ) + print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.") + + +def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None: + if isinstance(save_path, str): + save_path = Path(save_path) + save_path.mkdir(parents=True, exist_ok=True) + + total_size = sum(v.nbytes for v in weights.values()) + index_data = {"metadata": {"total_size": total_size}, "weight_map": {}} + mx.save_safetensors( + str(save_path / "model.safetensors"), weights, metadata={"format": "mlx"} + ) + + for weight_name in weights.keys(): + index_data["weight_map"][weight_name] = "model.safetensors" + + index_data["weight_map"] = { + k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"]) + } + + with open(save_path / "model.safetensors.index.json", "w") as f: + json.dump(index_data, f, indent=4) + + +def save_config( + config: dict, + config_path: Union[str, Path], +) -> None: + """Save the model configuration to the ``config_path``. + + The final configuration will be sorted before saving for better readability. + + Args: + config (dict): The model configuration. + config_path (Union[str, Path]): Model configuration file path. + """ + # Clean unused keys + config.pop("_name_or_path", None) + + # sort the config for better readability + config = dict(sorted(config.items())) + + # write the updated config to the config_path (if provided) + with open(config_path, "w") as fid: + json.dump(config, fid, indent=4) + + +def convert( + upload: bool, + model: str, + dtype: str = None, +): + hf_repo = f"facebook/encodec_{model}" + mlx_repo = f"mlx-community/encodec-{model}-{dtype}" + path = fetch_from_hub(hf_repo) + save_path = Path("mlx_models") + + weights = mx.load(str(Path(path) / "model.safetensors")) + + with open(path / "config.json", "r") as fid: + config = SimpleNamespace(**json.load(fid)) + + model = encodec.EncodecModel(config) + + new_weights = {} + for k, v in weights.items(): + basename, pname = k.rsplit(".", 1) + if pname == "weight_v": + g = weights[basename + ".weight_g"] + v = g * (v / mx.linalg.norm(v, axis=(1, 2), keepdims=True)) + k = basename + ".weight" + elif pname in ["weight_g", "embed_avg", "cluster_size", "inited"]: + continue + elif "lstm" in basename: + w_or_b, ih_or_hh, ln = pname.split("_") + if w_or_b == "weight": + new_pname = "Wx" if ih_or_hh == "ih" else "Wh" + elif w_or_b == "bias" and ih_or_hh == "ih": + continue + else: + v = v + weights[k.replace("_hh_", "_ih_")] + new_pname = "bias" + k = basename + "." + ln[1:] + "." + new_pname + if "conv.weight" in k: + # Possibly a transposed conv which has a different order + if "decoder" in k: + ln = int(k.split(".")[2]) + if "conv" in model.decoder.layers[ln] and isinstance( + model.decoder.layers[ln].conv, nn.ConvTranspose1d + ): + v = mx.moveaxis(v, 0, 2) + else: + v = mx.moveaxis(v, 1, 2) + else: + v = mx.moveaxis(v, 1, 2) + + new_weights[k] = v + weights = new_weights + + model.load_weights(list(weights.items())) + + if dtype is not None: + t = getattr(mx, dtype) + weights = {k: v.astype(t) for k, v in weights.items()} + + if isinstance(save_path, str): + save_path = Path(save_path) + + save_weights(save_path, weights) + + save_config(vars(config), config_path=save_path / "config.json") + + if upload: + upload_to_hub(save_path, mlx_repo, hf_repo) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert EnCodec weights to MLX.") + parser.add_argument( + "--model", + type=str, + default="48khz", + help="", + choices=["24khz", "32khz", "48khz"], + ) + parser.add_argument( + "--upload", + action="store_true", + help="Upload the weights to Hugging Face.", + ) + parser.add_argument( + "--dtype", + type=str, + help="Data type to convert the model to.", + default="float32", + choices=["float32", "bfloat16", "float16"], + ) + args = parser.parse_args() + convert(upload=args.upload, model=args.model, dtype=args.dtype) diff --git a/encodec/encodec.py b/encodec/encodec.py new file mode 100644 index 00000000..3ef47369 --- /dev/null +++ b/encodec/encodec.py @@ -0,0 +1,671 @@ +# Copyright © 2024 Apple Inc. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +_lstm_kernel = mx.fast.metal_kernel( + name="lstm", + input_names=["x", "h_in", "cell", "hidden_size", "time_step", "num_time_steps"], + output_names=["hidden_state", "cell_state"], + header=""" + template + T sigmoid(T x) { + auto y = 1 / (1 + metal::exp(-metal::abs(x))); + return (x < 0) ? 1 - y : y; + } + """, + source=""" + uint b = thread_position_in_grid.x; + uint d = hidden_size * 4; + + uint elem = b * d + thread_position_in_grid.y; + uint index = elem; + uint x_index = b * num_time_steps * d + time_step * d + index; + + auto i = sigmoid(h_in[index] + x[x_index]); + index += hidden_size; + x_index += hidden_size; + auto f = sigmoid(h_in[index] + x[x_index]); + index += hidden_size; + x_index += hidden_size; + auto g = metal::precise::tanh(h_in[index] + x[x_index]); + index += hidden_size; + x_index += hidden_size; + auto o = sigmoid(h_in[index] + x[x_index]); + + cell_state[elem] = f * cell[elem] + i * g; + hidden_state[elem] = o * metal::precise::tanh(cell_state[elem]); + """, +) + + +def lstm_custom(x, h_in, cell, time_step): + assert x.ndim == 3, "Input to LSTM must have 3 dimensions." + out_shape = cell.shape + return _lstm_kernel( + inputs=[x, h_in, cell, out_shape[-1], time_step, x.shape[-2]], + output_shapes=[out_shape, out_shape], + output_dtypes=[h_in.dtype, h_in.dtype], + grid=(x.shape[0], h_in.size // 4, 1), + threadgroup=(256, 1, 1), + ) + + +class LSTM(nn.Module): + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + ): + super().__init__() + + self.hidden_size = hidden_size + self.Wx = mx.zeros((4 * hidden_size, input_size)) + self.Wh = mx.zeros((4 * hidden_size, hidden_size)) + self.bias = mx.zeros((4 * hidden_size,)) if bias else None + + def __call__(self, x, hidden=None, cell=None): + if self.bias is not None: + x = mx.addmm(self.bias, x, self.Wx.T) + else: + x = x @ self.Wx.T + + all_hidden = [] + + B = x.shape[0] + cell = cell or mx.zeros((B, self.hidden_size), x.dtype) + for t in range(x.shape[-2]): + if hidden is None: + hidden = mx.zeros((B, self.hidden_size * 4), x.dtype) + else: + hidden = hidden @ self.Wh.T + hidden, cell = lstm_custom(x, hidden, cell, t) + all_hidden.append(hidden) + + return mx.stack(all_hidden, axis=-2) + + +class EncodecConv1d(nn.Module): + """Conv1d with asymmetric or causal padding and normalization.""" + + def __init__( + self, + config, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + ): + super().__init__() + self.causal = config.use_causal_conv + self.pad_mode = config.pad_mode + self.norm_type = config.norm_type + + self.conv = nn.Conv1d( + in_channels, out_channels, kernel_size, stride, dilation=dilation + ) + if self.norm_type == "time_group_norm": + self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True) + + self.stride = stride + + # Effective kernel size with dilations. + self.kernel_size = (kernel_size - 1) * dilation + 1 + + self.padding_total = kernel_size - stride + + def _get_extra_padding_for_conv1d( + self, + hidden_states: mx.array, + ) -> mx.array: + length = hidden_states.shape[1] + n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1 + n_frames = int(math.ceil(n_frames)) - 1 + ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total + return ideal_length - length + + def _pad1d( + self, + hidden_states: mx.array, + paddings: Tuple[int, int], + mode: str = "zero", + value: float = 0.0, + ): + if mode != "reflect": + return mx.pad( + hidden_states, paddings, mode="constant", constant_values=value + ) + + length = hidden_states.shape[1] + prefix = hidden_states[:, 1 : paddings[0] + 1][:, ::-1] + suffix = hidden_states[:, max(length - (paddings[1] + 1), 0) : -1][:, ::-1] + return mx.concatenate([prefix, hidden_states, suffix], axis=1) + + def __call__(self, hidden_states): + extra_padding = self._get_extra_padding_for_conv1d(hidden_states) + + if self.causal: + # Left padding for causal + hidden_states = self._pad1d( + hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode + ) + else: + # Asymmetric padding required for odd strides + padding_right = self.padding_total // 2 + padding_left = self.padding_total - padding_right + hidden_states = self._pad1d( + hidden_states, + (padding_left, padding_right + extra_padding), + mode=self.pad_mode, + ) + + hidden_states = self.conv(hidden_states) + + if self.norm_type == "time_group_norm": + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class EncodecConvTranspose1d(nn.Module): + """ConvTranspose1d with asymmetric or causal padding and normalization.""" + + def __init__( + self, + config, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + ): + super().__init__() + self.causal = config.use_causal_conv + self.trim_right_ratio = config.trim_right_ratio + self.norm_type = config.norm_type + self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride) + if config.norm_type == "time_group_norm": + self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True) + self.padding_total = kernel_size - stride + + def __call__(self, hidden_states): + hidden_states = self.conv(hidden_states) + + if self.norm_type == "time_group_norm": + hidden_states = self.norm(hidden_states) + + if self.causal: + padding_right = math.ceil(self.padding_total * self.trim_right_ratio) + else: + padding_right = self.padding_total // 2 + + padding_left = self.padding_total - padding_right + + end = hidden_states.shape[1] - padding_right + hidden_states = hidden_states[:, padding_left:end, :] + return hidden_states + + +class EncodecLSTM(nn.Module): + def __init__(self, config, dimension): + super().__init__() + self.lstm = [LSTM(dimension, dimension) for _ in range(config.num_lstm_layers)] + + def __call__(self, hidden_states): + h = hidden_states + for lstm in self.lstm: + h = lstm(h) + return h + hidden_states + + +class EncodecResnetBlock(nn.Module): + """ + Residual block from SEANet model as used by EnCodec. + """ + + def __init__(self, config, dim: int, dilations: List[int]): + super().__init__() + kernel_sizes = (config.residual_kernel_size, 1) + if len(kernel_sizes) != len(dilations): + raise ValueError("Number of kernel sizes should match number of dilations") + + hidden = dim // config.compress + block = [] + for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): + in_chs = dim if i == 0 else hidden + out_chs = dim if i == len(kernel_sizes) - 1 else hidden + block += [nn.ELU()] + block += [ + EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation) + ] + self.block = block + + if getattr(config, "use_conv_shortcut", True): + self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1) + else: + self.shortcut = nn.Identity() + + def __call__(self, hidden_states): + residual = hidden_states + for layer in self.block: + hidden_states = layer(hidden_states) + + return self.shortcut(residual) + hidden_states + + +class EncodecEncoder(nn.Module): + """SEANet encoder as used by EnCodec.""" + + def __init__(self, config): + super().__init__() + model = [ + EncodecConv1d( + config, config.audio_channels, config.num_filters, config.kernel_size + ) + ] + scaling = 1 + + for ratio in reversed(config.upsampling_ratios): + current_scale = scaling * config.num_filters + for j in range(config.num_residual_layers): + model += [ + EncodecResnetBlock( + config, current_scale, [config.dilation_growth_rate**j, 1] + ) + ] + model += [nn.ELU()] + model += [ + EncodecConv1d( + config, + current_scale, + current_scale * 2, + kernel_size=ratio * 2, + stride=ratio, + ) + ] + scaling *= 2 + + model += [EncodecLSTM(config, scaling * config.num_filters)] + model += [nn.ELU()] + model += [ + EncodecConv1d( + config, + scaling * config.num_filters, + config.hidden_size, + config.last_kernel_size, + ) + ] + + self.layers = model + + def __call__(self, hidden_states): + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + +class EncodecDecoder(nn.Module): + """SEANet decoder as used by EnCodec.""" + + def __init__(self, config): + super().__init__() + scaling = int(2 ** len(config.upsampling_ratios)) + model = [ + EncodecConv1d( + config, + config.hidden_size, + scaling * config.num_filters, + config.kernel_size, + ) + ] + + model += [EncodecLSTM(config, scaling * config.num_filters)] + + for ratio in config.upsampling_ratios: + current_scale = scaling * config.num_filters + model += [nn.ELU()] + model += [ + EncodecConvTranspose1d( + config, + current_scale, + current_scale // 2, + kernel_size=ratio * 2, + stride=ratio, + ) + ] + for j in range(config.num_residual_layers): + model += [ + EncodecResnetBlock( + config, current_scale // 2, (config.dilation_growth_rate**j, 1) + ) + ] + scaling //= 2 + + model += [nn.ELU()] + model += [ + EncodecConv1d( + config, + config.num_filters, + config.audio_channels, + config.last_kernel_size, + ) + ] + self.layers = model + + def __call__(self, hidden_states): + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + +class EncodecEuclideanCodebook(nn.Module): + """Codebook with Euclidean distance.""" + + def __init__(self, config): + super().__init__() + self.embed = mx.zeros((config.codebook_size, config.codebook_dim)) + + def quantize(self, hidden_states): + embed = self.embed.T + scaled_states = hidden_states.square().sum(axis=1, keepdims=True) + dist = -( + scaled_states + - 2 * hidden_states @ embed + + embed.square().sum(axis=0, keepdims=True) + ) + embed_ind = dist.argmax(axis=-1) + return embed_ind + + def encode(self, hidden_states): + shape = hidden_states.shape + hidden_states = hidden_states.reshape((-1, shape[-1])) + embed_ind = self.quantize(hidden_states) + embed_ind = embed_ind.reshape(*shape[:-1]) + return embed_ind + + def decode(self, embed_ind): + return self.embed[embed_ind] + + +class EncodecVectorQuantization(nn.Module): + """ + Vector quantization implementation. Currently supports only euclidean distance. + """ + + def __init__(self, config): + super().__init__() + self.codebook = EncodecEuclideanCodebook(config) + + def encode(self, hidden_states): + return self.codebook.encode(hidden_states) + + def decode(self, embed_ind): + return self.codebook.decode(embed_ind) + + +class EncodecResidualVectorQuantizer(nn.Module): + """Residual Vector Quantizer.""" + + def __init__(self, config): + super().__init__() + self.codebook_size = config.codebook_size + + hop_length = np.prod(config.upsampling_ratios) + self.frame_rate = math.ceil(config.sampling_rate / hop_length) + self.num_quantizers = int( + 1000 * config.target_bandwidths[-1] // (self.frame_rate * 10) + ) + self.layers = [ + EncodecVectorQuantization(config) for _ in range(self.num_quantizers) + ] + + def get_num_quantizers_for_bandwidth( + self, bandwidth: Optional[float] = None + ) -> int: + """Return num_quantizers based on specified target bandwidth.""" + bw_per_q = math.log2(self.codebook_size) * self.frame_rate + num_quantizers = self.num_quantizers + if bandwidth is not None and bandwidth > 0.0: + num_quantizers = int(max(1, math.floor(bandwidth * 1000 / bw_per_q))) + return num_quantizers + + def encode( + self, embeddings: mx.array, bandwidth: Optional[float] = None + ) -> mx.array: + """ + Encode a given input array with the specified frame rate at the given + bandwidth. The RVQ encode method sets the appropriate number of + quantizers to use and returns indices for each quantizer. + """ + num_quantizers = self.get_num_quantizers_for_bandwidth(bandwidth) + residual = embeddings + all_indices = [] + for layer in self.layers[:num_quantizers]: + indices = layer.encode(residual) + quantized = layer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = mx.stack(all_indices, axis=1) + return out_indices + + def decode(self, codes: mx.array) -> mx.array: + """Decode the given codes to the quantized representation.""" + quantized_out = None + for i, indices in enumerate(codes.split(codes.shape[1], axis=1)): + layer = self.layers[i] + quantized = layer.decode(indices.squeeze(1)) + if quantized_out is None: + quantized_out = quantized + else: + quantized_out = quantized + quantized_out + return quantized_out + + +class EncodecModel(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.encoder = EncodecEncoder(config) + self.decoder = EncodecDecoder(config) + self.quantizer = EncodecResidualVectorQuantizer(config) + + def _encode_frame( + self, input_values: mx.array, bandwidth: float, padding_mask: mx.array + ) -> Tuple[mx.array, Optional[mx.array]]: + """ + Encodes the given input using the underlying VQVAE. + """ + length = input_values.shape[1] + duration = length / self.config.sampling_rate + + if ( + self.config.chunk_length_s is not None + and duration > 1e-5 + self.config.chunk_length_s + ): + raise RuntimeError( + f"Duration of frame ({duration}) is longer than chunk {self.config.chunk_length_s}" + ) + + scale = None + if self.config.normalize: + # if the padding is non zero + input_values = input_values * padding_mask[..., None] + mono = mx.sum(input_values, axis=2, keepdims=True) / input_values.shape[2] + scale = mono.square().mean(axis=1, keepdims=True).sqrt() + 1e-8 + input_values = input_values / scale + + embeddings = self.encoder(input_values) + codes = self.quantizer.encode(embeddings, bandwidth) + return codes, scale + + def encode( + self, + input_values: mx.array, + padding_mask: mx.array = None, + bandwidth: Optional[float] = None, + ) -> Tuple[mx.array, Optional[mx.array]]: + """ + Encodes the input audio waveform into discrete codes. + + Args: + input_values (mx.array): The input audio waveform with shape + ``(batch_size, channels, sequence_length)``. + padding_mask (mx.array): Padding mask used to pad the ``input_values``. + bandwidth (float, optional): The target bandwidth. Must be one of + ``config.target_bandwidths``. If ``None``, uses the smallest + possible bandwidth. bandwidth is represented as a thousandth of + what it is, e.g. 6kbps bandwidth is represented as bandwidth == 6.0 + + Returns: + A list of frames containing the discrete encoded codes for the + input audio waveform, along with rescaling factors for each chunk + when ``config.normalize==True``. Each frame is a tuple ``(codebook, + scale)``, with ``codebook`` of shape ``(batch_size, num_codebooks, + frames)``. + """ + + if bandwidth is None: + bandwidth = self.config.target_bandwidths[0] + if bandwidth not in self.config.target_bandwidths: + raise ValueError( + f"This model doesn't support the bandwidth {bandwidth}. " + f"Select one of {self.config.target_bandwidths}." + ) + + _, input_length, channels = input_values.shape + + if channels < 1 or channels > 2: + raise ValueError( + f"Number of audio channels must be 1 or 2, but got {channels}" + ) + + chunk_length = self.chunk_length + if chunk_length is None: + chunk_length = input_length + stride = input_length + else: + stride = self.chunk_stride + + if padding_mask is None: + padding_mask = mx.ones(input_values.shape[:2], dtype=mx.bool_) + encoded_frames = [] + scales = [] + + step = chunk_length - stride + if (input_length % stride) != step: + raise ValueError( + "The input length is not properly padded for batched chunked " + "encoding. Make sure to pad the input correctly." + ) + + for offset in range(0, input_length - step, stride): + mask = padding_mask[:, offset : offset + chunk_length].astype(mx.bool_) + frame = input_values[:, offset : offset + chunk_length] + encoded_frame, scale = self._encode_frame(frame, bandwidth, mask) + encoded_frames.append(encoded_frame) + scales.append(scale) + + encoded_frames = mx.stack(encoded_frames) + + return (encoded_frames, scales) + + @staticmethod + def _linear_overlap_add(frames: List[mx.array], stride: int): + if len(frames) == 0: + raise ValueError("`frames` cannot be an empty list.") + + dtype = frames[0].dtype + N, frame_length, C = frames[0].shape + total_size = stride * (len(frames) - 1) + frames[-1].shape[1] + + time_vec = mx.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1] + weight = 0.5 - (time_vec - 0.5).abs() + + weight = weight[:, None] + sum_weight = mx.zeros((total_size, 1), dtype=dtype) + out = mx.zeros((N, total_size, C), dtype=dtype) + offset = 0 + + for frame in frames: + frame_length = frame.shape[1] + out[:, offset : offset + frame_length] += weight[:frame_length] * frame + sum_weight[offset : offset + frame_length] += weight[:frame_length] + offset += stride + + return out / sum_weight + + def _decode_frame( + self, codes: mx.array, scale: Optional[mx.array] = None + ) -> mx.array: + embeddings = self.quantizer.decode(codes) + outputs = self.decoder(embeddings) + if scale is not None: + outputs = outputs * scale + return outputs + + @property + def channels(self): + return self.config.audio_channels + + @property + def sampling_rate(self): + return self.config.sampling_rate + + @property + def chunk_length(self): + if self.config.chunk_length_s is None: + return None + else: + return int(self.config.chunk_length_s * self.config.sampling_rate) + + @property + def chunk_stride(self): + if self.config.chunk_length_s is None or self.config.overlap is None: + return None + else: + return max(1, int((1.0 - self.config.overlap) * self.chunk_length)) + + def decode( + self, + audio_codes: mx.array, + audio_scales: Union[mx.array, List[mx.array]], + padding_mask: Optional[mx.array] = None, + ) -> Tuple[mx.array, mx.array]: + """ + Decodes the given frames into an output audio waveform. + + Note that the output might be a bit bigger than the input. In that + case, any extra steps at the end should be trimmed. + + Args: + audio_codes (mx.array): Discret code embeddings of shape + ``(batch_size, nb_chunks, chunk_length)``. + audio_scales (mx.array): Scaling factor for each input. + padding_mask (mx.array): Padding mask. + """ + chunk_length = self.chunk_length + if chunk_length is None: + if audio_codes.shape[1] != 1: + raise ValueError(f"Expected one frame, got {len(audio_codes)}") + audio_values = self._decode_frame(audio_codes[:, 0], audio_scales[0]) + else: + decoded_frames = [] + + for frame, scale in zip(audio_codes, audio_scales): + frames = self._decode_frame(frame, scale) + decoded_frames.append(frames) + + audio_values = self._linear_overlap_add( + decoded_frames, self.chunk_stride or 1 + ) + + # truncate based on padding mask + if padding_mask is not None and padding_mask.shape[1] < audio_values.shape[1]: + audio_values = audio_values[:, : padding_mask.shape[1]] + return audio_values diff --git a/encodec/example.py b/encodec/example.py new file mode 100644 index 00000000..97b311a1 --- /dev/null +++ b/encodec/example.py @@ -0,0 +1,37 @@ +# Copyright © 2024 Apple Inc. + +import mlx.core as mx +from utils import load, load_audio, save_audio + +# Load the 48 KHz model and preprocessor. +model, processor = load("mlx-community/encodec-48khz-float32") + +# Load an audio file +audio = load_audio("/path/to/audio", model.sampling_rate, model.channels) + +# Preprocess the audio (this can also be a list of arrays for batched +# processing). +feats, mask = processor(audio) + + +# Encode at the given bandwidth. A lower bandwidth results in more +# compression but lower reconstruction quality. +@mx.compile +def encode(feats, mask): + return model.encode(feats, mask, bandwidth=3) + + +# Decode to reconstruct the audio +@mx.compile +def decode(codes, scales, mask): + return model.decode(codes, scales, mask) + + +codes, scales = encode(feats, mask) +reconstructed = decode(codes, scales, mask) + +# Trim any padding: +reconstructed = reconstructed[0, : len(audio)] + +# Save the audio as a wave file +save_audio("reconstructed.wav", reconstructed, model.sampling_rate) diff --git a/encodec/requirements.txt b/encodec/requirements.txt new file mode 100644 index 00000000..de5cc646 --- /dev/null +++ b/encodec/requirements.txt @@ -0,0 +1,3 @@ +mlx>=0.18 +numpy +huggingface_hub diff --git a/encodec/test.py b/encodec/test.py new file mode 100644 index 00000000..ffc23505 --- /dev/null +++ b/encodec/test.py @@ -0,0 +1,66 @@ +# Copyright © 2024 Apple Inc. + +import mlx.core as mx +import numpy as np +import torch +from datasets import Audio, load_dataset +from transformers import AutoProcessor, EncodecModel +from utils import load, load_audio, preprocess_audio + + +def compare_processors(): + np.random.seed(0) + audio_length = 95500 + audio = np.random.uniform(size=(2, audio_length)).astype(np.float32) + + processor = AutoProcessor.from_pretrained("facebook/encodec_48khz") + + pt_inputs = processor( + raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt" + ) + mx_inputs = preprocess_audio( + mx.array(audio).T, + processor.sampling_rate, + processor.chunk_length, + processor.chunk_stride, + ) + + assert np.array_equal(pt_inputs["input_values"], mx_inputs[0].moveaxis(2, 1)) + assert np.array_equal(pt_inputs["padding_mask"], mx_inputs[1]) + + +def compare_models(): + pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz") + mx_model, _ = load("mlx-community/encodec-48khz-float32") + + np.random.seed(0) + audio_length = 190560 + audio = np.random.uniform(size=(1, audio_length, 2)).astype(np.float32) + mask = np.ones((1, audio_length), dtype=np.int32) + pt_encoded = pt_model.encode( + torch.tensor(audio).moveaxis(2, 1), torch.tensor(mask)[None] + ) + mx_encoded = mx_model.encode(mx.array(audio), mx.array(mask)) + pt_codes = pt_encoded.audio_codes.numpy() + mx_codes = mx_encoded[0] + assert np.array_equal(pt_codes, mx_codes), "Encoding codes mismatch" + + for mx_scale, pt_scale in zip(mx_encoded[1], pt_encoded.audio_scales): + if mx_scale is not None: + pt_scale = pt_scale.numpy() + assert np.allclose(pt_scale, mx_scale, atol=1e-3, rtol=1e-4) + + pt_audio = pt_model.decode( + pt_encoded.audio_codes, pt_encoded.audio_scales, torch.tensor(mask)[None] + ) + pt_audio = pt_audio[0].squeeze().T.detach().numpy() + mx_audio = mx_model.decode(*mx_encoded, mx.array(mask)) + mx_audio = mx_audio.squeeze() + assert np.allclose( + pt_audio, mx_audio, atol=1e-4, rtol=1e-4 + ), "Decoding audio mismatch" + + +if __name__ == "__main__": + compare_processors() + compare_models() diff --git a/encodec/utils.py b/encodec/utils.py new file mode 100644 index 00000000..18b3f063 --- /dev/null +++ b/encodec/utils.py @@ -0,0 +1,129 @@ +# Copyright © 2024 Apple Inc. + +import functools +import json +from pathlib import Path +from types import SimpleNamespace +from typing import List, Optional, Union + +import mlx.core as mx +import numpy as np +from huggingface_hub import snapshot_download + +import encodec + + +def save_audio(file: str, audio: mx.array, sampling_rate: int): + """ + Save audio to a wave (.wav) file. + """ + from scipy.io.wavfile import write + + audio = (audio * 32767).astype(mx.int16) + write(file, sampling_rate, np.array(audio)) + + +def load_audio(file: str, sampling_rate: int, channels: int): + """ + Read audio into an mx.array, resampling if necessary. + + Args: + file (str): The audio file to open. + sampling_rate (int): The sample rate to resample the audio at if needed. + channels (int): The number of audio channels. + + Returns: + An mx.array containing the audio waveform in float32. + """ + from subprocess import CalledProcessError, run + + # This launches a subprocess to decode audio while down-mixing + # and resampling as necessary. Requires the ffmpeg CLI in PATH. + # fmt: off + cmd = [ + "ffmpeg", + "-nostdin", + "-threads", "0", + "-i", file, + "-f", "s16le", + "-ac", str(channels), + "-acodec", "pcm_s16le", + "-ar", str(sampling_rate), + "-" + ] + # fmt: on + try: + out = run(cmd, capture_output=True, check=True).stdout + except CalledProcessError as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + + out = mx.array(np.frombuffer(out, np.int16)) + return out.reshape(-1, channels).astype(mx.float32) / 32767.0 + + +def preprocess_audio( + raw_audio: Union[mx.array, List[mx.array]], + sampling_rate: int = 24000, + chunk_length: Optional[int] = None, + chunk_stride: Optional[int] = None, +): + r""" + Prepare inputs for the EnCodec model. + + Args: + raw_audio (mx.array or List[mx.array]): The sequence or batch of + sequences to be processed. + sampling_rate (int): The sampling rate at which the audio waveform + should be digitalized. + chunk_length (int, optional): The model's chunk length. + chunk_stride (int, optional): The model's chunk stride. + """ + if not isinstance(raw_audio, list): + raw_audio = [raw_audio] + + raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio] + + max_length = max(array.shape[0] for array in raw_audio) + if chunk_length is not None: + max_length += chunk_length - (max_length % chunk_stride) + + inputs = [] + masks = [] + for x in raw_audio: + length = x.shape[0] + mask = mx.ones((length,), dtype=mx.bool_) + difference = max_length - length + if difference > 0: + mask = mx.pad(mask, (0, difference)) + x = mx.pad(x, ((0, difference), (0, 0))) + inputs.append(x) + masks.append(mask) + return mx.stack(inputs), mx.stack(masks) + + +def load(path_or_repo): + """ + Load the model and audo preprocessor. + """ + path = Path(path_or_repo) + if not path.exists(): + path = Path( + snapshot_download( + repo_id=path_or_repo, + allow_patterns=["*.json", "*.safetensors", "*.model"], + ) + ) + + with open(path / "config.json", "r") as f: + config = SimpleNamespace(**json.load(f)) + + model = encodec.EncodecModel(config) + model.load_weights(str(path / "model.safetensors")) + processor = functools.partial( + preprocess_audio, + sampling_rate=config.sampling_rate, + chunk_length=model.chunk_length, + chunk_stride=model.chunk_stride, + ) + mx.eval(model) + return model, processor From e776c970f708e7a64e192d92cd90f100105a4fd6 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 25 Sep 2024 23:19:41 +0900 Subject: [PATCH 13/23] Fix llava model when using text-only prompt (#998) --- llava/llava.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/llava/llava.py b/llava/llava.py index 06e56059..9e6b7511 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -68,11 +68,10 @@ class LlavaModel(nn.Module): input_ids: Optional[mx.array] = None, pixel_values: Optional[mx.array] = None, ): - if pixel_values is None: - return self.language_model(input_ids) - # Get the input embeddings from the language model inputs_embeds = self.language_model.model.embed_tokens(input_ids) + if pixel_values is None: + return inputs_embeds # Get the ouptut hidden states from the vision model *_, hidden_states = self.vision_tower( From 76710f61af401457e505b851ff729d93489e07b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Sat, 28 Sep 2024 16:02:53 +0200 Subject: [PATCH 14/23] Adding support for mamba (#940) * initial commit * initial commit * Adding first lines * adding x, and dt projection layers * adding the clamping mechanism * First succesful inference * last commit for today - added custom geenrate function and it works as expected, will try training and then with loading a model from the hub * clean up * save up * almost * update * update * fixed cache handeling * fixed loading * added seperate generat_step method in the model and also in the utils to automaticaly use the generate step mthod in the model class * quick update * still not working * save * still not working * initial commit * utils.py logits = logits[:, -1, :] TypeError: tuple indices must be integers or slices, not tuple * update * update * Fixing the Batching Depfwise Comnvolution and multi token input * fixing generate and logits outputs * Done! * Fixing the cache handling, generating works now trying training * update ACKNOWLEDGEMENTS * removing the model_type if stuff in the _step loop in generate_step and adding MambaCache in base.py for training easier generations and removing mamba in tuner/utils. * quick clean up * update trainer/utils for right initialisation of the layers for LoRA, but not working. * clean up * Forther update to trainer/utils for correct layer selection. Successfull training * removing extra mamba-infer.py file * clean up, reformating will come later * reformat and big clean up, final commit * some speedups and cleanups * fix test * nits * nits --------- Co-authored-by: Awni Hannun --- ACKNOWLEDGMENTS.md | 1 + llms/mlx_lm/models/mamba.py | 231 ++++++++++++++++++++++++++++++++++++ llms/mlx_lm/tuner/utils.py | 10 +- llms/tests/test_models.py | 29 +++-- 4 files changed, 263 insertions(+), 8 deletions(-) create mode 100644 llms/mlx_lm/models/mamba.py diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 2b98bc95..2037a076 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -14,3 +14,4 @@ MLX Examples was developed with contributions from the following individuals: - Markus Enzweiler: Added the `cvae` examples. - Prince Canuma: Helped add support for `Starcoder2` models. - Shiyu Li: Added the `Segment Anything Model`. +- Gökdeniz Gülmez: Added support for `MiniCPM` and `Mamba`. diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py new file mode 100644 index 00000000..26408426 --- /dev/null +++ b/llms/mlx_lm/models/mamba.py @@ -0,0 +1,231 @@ +# Copyright © 2024 Apple Inc. + +import math +from dataclasses import dataclass + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + vocab_size: int + hidden_size: int + intermediate_size: int + state_size: int + num_hidden_layers: int + conv_kernel: int + use_bias: bool + use_conv_bias: bool + time_step_rank: int + tie_word_embeddings: bool = True + + def __post_init__(self): + if not hasattr(self, "hidden_size") and hasattr(self, "d_model"): + self.hidden_size = self.d_model + if not hasattr(self, "intermediate_size") and hasattr(self, "d_inner"): + self.intermediate_size = self.d_inner + if not hasattr(self, "state_size") and hasattr(self, "d_state"): + self.state_size = self.d_state + if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layer"): + self.num_hidden_layers = self.n_layer + if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layers"): + self.num_hidden_layers = self.n_layers + if not hasattr(self, "conv_kernel") and hasattr(self, "d_conv"): + self.conv_kernel = self.d_conv + if not hasattr(self, "use_bias") and hasattr(self, "bias"): + self.use_bias = self.bias + if not hasattr(self, "use_conv_bias") and hasattr(self, "conv_bias"): + self.use_conv_bias = self.conv_bias + + if self.time_step_rank == "auto": + self.time_step_rank = math.ceil(self.hidden_size / 16) + + +class MambaCache: + def __init__(self): + self.cache = [None, None] + + def __setitem__(self, idx, value): + self.cache[idx] = value + + def __getitem__(self, idx): + return self.cache[idx] + + @property + def state(self): + return self.cache + + +class DepthWiseConv1d(nn.Module): + def __init__(self, channels, kernel_size, bias=True, padding=0): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.padding = padding + self.weight = mx.random.normal((self.channels, kernel_size, 1)) + self.bias = mx.zeros((channels,)) if bias else None + + def __call__(self, x, cache=None): + B, L, C = x.shape + groups, K, _ = self.weight.shape + + if cache is not None: + x = mx.concatenate([cache, x], axis=1) + else: + x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)]) + + y = mx.conv_general(x, self.weight, groups=groups) + + if self.bias is not None: + y = y + self.bias + + return y, x[:, -K + 1 :, :] + + +class MambaBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + self.hidden_size = args.hidden_size + self.ssm_state_size = args.state_size + self.conv_kernel_size = args.conv_kernel + self.intermediate_size = args.intermediate_size + self.time_step_rank = int(args.time_step_rank) + self.use_conv_bias = args.use_conv_bias + + self.in_proj = nn.Linear( + self.hidden_size, self.intermediate_size * 2, bias=args.use_bias + ) + + self.conv1d = DepthWiseConv1d( + channels=self.intermediate_size, + kernel_size=self.conv_kernel_size, + bias=self.use_conv_bias, + padding=self.conv_kernel_size - 1, + ) + + self.x_proj = nn.Linear( + self.intermediate_size, + self.time_step_rank + 2 * self.ssm_state_size, + bias=False, + ) + self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) + + A = mx.repeat( + mx.arange(1.0, self.ssm_state_size + 1.0).reshape([1, self.ssm_state_size]), + repeats=self.intermediate_size, + axis=0, + ) + self.A_log = mx.log(A) + self.D = mx.ones([self.intermediate_size]) + + self.out_proj = nn.Linear( + self.intermediate_size, self.hidden_size, bias=args.use_bias + ) + + def ssm_step(self, x, state=None): + A = -mx.exp(self.A_log) + D = self.D + deltaBC = self.x_proj(x) + delta, B, C = mx.split( + deltaBC, + indices_or_sections=[ + self.time_step_rank, + self.time_step_rank + self.ssm_state_size, + ], + axis=-1, + ) + delta = nn.softplus(self.dt_proj(delta)) + new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1) + if state is not None: + new_state += state * mx.exp(mx.expand_dims(delta, -1) * A) + y = (new_state @ mx.expand_dims(C, -1)).squeeze(2) + y = y + D * x + return y, new_state + + def __call__(self, x, cache): + B, T, D = x.shape + if cache is None: + cache = [None, None] + + outputs = [] + for t in range(T): + xt = x[:, t, :] + xz = self.in_proj(xt) + x_t, z_t = xz.split(indices_or_sections=2, axis=1) + conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0]) + x_t = conv_out.squeeze(1) + x_t = nn.silu(x_t) + y_t, cache[1] = self.ssm_step(x_t, cache[1]) + z_t = nn.silu(z_t) + output_t = y_t * z_t + output_t = self.out_proj(output_t) + outputs.append(output_t) + output = mx.stack(outputs, axis=1) + return output + + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.mixer = MambaBlock(args) + self.norm = nn.RMSNorm(args.hidden_size) + + def __call__(self, x: mx.array, cache): + return self.mixer(self.norm(x), cache) + x + + +class Mamba(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] + self.norm_f = nn.RMSNorm(args.hidden_size) + + def __call__(self, x: mx.array, cache): + x = self.embeddings(x) + if cache is None: + cache = [None] * len(self.layers) + for layer, c in zip(self.layers, cache): + x = layer(x, c) + return self.norm_f(x) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.backbone = Mamba(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__(self, inputs: mx.array, cache=None): + B, T = inputs.shape + + x = self.backbone(inputs, cache) + + if self.args.tie_word_embeddings: + logits = self.backbone.embeddings.as_linear(x) + else: + logits = self.lm_head(x) + + return logits + + def sanitize(self, weights): + for k, v in weights.items(): + if "conv1d.weight" in k and v.ndim == 3: + weights[k] = v.moveaxis(2, 1) + return weights + + def make_cache(self, batch_size: int = 1): + return [MambaCache() for _ in range(len(self.layers))] + + @property + def layers(self): + return self.backbone.layers diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 71fbfaab..ab9d37aa 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -52,7 +52,6 @@ def linear_to_lora_layers( use_dora (bool): If True, uses DoRA instead of LoRA. Default: ``False`` """ - num_layers = len(model.layers) if num_lora_layers < 0: @@ -140,6 +139,15 @@ def linear_to_lora_layers( "self_attn.kv_b_proj", ] ) + elif model.model_type == "mamba": + keys = set( + [ + "mixer.in_proj", + "mixer.x_proj", + "mixer.dt_proj", + "mixer.out_proj", + ] + ) else: raise ValueError(f"Lora does not support {model.model_type}") diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index fcf1dc33..cd7e7fd0 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -5,6 +5,7 @@ import unittest import mlx.core as mx from mlx.utils import tree_map from mlx_lm.models.base import KVCache, RotatingKVCache +from mlx_lm.utils import make_kv_caches class TestModels(unittest.TestCase): @@ -100,13 +101,7 @@ class TestModels(unittest.TestCase): self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.dtype, t) - kv_heads = ( - [model.n_kv_heads] * len(model.layers) - if isinstance(model.n_kv_heads, int) - else model.n_kv_heads - ) - cache = [KVCache(model.head_dim, n) for n in kv_heads] - + cache = make_kv_caches(model) outputs = model(inputs, cache) self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.dtype, t) @@ -397,6 +392,26 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_mamba(self): + from mlx_lm.models import mamba + + args = mamba.ModelArgs( + model_type="mamba", + vocab_size=10000, + use_bias=False, + use_conv_bias=True, + conv_kernel=4, + hidden_size=768, + num_hidden_layers=24, + state_size=16, + intermediate_size=1536, + time_step_rank=48, + ) + model = mamba.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + def test_gpt2(self): from mlx_lm.models import gpt2 From d812516d3d55c25e55238aee1e083fac5378a07f Mon Sep 17 00:00:00 2001 From: jamesm131 <20141156+jamesm131@users.noreply.github.com> Date: Sun, 29 Sep 2024 00:21:11 +1000 Subject: [PATCH 15/23] Add /v1/models endpoint to mlx_lm.server (#984) * Add 'models' endpoint to server * Add test for new 'models' server endpoint * Check hf_cache for mlx models * update tests to check hf_cache for models * simplify test * doc --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/SERVER.md | 14 +++++++++++++ llms/mlx_lm/server.py | 41 +++++++++++++++++++++++++++++++++++++++ llms/tests/test_server.py | 15 ++++++++++++++ 3 files changed, 70 insertions(+) diff --git a/llms/mlx_lm/SERVER.md b/llms/mlx_lm/SERVER.md index 9c42d410..55be1c9c 100644 --- a/llms/mlx_lm/SERVER.md +++ b/llms/mlx_lm/SERVER.md @@ -85,3 +85,17 @@ curl localhost:8080/v1/chat/completions \ - `adapters`: (Optional) A string path to low-rank adapters. The path must be rlative to the directory the server was started in. + +### List Models + +Use the `v1/models` endpoint to list available models: + +```shell +curl localhost:8080/v1/models -H "Content-Type: application/json" +``` + +This will return a list of locally available models where each model in the +list contains the following fields: + +- `"id"`: The Hugging Face repo id. +- `"created"`: A timestamp representing the model creation time. diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 79ac1836..f2d8b86a 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -11,6 +11,7 @@ from pathlib import Path from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union import mlx.core as mx +from huggingface_hub import scan_cache_dir from .utils import generate_step, load @@ -618,6 +619,46 @@ class APIHandler(BaseHTTPRequestHandler): prompt = self.tokenizer.encode(prompt_text) return mx.array(prompt) + def do_GET(self): + """ + Respond to a GET request from a client. + """ + if self.path == "/v1/models": + self.handle_models_request() + else: + self._set_completion_headers(404) + self.end_headers() + self.wfile.write(b"Not Found") + + def handle_models_request(self): + """ + Handle a GET request for the /v1/models endpoint. + """ + self._set_completion_headers(200) + self.end_headers() + + # Scan the cache directory for downloaded mlx models + hf_cache_info = scan_cache_dir() + downloaded_models = [ + repo for repo in hf_cache_info.repos if "mlx" in repo.repo_id + ] + + # Create a list of available models + models = [ + { + "id": repo.repo_id, + "object": "model", + "created": self.created, + } + for repo in downloaded_models + ] + + response = {"object": "list", "data": models} + + response_json = json.dumps(response).encode() + self.wfile.write(response_json) + self.wfile.flush() + def run( host: str, diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py index baea664a..cbcccfbe 100644 --- a/llms/tests/test_server.py +++ b/llms/tests/test_server.py @@ -1,5 +1,7 @@ # Copyright © 2024 Apple Inc. + import http +import json import threading import unittest @@ -77,6 +79,19 @@ class TestServer(unittest.TestCase): self.assertIn("id", response_body) self.assertIn("choices", response_body) + def test_handle_models(self): + url = f"http://localhost:{self.port}/v1/models" + response = requests.get(url) + self.assertEqual(response.status_code, 200) + response_body = json.loads(response.text) + self.assertEqual(response_body["object"], "list") + self.assertIsInstance(response_body["data"], list) + self.assertGreater(len(response_body["data"]), 0) + model = response_body["data"][0] + self.assertIn("id", model) + self.assertEqual(model["object"], "model") + self.assertIn("created", model) + def test_sequence_overlap(self): from mlx_lm.server import sequence_overlap From ace2bb58900dd382bbf3e6d9c58cbc3beee0261c Mon Sep 17 00:00:00 2001 From: nathan <97126670+nathanrchn@users.noreply.github.com> Date: Sat, 28 Sep 2024 19:08:49 +0200 Subject: [PATCH 16/23] Add logits_processor option to generate_step function (#983) * Add logits_processor option for the generation as in huggingface transformers library * concatenation correction * Rename the tokens variable for clarity * remove the logit_bias argument from generate_step method * fix the variable name * nits + test * test * add back logit bias + test --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/utils.py | 25 +++++++++++++---- llms/tests/test_generate.py | 55 +++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 6 deletions(-) create mode 100644 llms/tests/test_generate.py diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 5621609d..16271c3e 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -154,10 +154,11 @@ def generate_step( top_p: float = 1.0, min_p: float = 0.0, min_tokens_to_keep: int = 1, - logit_bias: Optional[Dict[int, float]] = None, prefill_step_size: int = 512, max_kv_size: Optional[int] = None, cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None, + logit_bias: Optional[Dict[int, float]] = None, + logits_processor: Optional[Callable[[mx.array, mx.array], mx.array]] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -177,10 +178,13 @@ def generate_step( probability) that a token probability must have to be considered. min_tokens_to_keep (int, optional): Minimum number of tokens that cannot be filtered by min_p sampling. - logit_bias (dictionary, optional): Additive logit bias. prefill_step_size (int): Step size for processing the prompt. max_kv_size (int, optional): Maximum size of the key-value cache. Old entries (except the first 4 tokens) will be overwritten. + logit_bias (dictionary, optional): Additive logit bias. + logits_processor (Callable[[mx.array, mx.array], mx.array], optional): + A function that takes tokens and logits and returns the processed + logits. Default: ``None``. Yields: Generator[Tuple[mx.array, mx.array], None, None]: A generator producing @@ -188,10 +192,6 @@ def generate_step( """ def sample(logits: mx.array) -> Tuple[mx.array, float]: - if logit_bias: - indices = mx.array(list(logit_bias.keys())) - values = mx.array(list(logit_bias.values())) - logits[:, indices] += values logprobs = logits - mx.logsumexp(logits) if temp == 0: @@ -214,6 +214,7 @@ def generate_step( ) y = prompt + tokens = None # Create the KV cache for generation cache = make_kv_caches(model, max_kv_size) @@ -233,11 +234,23 @@ def generate_step( if repetition_context_size: repetition_context = repetition_context[-repetition_context_size:] + if logit_bias: + indices = mx.array(list(logit_bias.keys())) + values = mx.array(list(logit_bias.values())) + def _step(y): nonlocal repetition_context logits = model(y[None], cache=cache) logits = logits[:, -1, :] + if logits_processor: + nonlocal tokens + tokens = mx.concat([tokens, y]) if tokens is not None else y + logits = logits_processor(tokens, logits) + + if logit_bias: + logits[:, indices] += values + if repetition_penalty: logits = apply_repetition_penalty( logits, repetition_context, repetition_penalty diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py new file mode 100644 index 00000000..bc969844 --- /dev/null +++ b/llms/tests/test_generate.py @@ -0,0 +1,55 @@ +# Copyright © 2024 Apple Inc. + +import unittest + +from mlx_lm.utils import generate, load + + +class TestGenerate(unittest.TestCase): + + @classmethod + def setUpClass(cls): + HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" + cls.model, cls.tokenizer = load(HF_MODEL_PATH) + + def test_generate(self): + # Simple test that generation runs + text = generate( + self.model, self.tokenizer, "hello", max_tokens=5, verbose=False + ) + + def test_generate_with_logit_bias(self): + logit_bias = {0: 2000.0, 1: -20.0} + text = generate( + self.model, + self.tokenizer, + "hello", + max_tokens=5, + verbose=False, + logit_bias=logit_bias, + ) + self.assertEqual(text, "!!!!!") + + def test_generate_with_processor(self): + init_toks = self.tokenizer.encode("hello") + + all_toks = None + + def logits_processor(toks, logits): + nonlocal all_toks + all_toks = toks + return logits + + generate( + self.model, + self.tokenizer, + "hello", + max_tokens=5, + verbose=False, + logits_processor=logits_processor, + ) + self.assertEqual(len(all_toks), len(init_toks) + 5) + + +if __name__ == "__main__": + unittest.main() From 7ec2021bb99de16f89c59949439f428fc968a0d7 Mon Sep 17 00:00:00 2001 From: madroid Date: Sun, 29 Sep 2024 01:41:36 +0800 Subject: [PATCH 17/23] LoRA: support tools(function calling) format datasets (#995) * LoRA: support fine-tuning tools datasets * LoRA: Split small function * LoRA: add tools format to lora docs * LoRA: pre-commit fix * Revert "LoRA: pre-commit fix" This reverts commit b94b7e0fe7c6adfb642e1392710c027096d91d49. * Revert "LoRA: Split small function" This reverts commit 3f6a5f19fd8ba24bf6933c3a9bdcc66c8b29825f. * LoRA: remove ToolsDataset In a JSONL file, not all data is required to include the tools value. * nit in readme * nit in readme * nit in readme --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/LORA.md | 68 +++++++++++++++++++++++++++++++---- llms/mlx_lm/tuner/datasets.py | 5 ++- llms/mlx_lm/tuner/trainer.py | 4 +-- 3 files changed, 66 insertions(+), 11 deletions(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 2d9a2553..8aec89ec 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -160,8 +160,8 @@ For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a `valid.jsonl` to be in the data directory. For evaluation (`--test`), the data loader expects a `test.jsonl` in the data directory. -Currently, `*.jsonl` files support three data formats: `chat`, -`completions`, and `text`. Here are three examples of these formats: +Currently, `*.jsonl` files support `chat`, `tools`, `completions`, and `text` +data formats. Here are examples of these formats: `chat`: @@ -169,6 +169,58 @@ Currently, `*.jsonl` files support three data formats: `chat`, {"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello."}, {"role": "assistant", "content": "How can I assistant you today."}]} ``` +`tools`: + +```jsonl +{"messages":[{"role":"user","content":"What is the weather in San Francisco?"},{"role":"assistant","tool_calls":[{"id":"call_id","type":"function","function":{"name":"get_current_weather","arguments":"{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"}}]}],"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and country, eg. San Francisco, USA"},"format":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location","format"]}}}]} +``` + +
+View the expanded single data tool format + +```jsonl +{ + "messages": [ + { "role": "user", "content": "What is the weather in San Francisco?" }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_id", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": "{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}" + } + } + ] + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and country, eg. San Francisco, USA" + }, + "format": { "type": "string", "enum": ["celsius", "fahrenheit"] } + }, + "required": ["location", "format"] + } + } + } + ] +} +``` + +
+ `completions`: ```jsonl @@ -215,11 +267,13 @@ hf_dataset: - Arguments specified in `config` will be passed as keyword arguments to [`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset). -In general, for the `chat` and `completions` formats, Hugging Face [chat -templates](https://huggingface.co/blog/chat-templates) are used. This applies -the model's chat template by default. If the model does not have a chat -template, then Hugging Face will use a default. For example, the final text in -the `chat` example above with Hugging Face's default template becomes: +In general, for the `chat`, `tools` and `completions` formats, Hugging Face +[chat +templates](https://huggingface.co/docs/transformers/main/en/chat_templating) +are used. This applies the model's chat template by default. If the model does +not have a chat template, then Hugging Face will use a default. For example, +the final text in the `chat` example above with Hugging Face's default template +becomes: ```text <|im_start|>system diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 3d99894c..2b8abf43 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -36,7 +36,10 @@ class ChatDataset(Dataset): def __getitem__(self, idx: int): messages = self._data[idx]["messages"] text = self._tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + messages, + tools=self._data[idx].get("tools", None), + tokenize=False, + add_generation_prompt=True, ) return text diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 24fcc5c6..b15801a5 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -93,9 +93,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) # Encode batch batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]] for b in batch: - if b[-1] == tokenizer.eos_token_id: - print("[WARNING] Example already has an EOS token appended") - else: + if b[-1] != tokenizer.eos_token_id: b.append(tokenizer.eos_token_id) lengths = [len(x) for x in batch] From 50e5ca81a8c06f4c49cec48795330209a885d2c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Mon, 30 Sep 2024 02:12:47 +0200 Subject: [PATCH 18/23] Adding full finetuning (#903) * Adding full model weights finetuning * Updating the LORA.md and ACKNOWLEDGMENTS.md files. * removing --use-dora and --fulll-training and adding --fine-tune-type * some clean up * reformating and fixing dora training * updated CONFIG_DEFAULTS * update config example * update in the config example fie * Update LORA.md * merge and commit * adding argument for dora linear layer * clean up * clean up in the example yaml file * fix * final fix before sending * small addition to re md file * fix for loading the fully trained model by saving all the files and configs correctly * clean up * removing the unnesesairy files * changing lora layers back to 16 * removed max file size * nits * resolve merge * some consistency changes --------- Co-authored-by: Awni Hannun --- ACKNOWLEDGMENTS.md | 2 +- llms/README.md | 2 +- llms/mlx_lm/LORA.md | 16 ++++++---- llms/mlx_lm/examples/lora_config.yaml | 7 ++-- llms/mlx_lm/fuse.py | 15 ++++----- llms/mlx_lm/lora.py | 46 +++++++++++++++++---------- llms/mlx_lm/tuner/trainer.py | 22 ++++++------- llms/mlx_lm/tuner/utils.py | 35 ++++++++++---------- llms/mlx_lm/utils.py | 4 +-- 9 files changed, 79 insertions(+), 70 deletions(-) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 2037a076..41557c29 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals: - Markus Enzweiler: Added the `cvae` examples. - Prince Canuma: Helped add support for `Starcoder2` models. - Shiyu Li: Added the `Segment Anything Model`. -- Gökdeniz Gülmez: Added support for `MiniCPM` and `Mamba`. +- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba` and support for `full-fine-tuning`. \ No newline at end of file diff --git a/llms/README.md b/llms/README.md index b8e1914d..75677865 100644 --- a/llms/README.md +++ b/llms/README.md @@ -16,7 +16,7 @@ conda install -c conda-forge mlx-lm The `mlx-lm` package also has: -- [LoRA and QLoRA fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md) +- [LoRA, QLoRA, and full fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md) - [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md) - [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 8aec89ec..80c25b4b 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -57,6 +57,9 @@ mlx_lm.lora \ --iters 600 ``` +To fine-tune the full model weights, add the `--fine-tune-type full` flag. +Currently supported fine-tuning types are `lora` (default), `dora`, and `full`. + The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl` when using `--train` and a path to a `test.jsonl` when using `--test`. For more details on the data format see the section on [Data](#Data). @@ -67,8 +70,8 @@ mistralai/Mistral-7B-v0.1`. If `--model` points to a quantized model, then the training will use QLoRA, otherwise it will use regular LoRA. -By default, the adapter config and weights are saved in `adapters/`. You can -specify the output location with `--adapter-path`. +By default, the adapter config and learned weights are saved in `adapters/`. +You can specify the output location with `--adapter-path`. You can resume fine-tuning with an existing adapter with `--resume-adapter-file `. @@ -118,7 +121,7 @@ mlx_lm.fuse --model ``` This will by default load the adapters from `adapters/`, and save the fused -model in the path `lora_fused_model/`. All of these are configurable. +model in the path `fused_model/`. All of these are configurable. To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments to `mlx_lm.fuse`. The latter is the repo name of the original model, which is @@ -141,7 +144,7 @@ mlx_lm.fuse \ --export-gguf ``` -This will save the GGUF model in `lora_fused_model/ggml-model-f16.gguf`. You +This will save the GGUF model in `fused_model/ggml-model-f16.gguf`. You can specify the file name with `--gguf-path`. ## Data @@ -301,7 +304,7 @@ of memory. Here are some tips to reduce memory use should you need to do so: setting this to `2` or `1` will reduce memory consumption. This may slow things down a little, but will also reduce the memory use. -3. Reduce the number of layers to fine-tune with `--lora-layers`. The default +3. Reduce the number of layers to fine-tune with `--num-layers`. The default is `16`, so you can try `8` or `4`. This reduces the amount of memory needed for back propagation. It may also reduce the quality of the fine-tuned model if you are fine-tuning with a lot of data. @@ -323,7 +326,7 @@ mlx_lm.lora \ --model mistralai/Mistral-7B-v0.1 \ --train \ --batch-size 1 \ - --lora-layers 4 \ + --num-layers 4 \ --data wikisql ``` @@ -333,4 +336,5 @@ tokens-per-second, using the MLX Example data set. [^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA. + [^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314) diff --git a/llms/mlx_lm/examples/lora_config.yaml b/llms/mlx_lm/examples/lora_config.yaml index 073a5b6f..4ec9a23c 100644 --- a/llms/mlx_lm/examples/lora_config.yaml +++ b/llms/mlx_lm/examples/lora_config.yaml @@ -1,8 +1,12 @@ # The path to the local model directory or Hugging Face repo. model: "mlx_model" + # Whether or not to train (boolean) train: true +# The fine-tuning method: "lora", "dora", or "full". +fine_tune_type: lora + # Directory with {train, valid, test}.jsonl files data: "/path/to/training/data" @@ -51,9 +55,6 @@ max_seq_length: 2048 # Use gradient checkpointing to reduce memory use. grad_checkpoint: false -# Use DoRA instead of LoRA. -use_dora: false - # LoRA parameters can only be specified in a config file lora_parameters: # The layer keys to apply LoRA to. diff --git a/llms/mlx_lm/fuse.py b/llms/mlx_lm/fuse.py index 16457036..b0c46a74 100644 --- a/llms/mlx_lm/fuse.py +++ b/llms/mlx_lm/fuse.py @@ -8,7 +8,7 @@ from mlx.utils import tree_flatten, tree_unflatten from .gguf import convert_to_gguf from .tuner.dora import DoRAEmbedding, DoRALinear from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear -from .tuner.utils import apply_lora_layers, dequantize +from .tuner.utils import dequantize, load_adapters from .utils import ( fetch_from_hub, get_model_path, @@ -29,7 +29,7 @@ def parse_arguments() -> argparse.Namespace: ) parser.add_argument( "--save-path", - default="lora_fused_model", + default="fused_model", help="The path to save the fused model.", ) parser.add_argument( @@ -77,17 +77,14 @@ def main() -> None: model, config, tokenizer = fetch_from_hub(model_path) model.freeze() - model = apply_lora_layers(model, args.adapter_path) + model = load_adapters(model, args.adapter_path) fused_linears = [ - (n, m.fuse()) - for n, m in model.named_modules() - if isinstance( - m, (LoRASwitchLinear, LoRALinear, LoRAEmbedding, DoRALinear, DoRAEmbedding) - ) + (n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse") ] - model.update_modules(tree_unflatten(fused_linears)) + if fused_linears: + model.update_modules(tree_unflatten(fused_linears)) if args.de_quantize: print("De-quantizing model") diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 580e3d3c..69232774 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -15,9 +15,9 @@ from .tokenizer_utils import TokenizerWrapper from .tuner.datasets import load_dataset from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train from .tuner.utils import ( - apply_lora_layers, build_schedule, linear_to_lora_layers, + load_adapters, print_trainable_parameters, ) from .utils import load, save_config @@ -41,9 +41,10 @@ yaml_loader.add_implicit_resolver( CONFIG_DEFAULTS = { "model": "mlx_model", "train": False, + "fine_tune_type": "lora", "data": "data/", "seed": 0, - "lora_layers": 16, + "num_layers": 16, "batch_size": 4, "iters": 1000, "val_batches": 25, @@ -58,7 +59,6 @@ CONFIG_DEFAULTS = { "max_seq_length": 2048, "lr_schedule": None, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, - "use_dora": False, } @@ -82,7 +82,14 @@ def build_parser(): help="Directory with {train, valid, test}.jsonl files", ) parser.add_argument( - "--lora-layers", + "--fine-tune-type", + type=str, + choices=["lora", "dora", "full"], + default="lora", + help="Type of fine-tuning to perform: lora, dora, or full.", + ) + parser.add_argument( + "--num-layers", type=int, help="Number of layers to fine-tune. Default is 16, use -1 for all.", ) @@ -107,12 +114,12 @@ def build_parser(): parser.add_argument( "--resume-adapter-file", type=str, - help="Load path to resume training with the given adapters.", + help="Load path to resume training from the given fine-tuned weights.", ) parser.add_argument( "--adapter-path", type=str, - help="Save/load path for the adapters.", + help="Save/load path for the fine-tuned weights.", ) parser.add_argument( "--save-every", @@ -148,9 +155,6 @@ def build_parser(): default=None, ) parser.add_argument("--seed", type=int, default=None, help="The PRNG seed") - parser.add_argument( - "--use-dora", action="store_true", default=None, help="Use DoRA to finetune." - ) return parser @@ -162,21 +166,31 @@ def train_model( valid_set, training_callback: TrainingCallback = None, ): - # Freeze all layers model.freeze() + if args.fine_tune_type == "full": + for l in model.layers[-min(args.num_layers, 0) :]: + l.unfreeze() + elif args.fine_tune_type in ["lora", "dora"]: + # Convert linear layers to lora/dora layers and unfreeze in the process + linear_to_lora_layers( + model, + args.num_layers, + args.lora_parameters, + use_dora=(args.fine_tune_type == "dora"), + ) + else: + raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}") - # Convert linear layers to lora layers and unfreeze in the process - linear_to_lora_layers(model, args.lora_layers, args.lora_parameters, args.use_dora) - - # Resume training the given adapters. + # Resume from weights if provided if args.resume_adapter_file is not None: - print(f"Loading pretrained adapters from {args.resume_adapter_file}") + print(f"Loading fine-tuned weights from {args.resume_adapter_file}") model.load_weights(args.resume_adapter_file, strict=False) print_trainable_parameters(model) adapter_path = Path(args.adapter_path) adapter_path.mkdir(parents=True, exist_ok=True) + adapter_file = adapter_path / "adapters.safetensors" save_config(vars(args), adapter_path / "adapter_config.json") @@ -240,7 +254,7 @@ def run(args, training_callback: TrainingCallback = None): if args.test and not args.train: # Allow testing without LoRA layers by providing empty path if args.adapter_path != "": - apply_lora_layers(model, args.adapter_path) + load_adapters(model, args.adapter_path) elif args.train: print("Training") diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index b15801a5..1d934a72 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -1,5 +1,7 @@ # Copyright © 2024 Apple Inc. +import glob +import shutil import time from dataclasses import dataclass, field from pathlib import Path @@ -285,24 +287,18 @@ def train( # Save adapter weights if it % args.steps_per_save == 0: - save_adapter(model, args.adapter_file) + adapter_weights = dict(tree_flatten(model.trainable_parameters())) + mx.save_safetensors(str(args.adapter_file), adapter_weights) checkpoint = ( Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors" ) - save_adapter(model, checkpoint) + mx.save_safetensors(str(checkpoint), adapter_weights) print( f"Iter {it}: Saved adapter weights to " f"{args.adapter_file} and {checkpoint}." ) - # save final adapter weights - save_adapter(model, args.adapter_file) - print(f"Saved final adapter weights to {args.adapter_file}.") - - -def save_adapter( - model: nn.Module, - adapter_file: Union[str, Path], -): - flattened_tree = tree_flatten(model.trainable_parameters()) - mx.save_safetensors(str(adapter_file), dict(flattened_tree)) + # Save final weights + adapter_weights = dict(tree_flatten(model.trainable_parameters())) + mx.save_safetensors(str(args.adapter_file), adapter_weights) + print(f"Saved final weights to {args.adapter_file}.") diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index ab9d37aa..7c78ee91 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -36,7 +36,7 @@ def build_schedule(schedule_config: Dict): def linear_to_lora_layers( model: nn.Module, - num_lora_layers: int, + num_layers: int, config: Dict, use_dora: bool = False, ): @@ -45,22 +45,17 @@ def linear_to_lora_layers( Args: model (nn.Module): The neural network model. - num_lora_layers (int): The number of blocks to convert to lora layers + num_layers (int): The number of blocks to convert to lora layers starting from the last layer. config (dict): More configuration parameters for LoRA, including the rank, scale, and optional layer keys. use_dora (bool): If True, uses DoRA instead of LoRA. Default: ``False`` """ - num_layers = len(model.layers) - - if num_lora_layers < 0: - num_lora_layers = num_layers - - if num_lora_layers > num_layers: + if num_layers > len(model.layers): raise ValueError( - f"Requested {num_lora_layers} LoRA layers " - f"but the model only has {num_layers} layers." + f"Requested {num_layers} LoRA layers " + f"but the model only has {len(model.layers)} layers." ) def to_lora(layer): @@ -151,7 +146,7 @@ def linear_to_lora_layers( else: raise ValueError(f"Lora does not support {model.model_type}") - for l in model.layers[num_layers - num_lora_layers :]: + for l in model.layers[-min(num_layers, 0) :]: lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys] if lora_layers: l.update_modules(tree_unflatten(lora_layers)) @@ -161,9 +156,9 @@ def linear_to_lora_layers( model.update_modules(tree_unflatten(lora_modules)) -def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module: +def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module: """ - Apply LoRA layers to the model. + Load any fine-tuned adapters / layers. Args: model (nn.Module): The neural network model. @@ -177,12 +172,14 @@ def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module: raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}") with open(adapter_path / "adapter_config.json", "r") as fid: config = types.SimpleNamespace(**json.load(fid)) - linear_to_lora_layers( - model, - config.lora_layers, - config.lora_parameters, - getattr(config, "use_dora", False), - ) + fine_tune_type = getattr(config, "fine_tune_type", "lora") + if fine_tune_type != "full": + linear_to_lora_layers( + model, + config.num_layers, + config.lora_parameters, + use_dora=(fine_tune_type == "dora"), + ) model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False) return model diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 16271c3e..9411138d 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -21,8 +21,8 @@ from transformers import PreTrainedTokenizer from .models.base import KVCache, RotatingKVCache from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling from .tokenizer_utils import TokenizerWrapper, load_tokenizer -from .tuner.utils import apply_lora_layers from .tuner.utils import dequantize as dequantize_model +from .tuner.utils import load_adapters # Constants MODEL_REMAPPING = { @@ -515,7 +515,7 @@ def load( model = load_model(model_path, lazy, model_config) if adapter_path is not None: - model = apply_lora_layers(model, adapter_path) + model = load_adapters(model, adapter_path) model.eval() tokenizer = load_tokenizer(model_path, tokenizer_config) From aa1c8abdc67be20a1efcacb0e3102ac732a87967 Mon Sep 17 00:00:00 2001 From: madroid Date: Mon, 30 Sep 2024 22:36:21 +0800 Subject: [PATCH 19/23] LoRA: Support HuggingFace dataset via data parameter (#996) * LoRA: support huggingface dataset via `data` argument * LoRA: Extract the load_custom_hf_dataset function * LoRA: split small functions * fix spelling errors * handle load hf dataset error * fix pre-commit lint * update data argument help * nits and doc --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/LORA.md | 8 ++- llms/mlx_lm/lora.py | 5 +- llms/mlx_lm/tuner/datasets.py | 131 +++++++++++++++++++++------------- 3 files changed, 93 insertions(+), 51 deletions(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 80c25b4b..2d0dcf60 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -251,7 +251,13 @@ To use Hugging Face datasets, first install the `datasets` package: pip install datasets ``` -Specify the Hugging Face dataset arguments in a YAML config. For example: +If the Hugging Face dataset is already in a supported format, you can specify +it on the command line. For example, pass `--data mlx-community/wikisql` to +train on the pre-formatted WikiwSQL data. + +Otherwise, provide a mapping of keys in the dataset to the features MLX LM +expects. Use a YAML config to specify the Hugging Face dataset arguments. For +example: ``` hf_dataset: diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 69232774..c96e75a7 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -79,7 +79,10 @@ def build_parser(): parser.add_argument( "--data", type=str, - help="Directory with {train, valid, test}.jsonl files", + help=( + "Directory with {train, valid, test}.jsonl files or the name " + "of a Hugging Face dataset (e.g., 'mlx-community/wikisql')" + ), ) parser.add_argument( "--fine-tune-type", diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 2b8abf43..20b32eff 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -76,17 +76,14 @@ class CompletionsDataset(Dataset): return text -def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None): - # Return empty dataset for non-existent paths - if not path.exists(): - return [] - with open(path, "r") as fid: - data = [json.loads(l) for l in fid] - if "messages" in data[0]: +def create_dataset(data, tokenizer: PreTrainedTokenizer = None): + sample = data[0] + + if "messages" in sample: return ChatDataset(data, tokenizer) - elif "prompt" in data[0] and "completion" in data[0]: + elif "prompt" in sample and "completion" in sample: return CompletionsDataset(data, tokenizer) - elif "text" in data[0]: + elif "text" in sample: return Dataset(data) else: raise ValueError( @@ -95,54 +92,90 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None): ) -def load_dataset(args, tokenizer: PreTrainedTokenizer): - if getattr(args, "hf_dataset", None) is not None: - import datasets +def load_local_dataset(data_path: Path, tokenizer: PreTrainedTokenizer): + def load_subset(path): + if not path.exists(): + return [] + with open(path, "r") as fid: + data = [json.loads(l) for l in fid] + return create_dataset(data, tokenizer) - hf_args = args.hf_dataset - dataset_name = hf_args["name"] - print(f"Loading Hugging Face dataset {dataset_name}.") - text_feature = hf_args.get("text_feature") - prompt_feature = hf_args.get("prompt_feature") - completion_feature = hf_args.get("completion_feature") + names = ("train", "valid", "test") + train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] + return train, valid, test - def create_hf_dataset(split: str = None): - ds = datasets.load_dataset( - dataset_name, - split=split, - **hf_args.get("config", {}), - ) - if prompt_feature and completion_feature: - return CompletionsDataset( - ds, tokenizer, prompt_feature, completion_feature - ) - elif text_feature: - return Dataset(train_ds, text_key=text_feature) - else: - raise ValueError( - "Specify either a prompt and completion feature or a text " - "feature for the Hugging Face dataset." - ) - if args.train: - train_split = hf_args.get("train_split", "train[:80%]") - valid_split = hf_args.get("valid_split", "train[-10%:]") - train = create_hf_dataset(split=train_split) - valid = create_hf_dataset(split=valid_split) - else: - train, valid = [], [] - if args.test: - test = create_hf_dataset(split=hf_args.get("test_split")) - else: - test = [] +def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer): + from datasets import exceptions, load_dataset + + try: + dataset = load_dataset(data_id) - else: names = ("train", "valid", "test") - data_path = Path(args.data) train, valid, test = [ - create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names + create_dataset(dataset[n], tokenizer) if n in dataset.keys() else [] + for n in names ] + + except exceptions.DatasetNotFoundError: + raise ValueError(f"Not found Hugging Face dataset: {data_id} .") + + return train, valid, test + + +def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): + import datasets + + hf_args = args.hf_dataset + dataset_name = hf_args["name"] + print(f"Loading Hugging Face dataset {dataset_name}.") + text_feature = hf_args.get("text_feature") + prompt_feature = hf_args.get("prompt_feature") + completion_feature = hf_args.get("completion_feature") + + def create_hf_dataset(split: str = None): + ds = datasets.load_dataset( + dataset_name, + split=split, + **hf_args.get("config", {}), + ) + if prompt_feature and completion_feature: + return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature) + elif text_feature: + return Dataset(train_ds, text_key=text_feature) + else: + raise ValueError( + "Specify either a prompt and completion feature or a text " + "feature for the Hugging Face dataset." + ) + + if args.train: + train_split = hf_args.get("train_split", "train[:80%]") + valid_split = hf_args.get("valid_split", "train[-10%:]") + train = create_hf_dataset(split=train_split) + valid = create_hf_dataset(split=valid_split) + else: + train, valid = [], [] + if args.test: + test = create_hf_dataset(split=hf_args.get("test_split")) + else: + test = [] + + return train, valid, test + + +def load_dataset(args, tokenizer: PreTrainedTokenizer): + if getattr(args, "hf_dataset", None) is not None: + train, valid, test = load_custom_hf_dataset(args, tokenizer) + else: + data_path = Path(args.data) + if data_path.exists(): + train, valid, test = load_local_dataset(data_path, tokenizer) + else: + print(f"Loading Hugging Face dataset {args.data}.") + train, valid, test = load_hf_dataset(args.data, tokenizer) + if args.train and len(train) == 0: raise ValueError( "Training set not found or empty. Must provide training set for fine-tuning." From 418d9a5511fbefced1229cfed1d311a771aa5db5 Mon Sep 17 00:00:00 2001 From: Zai Thottakath Date: Mon, 30 Sep 2024 10:01:11 -0500 Subject: [PATCH 20/23] Feature: QDoRA (#891) * feat: QDoRA with tests and a small bug fix for recalculation of self.m * some simplifications and fixes --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/tuner/dora.py | 51 ++++++++++--- llms/tests/test_finetune.py | 143 +++++++++++++++++++++++++++++++++++- 2 files changed, 183 insertions(+), 11 deletions(-) diff --git a/llms/mlx_lm/tuner/dora.py b/llms/mlx_lm/tuner/dora.py index bd2dfb01..aba1f6f4 100644 --- a/llms/mlx_lm/tuner/dora.py +++ b/llms/mlx_lm/tuner/dora.py @@ -14,10 +14,11 @@ class DoRALinear(nn.Module): dropout: float = 0.0, scale: float = 20.0, ): - # TODO support quantized weights in DoRALinear + # TODO remove when input_dims and output_dims are attributes + # on linear and quantized linear output_dims, input_dims = linear.weight.shape if isinstance(linear, nn.QuantizedLinear): - raise ValueError("DoRALinear does not yet support quantization.") + input_dims *= 32 // linear.bits dora_lin = DoRALinear( input_dims=input_dims, output_dims=output_dims, @@ -31,13 +32,13 @@ class DoRALinear(nn.Module): def fuse(self, de_quantize: bool = False): linear = self.linear bias = "bias" in linear - weight = linear.weight + weight = self._dequantized_weight() - # Use the same type as the linear weight if not quantized + # Use the same type as the linear weight dtype = weight.dtype output_dims, input_dims = weight.shape - fused_linear = nn.Linear(input_dims, output_dims, bias=bias) + fused_linear = nn.Linear(input_dims, output_dims, bias=False) lora_b = (self.scale * self.lora_b.T).astype(dtype) lora_a = self.lora_a.T.astype(dtype) @@ -47,6 +48,13 @@ class DoRALinear(nn.Module): if bias: fused_linear.bias = linear.bias + + if self._is_quantized() and not de_quantize: + fused_linear = nn.QuantizedLinear.from_linear( + fused_linear, + linear.group_size, + linear.bits, + ) return fused_linear def __init__( @@ -76,22 +84,45 @@ class DoRALinear(nn.Module): ) self.lora_b = mx.zeros(shape=(r, output_dims)) - def set_linear(self, linear: nn.Linear): + def set_linear(self, linear): + """ + Set the self.linear layer and recompute self.m. + """ self.linear = linear - self.m = mx.linalg.norm(self.linear.weight, axis=1) + self.m = mx.linalg.norm(self._dequantized_weight().astype(mx.float32), axis=1) + + def _dequantized_weight(self): + """ + Return the weight of linear layer and dequantize it if is quantized + """ + weight = self.linear.weight + if self._is_quantized(): + weight = mx.dequantize( + weight, + self.linear.scales, + self.linear.biases, + self.linear.group_size, + self.linear.bits, + ) + return weight + + def _is_quantized(self): + return isinstance(self.linear, nn.QuantizedLinear) def __call__(self, x): # Regular LoRA (without a bias) - y = x @ self.linear.weight.T + w = self._dequantized_weight() + y = x @ w.T + z = (self.dropout(x) @ self.lora_a) @ self.lora_b out = y + (self.scale * z).astype(x.dtype) # Compute the norm of the adapted weights - adapted = self.linear.weight + (self.scale * self.lora_b.T) @ self.lora_a.T + adapted = w + (self.scale * self.lora_b.T) @ self.lora_a.T denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1)) # Remove the norm and scale by the learned magnitude - out = (self.m / denom) * out + out = (self.m / denom).astype(x.dtype) * out if "bias" in self.linear: out = out + self.linear.bias diff --git a/llms/tests/test_finetune.py b/llms/tests/test_finetune.py index 289b8cfb..107be092 100644 --- a/llms/tests/test_finetune.py +++ b/llms/tests/test_finetune.py @@ -11,7 +11,7 @@ import mlx.nn as nn import mlx.optimizers as opt from mlx.utils import tree_flatten from mlx_lm import lora, tuner -from mlx_lm.tuner.dora import DoRAEmbedding +from mlx_lm.tuner.dora import DoRAEmbedding, DoRALinear from mlx_lm.tuner.lora import LoRAEmbedding, LoRALinear from mlx_lm.tuner.trainer import evaluate from mlx_lm.tuner.utils import build_schedule @@ -164,6 +164,147 @@ class TestDora(unittest.TestCase): self.assertFalse(mx.array_equal(embedding.weight, new_embedding.weight)) self.assertFalse(mx.array_equal(embedding(tokens), dora_emb(tokens))) + def test_llama(self): + from mlx_lm.models import llama + + hidden_size = 1024 + intermediate_size = 2048 + args = llama.ModelArgs( + model_type="llama", + hidden_size=hidden_size, + num_hidden_layers=4, + intermediate_size=intermediate_size, + num_attention_heads=4, + rms_norm_eps=1e-5, + vocab_size=10_000, + ) + + dora_layers = 4 + + def check_config(params): + n_keys = 2 + if "keys" in params: + n_keys = len(params["keys"]) + model = llama.Model(args) + model.freeze() + tuner.utils.linear_to_lora_layers(model, dora_layers, params, use_dora=True) + trainable_params = sum( + v.size for _, v in tree_flatten(model.trainable_parameters()) + ) + self.assertEqual( + trainable_params, + dora_layers + * (params["rank"] * hidden_size * 2 * n_keys + n_keys * hidden_size), + ) + + params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0} + check_config(params) + + params["rank"] = 1 + check_config(params) + + params["keys"] = ["self_attn.k_proj"] + check_config(params) + + def test_dora_m_parameter(self): + dora_lin = DoRALinear(input_dims=100, output_dims=100) + self.assertTrue( + mx.allclose(dora_lin.m, mx.linalg.norm(dora_lin.linear.weight, axis=1)) + ) + + # Recomputes m when changing Linear + inital_m = dora_lin.m + lin = nn.Linear(10, 10) + dora_lin.set_linear(lin) + self.assertTrue(mx.allclose(dora_lin.m, mx.linalg.norm(lin.weight, axis=1))) + + # Works with quantized weights + quantized_linear = nn.QuantizedLinear(512, 512) + dora_lin.set_linear(quantized_linear) + dequantized_weight = mx.dequantize( + quantized_linear.weight, + quantized_linear.scales, + quantized_linear.biases, + quantized_linear.group_size, + quantized_linear.bits, + ) + self.assertTrue( + mx.allclose(dora_lin.m, mx.linalg.norm(dequantized_weight, axis=1)) + ) + + def test_dora_from_linear(self): + in_dims = 256 + out_dims = 256 + r = 4 + + linear = nn.Linear(in_dims, out_dims) + dora_lin = DoRALinear.from_base(linear, r) + self.assertTrue(mx.allclose(dora_lin.m, mx.linalg.norm(linear.weight, axis=1))) + self.assertEqual(dora_lin.lora_a.shape, (in_dims, r)) + self.assertEqual(dora_lin.lora_b.shape, (r, out_dims)) + self.assertEqual(dora_lin.m.shape, (out_dims,)) + + quantized_linear = nn.QuantizedLinear(in_dims, out_dims) + dequantized_weight = mx.dequantize( + quantized_linear.weight, + quantized_linear.scales, + quantized_linear.biases, + quantized_linear.group_size, + quantized_linear.bits, + ) + dora_quant_lin = DoRALinear.from_base(quantized_linear, r) + self.assertTrue( + mx.allclose(dora_quant_lin.m, mx.linalg.norm(dequantized_weight, axis=1)) + ) + self.assertEqual(dora_quant_lin.lora_a.shape, (in_dims, r)) + self.assertEqual(dora_quant_lin.lora_b.shape, (r, out_dims)) + self.assertEqual(dora_quant_lin.m.shape, (out_dims,)) + + def test_dora_to_linear(self): + in_dims = 256 + out_dims = 256 + r = 4 + + linear = nn.Linear(in_dims, out_dims, bias=True) + dora_lin = DoRALinear.from_base(linear, r) + to_linear = dora_lin.fuse() + self.assertTrue(mx.allclose(linear.weight, to_linear.weight)) + self.assertTrue(mx.allclose(linear.bias, to_linear.bias)) + + def dequantize_weight(quantized_linear): + return mx.dequantize( + quantized_linear.weight, + quantized_linear.scales, + quantized_linear.biases, + quantized_linear.group_size, + quantized_linear.bits, + ) + + quantized_linear = nn.QuantizedLinear(in_dims, out_dims, bias=True) + dora_quantized_linear = DoRALinear.from_base(quantized_linear, r) + # Dequantize + to_linear_from_quantized = dora_quantized_linear.fuse(de_quantize=True) + self.assertTrue( + mx.allclose(quantized_linear.bias, to_linear_from_quantized.bias) + ) + self.assertTrue( + mx.allclose( + dequantize_weight(quantized_linear), to_linear_from_quantized.weight + ) + ) + + def test_dora_dtype(self): + in_dims = 256 + out_dims = 256 + r = 4 + + linear = nn.Linear(in_dims, out_dims, bias=True) + linear.set_dtype(mx.float16) + dora_lin = DoRALinear.from_base(linear, r) + + x = mx.random.uniform(shape=(2, 256)).astype(mx.float16) + self.assertEqual(dora_lin(x).dtype, mx.float16) + class TestScheduleConfig(unittest.TestCase): def test_join(self): From 0866e23a67c3d5ab8dbac352b112f13967103942 Mon Sep 17 00:00:00 2001 From: nathan <97126670+nathanrchn@users.noreply.github.com> Date: Mon, 30 Sep 2024 17:49:03 +0200 Subject: [PATCH 21/23] repetiton_penalty and logits_bias just using logits_processors (#1004) * refactor of repetition_penalty and logits_bias to use logits_processor * nits --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/utils.py | 66 ++++++++++++++++++------------------- llms/tests/test_generate.py | 2 +- 2 files changed, 33 insertions(+), 35 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 9411138d..54a96457 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -101,7 +101,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path return model_path -def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float): +def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float): """ Apply repetition penalty to specific logits based on the given context. @@ -109,19 +109,18 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f Args: logits (mx.array): The logits produced by the language model. - generated_tokens (any): A list of N previous tokens. + tokens (mx.array): A list of N previous tokens. penalty (float): The repetition penalty factor to be applied. Returns: logits (mx.array): Logits with repetition penalty applied to generated tokens. """ - if len(generated_tokens) > 0: - indices = mx.array([token for token in generated_tokens]) - selected_logits = logits[:, indices] + if len(tokens) > 0: + selected_logits = logits[:, tokens] selected_logits = mx.where( selected_logits < 0, selected_logits * penalty, selected_logits / penalty ) - logits[:, indices] = selected_logits + logits[:, tokens] = selected_logits return logits @@ -158,7 +157,7 @@ def generate_step( max_kv_size: Optional[int] = None, cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None, logit_bias: Optional[Dict[int, float]] = None, - logits_processor: Optional[Callable[[mx.array, mx.array], mx.array]] = None, + logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -182,8 +181,8 @@ def generate_step( max_kv_size (int, optional): Maximum size of the key-value cache. Old entries (except the first 4 tokens) will be overwritten. logit_bias (dictionary, optional): Additive logit bias. - logits_processor (Callable[[mx.array, mx.array], mx.array], optional): - A function that takes tokens and logits and returns the processed + logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional): + A list of functions that take tokens and logits and return the processed logits. Default: ``None``. Yields: @@ -213,6 +212,27 @@ def generate_step( f"repetition_penalty must be a non-negative float, got {repetition_penalty}" ) + logits_processor = logits_processor or [] + + if repetition_penalty: + + def repetition_penalty_processor(tokens, logits): + return apply_repetition_penalty( + logits, tokens[-repetition_context_size:], repetition_penalty + ) + + logits_processor.append(repetition_penalty_processor) + + if logit_bias: + indices = mx.array(list(logit_bias.keys())) + values = mx.array(list(logit_bias.values())) + + def logit_bias_processor(_, logits): + logits[:, indices] += values + return logits + + logits_processor.append(logit_bias_processor) + y = prompt tokens = None @@ -229,40 +249,18 @@ def generate_step( c.update_and_fetch(h[0], h[1]) mx.eval([c.state for c in cache]) - repetition_context = prompt.tolist() - - if repetition_context_size: - repetition_context = repetition_context[-repetition_context_size:] - - if logit_bias: - indices = mx.array(list(logit_bias.keys())) - values = mx.array(list(logit_bias.values())) - def _step(y): - nonlocal repetition_context logits = model(y[None], cache=cache) logits = logits[:, -1, :] if logits_processor: nonlocal tokens tokens = mx.concat([tokens, y]) if tokens is not None else y - logits = logits_processor(tokens, logits) - if logit_bias: - logits[:, indices] += values + for processor in logits_processor: + logits = processor(tokens, logits) - if repetition_penalty: - logits = apply_repetition_penalty( - logits, repetition_context, repetition_penalty - ) - y, logprobs = sample(logits) - repetition_context.append(y.item()) - else: - y, logprobs = sample(logits) - - if repetition_context_size: - if len(repetition_context) > repetition_context_size: - repetition_context = repetition_context[-repetition_context_size:] + y, logprobs = sample(logits) return y, logprobs.squeeze(0) while y.size > prefill_step_size: diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py index bc969844..68f1670b 100644 --- a/llms/tests/test_generate.py +++ b/llms/tests/test_generate.py @@ -46,7 +46,7 @@ class TestGenerate(unittest.TestCase): "hello", max_tokens=5, verbose=False, - logits_processor=logits_processor, + logits_processor=[logits_processor], ) self.assertEqual(len(all_toks), len(init_toks) + 5) From 36c1d8e8dcd42a7104ab532ef4ae003c08f41c5f Mon Sep 17 00:00:00 2001 From: madroid Date: Thu, 3 Oct 2024 03:36:07 +0800 Subject: [PATCH 22/23] Server: support function calling (#1003) --- llms/mlx_lm/server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index f2d8b86a..42962b54 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -594,6 +594,7 @@ class APIHandler(BaseHTTPRequestHandler): ): prompt = self.tokenizer.apply_chat_template( body["messages"], + body.get("tools", None), tokenize=True, add_generation_prompt=True, ) From 9bc53fc2100319d59179a179efe34346372772cf Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 2 Oct 2024 13:13:33 -0700 Subject: [PATCH 23/23] convert (#1006) --- whisper/convert.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/whisper/convert.py b/whisper/convert.py index da7195e0..cdd50bc5 100644 --- a/whisper/convert.py +++ b/whisper/convert.py @@ -35,6 +35,8 @@ _MODELS = { "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", + "large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt", + "turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt", } # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are @@ -52,6 +54,8 @@ _ALIGNMENT_HEADS = { "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", + "large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`", + "turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`", }