From 4d2ee6740285150c1f28ad01a0e4cebd7684665c Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Tue, 8 Oct 2024 15:46:04 -0700 Subject: [PATCH] change to from_pretrained --- encodec/README.md | 4 +- encodec/benchmarks/bench_mx.py | 5 +- encodec/encodec.py | 57 +++++----- encodec/example.py | 4 +- encodec/test.py | 9 +- musicgen/benchmarks/bench_mx.py | 10 +- musicgen/benchmarks/bench_pt.py | 2 + musicgen/generate.py | 6 +- musicgen/musicgen.py | 66 ++++++++++- musicgen/utils.py | 60 ---------- t5/README.md | 40 +++---- t5/t5.py | 195 +++++++++++++++++--------------- 12 files changed, 231 insertions(+), 227 deletions(-) diff --git a/encodec/README.md b/encodec/README.md index 9a01dc7b..a3b948bf 100644 --- a/encodec/README.md +++ b/encodec/README.md @@ -33,11 +33,11 @@ An example using the model: ```python import mlx.core as mx -from encodec import load +from encodec import EncodecModel from utils import load_audio, save_audio # Load the 48 KHz model and preprocessor. -model, processor = load("mlx-community/encodec-48khz-float32") +model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32") # Load an audio file audio = load_audio("path/to/audio", model.sampling_rate, model.channels) diff --git a/encodec/benchmarks/bench_mx.py b/encodec/benchmarks/bench_mx.py index 2acd4b75..61ddaae8 100644 --- a/encodec/benchmarks/bench_mx.py +++ b/encodec/benchmarks/bench_mx.py @@ -3,9 +3,10 @@ import time import mlx.core as mx -from utils import load -model, processor = load("mlx-community/encodec-48khz-float32") +from encodec import EncodecModel + +model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32") audio = mx.random.uniform(shape=(288000, 2)) feats, mask = processor(audio) diff --git a/encodec/encodec.py b/encodec/encodec.py index 8f98a1a3..4b85dfdd 100644 --- a/encodec/encodec.py +++ b/encodec/encodec.py @@ -673,6 +673,33 @@ class EncodecModel(nn.Module): audio_values = audio_values[:, : padding_mask.shape[1]] return audio_values + @classmethod + def from_pretrained(cls, path_or_repo: str): + from huggingface_hub import snapshot_download + + path = Path(path_or_repo) + if not path.exists(): + path = Path( + snapshot_download( + repo_id=path_or_repo, + allow_patterns=["*.json", "*.safetensors", "*.model"], + ) + ) + + with open(path / "config.json", "r") as f: + config = SimpleNamespace(**json.load(f)) + + model = EncodecModel(config) + model.load_weights(str(path / "model.safetensors")) + processor = functools.partial( + preprocess_audio, + sampling_rate=config.sampling_rate, + chunk_length=model.chunk_length, + chunk_stride=model.chunk_stride, + ) + mx.eval(model) + return model, processor + def preprocess_audio( raw_audio: Union[mx.array, List[mx.array]], @@ -712,33 +739,3 @@ def preprocess_audio( inputs.append(x) masks.append(mask) return mx.stack(inputs), mx.stack(masks) - - -def load(path_or_repo): - """ - Load the model and audo preprocessor. - """ - from huggingface_hub import snapshot_download - - path = Path(path_or_repo) - if not path.exists(): - path = Path( - snapshot_download( - repo_id=path_or_repo, - allow_patterns=["*.json", "*.safetensors", "*.model"], - ) - ) - - with open(path / "config.json", "r") as f: - config = SimpleNamespace(**json.load(f)) - - model = EncodecModel(config) - model.load_weights(str(path / "model.safetensors")) - processor = functools.partial( - preprocess_audio, - sampling_rate=config.sampling_rate, - chunk_length=model.chunk_length, - chunk_stride=model.chunk_stride, - ) - mx.eval(model) - return model, processor diff --git a/encodec/example.py b/encodec/example.py index e5a899f1..15ea476c 100644 --- a/encodec/example.py +++ b/encodec/example.py @@ -3,10 +3,10 @@ import mlx.core as mx from utils import load_audio, save_audio -from encodec import load +from encodec import EncodecModel # Load the 48 KHz model and preprocessor. -model, processor = load("mlx-community/encodec-48khz-float32") +model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32") # Load an audio file audio = load_audio("/path/to/audio", model.sampling_rate, model.channels) diff --git a/encodec/test.py b/encodec/test.py index 1012f5f9..ae565c29 100644 --- a/encodec/test.py +++ b/encodec/test.py @@ -3,9 +3,10 @@ import mlx.core as mx import numpy as np import torch -from transformers import AutoProcessor, EncodecModel +from transformers import AutoProcessor +from transformers import EncodecModel as PTEncodecModel -from encodec import load, preprocess_audio +from encodec import EncodecModel, preprocess_audio def compare_processors(): @@ -30,8 +31,8 @@ def compare_processors(): def compare_models(): - pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz") - mx_model, _ = load("mlx-community/encodec-48khz-float32") + pt_model = PTEncodecModel.from_pretrained("facebook/encodec_48khz") + mx_model, _ = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32") np.random.seed(0) audio_length = 190560 diff --git a/musicgen/benchmarks/bench_mx.py b/musicgen/benchmarks/bench_mx.py index 79842efd..b669597b 100644 --- a/musicgen/benchmarks/bench_mx.py +++ b/musicgen/benchmarks/bench_mx.py @@ -1,12 +1,18 @@ +# Copyright © 2024 Apple Inc. + import sys import time from pathlib import Path import mlx.core as mx -from utils import load + +cur_path = Path(__file__).parents[1].resolve() +sys.path.append(str(cur_path)) + +from musicgen import MusicGen text = "folk ballad" -model = load("facebook/musicgen-medium") +model = MusicGen.from_pretrained("facebook/musicgen-medium") max_steps = 100 diff --git a/musicgen/benchmarks/bench_pt.py b/musicgen/benchmarks/bench_pt.py index 689c36eb..de01aa66 100644 --- a/musicgen/benchmarks/bench_pt.py +++ b/musicgen/benchmarks/bench_pt.py @@ -1,3 +1,5 @@ +# Copyright © 2024 Apple Inc. + import time import torch diff --git a/musicgen/generate.py b/musicgen/generate.py index b5c92c6c..5a6b7804 100644 --- a/musicgen/generate.py +++ b/musicgen/generate.py @@ -2,11 +2,13 @@ import argparse -from utils import load, save_audio +from utils import save_audio + +from musicgen import MusicGen def main(text: str, output_path: str, model_name: str, max_steps: int): - model = load(model_name) + model = MusicGen.from_pretrained(model_name) audio = model.generate(text, max_steps=max_steps) save_audio(output_path, audio, model.sampling_rate) diff --git a/musicgen/musicgen.py b/musicgen/musicgen.py index 892a54f9..4139117b 100644 --- a/musicgen/musicgen.py +++ b/musicgen/musicgen.py @@ -1,8 +1,10 @@ # Copyright © 2024 Apple Inc. +import json import sys from functools import partial from pathlib import Path +from types import SimpleNamespace from typing import Optional import mlx.core as mx @@ -13,14 +15,14 @@ cur_path = Path(__file__).parents[1].resolve() sys.path.append(str(cur_path / "encodec")) sys.path.append(str(cur_path / "t5")) -from encodec import load as load_encodec -from t5 import load_model as load_t5 +from encodec import EncodecModel +from t5 import T5 class TextConditioner(nn.Module): def __init__(self, t5_name, input_dim, output_dim): super().__init__() - self._t5, self.tokenizer = load_t5(t5_name) + self._t5, self.tokenizer = T5.from_pretrained(t5_name) self.output_proj = nn.Linear(input_dim, output_dim) def __call__(self, text): @@ -222,7 +224,9 @@ class MusicGen(nn.Module): ] encodec_name = config.audio_encoder._name_or_path.split("/")[-1] encodec_name = encodec_name.replace("_", "-") - self._audio_decoder, _ = load_encodec(f"mlx-community/{encodec_name}-float32") + self._audio_decoder, _ = EncodecModel.from_pretrained( + f"mlx-community/{encodec_name}-float32" + ) def __call__( self, @@ -304,3 +308,57 @@ class MusicGen(nn.Module): audio_seq = mx.swapaxes(audio_seq, -1, -2)[:, mx.newaxis] audio = self._audio_decoder.decode(audio_seq, audio_scales=[None]) return audio[0] + + @classmethod + def sanitize(cls, weights): + out_weights = {} + for k, arr in weights.items(): + if k.startswith("transformer."): + k = k[len("transformer.") :] + + if "cross_attention" in k: + k = k.replace("cross_attention", "cross_attn") + + if "condition_provider" in k: + k = k.replace( + "condition_provider.conditioners.description", "text_conditioner" + ) + + if "in_proj_weight" in k: + dim = arr.shape[0] // 3 + name = "in_proj_weight" + out_weights[k.replace(name, "q_proj.weight")] = arr[:dim] + out_weights[k.replace(name, "k_proj.weight")] = arr[dim : dim * 2] + out_weights[k.replace(name, "v_proj.weight")] = arr[dim * 2 :] + continue + + out_weights[k] = arr + return out_weights + + @classmethod + def from_pretrained(cls, path_or_repo: str): + import torch + from huggingface_hub import snapshot_download + + path = Path(path_or_repo) + if not path.exists(): + path = Path( + snapshot_download( + repo_id=path_or_repo, + allow_patterns=["*.json", "state_dict.bin"], + ) + ) + + with open(path / "config.json", "r") as f: + config = SimpleNamespace(**json.load(f)) + config.text_encoder = SimpleNamespace(**config.text_encoder) + config.audio_encoder = SimpleNamespace(**config.audio_encoder) + config.decoder = SimpleNamespace(**config.decoder) + + weights = torch.load(path / "state_dict.bin", weights_only=True)["best_state"] + weights = {k: mx.array(v.numpy()) for k, v in weights.items()} + weights = cls.sanitize(weights) + + model = MusicGen(config) + model.load_weights(list(weights.items())) + return model diff --git a/musicgen/utils.py b/musicgen/utils.py index ed10fa2e..78e92571 100644 --- a/musicgen/utils.py +++ b/musicgen/utils.py @@ -1,14 +1,7 @@ # Copyright © 2024 Apple Inc. -import json -from pathlib import Path -from types import SimpleNamespace - import mlx.core as mx import numpy as np -from huggingface_hub import snapshot_download - -import musicgen def save_audio(file: str, audio: mx.array, sampling_rate: int): @@ -20,56 +13,3 @@ def save_audio(file: str, audio: mx.array, sampling_rate: int): audio = mx.clip(audio, -1, 1) audio = (audio * 32767).astype(mx.int16) write(file, sampling_rate, np.array(audio)) - - -def load(path_or_repo): - """ - Load the model and audio preprocessor. - """ - import torch - - path = Path(path_or_repo) - if not path.exists(): - path = Path( - snapshot_download( - repo_id=path_or_repo, - allow_patterns=["*.json", "state_dict.bin"], - ) - ) - - with open(path / "config.json", "r") as f: - config = SimpleNamespace(**json.load(f)) - config.text_encoder = SimpleNamespace(**config.text_encoder) - config.audio_encoder = SimpleNamespace(**config.audio_encoder) - config.decoder = SimpleNamespace(**config.decoder) - - weights = torch.load(path / "state_dict.bin", weights_only=True)["best_state"] - weights = {k: mx.array(v.numpy()) for k, v in weights.items()} - - decoder_weights = {} - for k, arr in weights.items(): - if k.startswith("transformer."): - k = k[len("transformer.") :] - - if "cross_attention" in k: - k = k.replace("cross_attention", "cross_attn") - - if "condition_provider" in k: - k = k.replace( - "condition_provider.conditioners.description", "text_conditioner" - ) - - if "in_proj_weight" in k: - dim = arr.shape[0] // 3 - name = "in_proj_weight" - decoder_weights[k.replace(name, "q_proj.weight")] = arr[:dim] - decoder_weights[k.replace(name, "k_proj.weight")] = arr[dim : dim * 2] - decoder_weights[k.replace(name, "v_proj.weight")] = arr[dim * 2 :] - continue - - decoder_weights[k] = arr - - model = musicgen.MusicGen(config) - model.load_weights(list(decoder_weights.items())) - mx.eval(model) - return model diff --git a/t5/README.md b/t5/README.md index a0cc861b..e5165f8f 100644 --- a/t5/README.md +++ b/t5/README.md @@ -7,31 +7,6 @@ tasks by prepending task-specific prefixes to the input, e.g.: This example also supports the FLAN-T5 models variants.[^2] -## Setup - -Download and convert the model: - -```sh -python convert.py --model -``` - -This will make the `.npz` file which MLX can read. - -The `` can be any of the following: - -| Model Name | Model Size | -| ---------- | ---------- -| t5-small | 60 million | -| t5-base | 220 million | -| t5-large | 770 million | -| t5-3b | 3 billion | -| t5-11b | 11 billion | - -The FLAN variants can be specified with `google/flan-t5-small`, -`google/flan-t5-base`, etc. See the [Hugging Face -page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for a -complete list of models. - ## Generate Generate text with: @@ -48,6 +23,21 @@ To see a list of options run: python t5.py --help ``` +The `` can be any of the following: + +| Model Name | Model Size | +| ---------- | ---------- +| t5-small | 60 million | +| t5-base | 220 million | +| t5-large | 770 million | +| t5-3b | 3 billion | +| t5-11b | 11 billion | + +The FLAN variants can be specified with `google/flan-t5-small`, +`google/flan-t5-base`, etc. See the [Hugging Face +page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for a +complete list of models. + [^1]: For more information on T5 see the [original paper](https://arxiv.org/abs/1910.10683) or the [Hugging Face page](https://huggingface.co/docs/transformers/model_doc/t5). [^2]: For more information on FLAN-T5 see the [original paper](https://arxiv.org/abs/2210.11416). diff --git a/t5/t5.py b/t5/t5.py index f9a582d7..651e8e1e 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -10,36 +10,36 @@ import mlx.nn as nn import numpy as np from transformers import AutoTokenizer -SHARED_REPLACEMENT_PATTERNS = [ - (".block.", ".layers."), - (".k.", ".key_proj."), - (".o.", ".out_proj."), - (".q.", ".query_proj."), - (".v.", ".value_proj."), - ("shared.", "wte."), - ("lm_head.", "lm_head.linear."), - (".layer.0.layer_norm.", ".ln1."), - (".layer.1.layer_norm.", ".ln2."), - (".layer.2.layer_norm.", ".ln3."), - (".final_layer_norm.", ".ln."), - ( - "layers.0.layer.0.SelfAttention.relative_attention_bias.", - "relative_attention_bias.embeddings.", - ), -] -ENCODER_REPLACEMENT_PATTERNS = [ - (".layer.0.SelfAttention.", ".attention."), - (".layer.1.DenseReluDense.", ".dense."), -] +class Tokenizer: + def __init__(self, config, model_name): + self._decoder_start_id = config.decoder_start_token_id + self._tokenizer = AutoTokenizer.from_pretrained( + model_name, + legacy=False, + model_max_length=getattr(config, "n_positions", 512), + ) -DECODER_REPLACEMENT_PATTERNS = [ - (".layer.0.SelfAttention.", ".self_attention."), - (".layer.1.EncDecAttention.", ".cross_attention."), - (".layer.2.DenseReluDense.", ".dense."), -] + @property + def eos_id(self) -> int: + return self._tokenizer.eos_token_id -IGNORED_KEYS = ["decoder.layers.0.cross_attention.relative_attention_bias.weight"] + @property + def decoder_start_id(self) -> int: + return self._decoder_start_id + + def encode(self, s: str) -> mx.array: + return mx.array( + self._tokenizer( + s, + return_tensors="np", + return_attention_mask=False, + )["input_ids"] + ) + + def decode(self, t: List[int], with_sep: bool = True) -> str: + tokens = self._tokenizer.convert_ids_to_tokens(t) + return "".join(t.replace("▁", " " if with_sep else "") for t in tokens) def _relative_position_bucket( @@ -350,36 +350,83 @@ class T5(nn.Module): ): return self.decode(decoder_inputs, self.encode(inputs))[0] + @classmethod + def sanitize(cls, weights): + shared_replacement_patterns = [ + (".block.", ".layers."), + (".k.", ".key_proj."), + (".o.", ".out_proj."), + (".q.", ".query_proj."), + (".v.", ".value_proj."), + ("shared.", "wte."), + ("lm_head.", "lm_head.linear."), + (".layer.0.layer_norm.", ".ln1."), + (".layer.1.layer_norm.", ".ln2."), + (".layer.2.layer_norm.", ".ln3."), + (".final_layer_norm.", ".ln."), + ( + "layers.0.layer.0.SelfAttention.relative_attention_bias.", + "relative_attention_bias.embeddings.", + ), + ] -class Tokenizer: - def __init__(self, config, model_name): - self._decoder_start_id = config.decoder_start_token_id - self._tokenizer = AutoTokenizer.from_pretrained( - model_name, - legacy=False, - model_max_length=getattr(config, "n_positions", 512), - ) + encoder_replacement_patterns = [ + (".layer.0.SelfAttention.", ".attention."), + (".layer.1.DenseReluDense.", ".dense."), + ] - @property - def eos_id(self) -> int: - return self._tokenizer.eos_token_id + decoder_replacement_patterns = [ + (".layer.0.SelfAttention.", ".self_attention."), + (".layer.1.EncDecAttention.", ".cross_attention."), + (".layer.2.DenseReluDense.", ".dense."), + ] - @property - def decoder_start_id(self) -> int: - return self._decoder_start_id + ignored_keys = [ + "decoder.layers.0.cross_attention.relative_attention_bias.weight" + ] - def encode(self, s: str) -> mx.array: - return mx.array( - self._tokenizer( - s, - return_tensors="np", - return_attention_mask=False, - )["input_ids"] - ) + def replace_key(key: str) -> str: + for old, new in shared_replacement_patterns: + key = key.replace(old, new) + if key.startswith("encoder."): + for old, new in encoder_replacement_patterns: + key = key.replace(old, new) + elif key.startswith("decoder."): + for old, new in decoder_replacement_patterns: + key = key.replace(old, new) + return key - def decode(self, t: List[int], with_sep: bool = True) -> str: - tokens = self._tokenizer.convert_ids_to_tokens(t) - return "".join(t.replace("▁", " " if with_sep else "") for t in tokens) + weights = {replace_key(k): v for k, v in weights.items()} + for key in ignored_keys: + if key in weights: + del weights[key] + return weights + + @classmethod + def from_pretrained( + cls, path_or_repo: str, dtype: mx.Dtype = mx.bfloat16 + ) -> tuple["T5", Tokenizer]: + from huggingface_hub import snapshot_download + + path = Path(path_or_repo) + if not path.exists(): + path = Path( + snapshot_download( + repo_id=path_or_repo, + allow_patterns=["*.json", "*.safetensors", "*.model"], + ) + ) + print(path) + + with open(path / "config.json", "r") as f: + config = SimpleNamespace(**json.load(f)) + + model = T5(config) + weights = mx.load(str(path / "model.safetensors")) + weights = cls.sanitize(weights) + weights = {k: v.astype(dtype) for k, v in weights.items()} + model.load_weights(list(weights.items())) + return model, Tokenizer(config, "t5-base") def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float] = 0.0): @@ -400,47 +447,6 @@ def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float] yield y.squeeze() -def replace_key(key: str) -> str: - for old, new in SHARED_REPLACEMENT_PATTERNS: - key = key.replace(old, new) - if key.startswith("encoder."): - for old, new in ENCODER_REPLACEMENT_PATTERNS: - key = key.replace(old, new) - elif key.startswith("decoder."): - for old, new in DECODER_REPLACEMENT_PATTERNS: - key = key.replace(old, new) - return key - - -def load_model(path_or_repo: str, dtype: str = "bfloat16"): - from huggingface_hub import snapshot_download - - path = Path(path_or_repo) - if not path.exists(): - path = Path( - snapshot_download( - repo_id=path_or_repo, - allow_patterns=["*.json", "*.safetensors", "*.model"], - ) - ) - print(path) - - with open(path / "config.json", "r") as f: - config = SimpleNamespace(**json.load(f)) - - dtype = getattr(mx, dtype) - - model = T5(config) - weights = mx.load(str(path / "model.safetensors")) - weights = {replace_key(k): v.astype(dtype) for k, v in weights.items()} - for key in IGNORED_KEYS: - del weights[key] - model.load_weights(list(weights.items())) - - mx.eval(model.parameters()) - return model, Tokenizer(config, "t5-base") - - if __name__ == "__main__": parser = argparse.ArgumentParser(description="T5 Inference script") parser.add_argument( @@ -486,7 +492,8 @@ if __name__ == "__main__": mx.random.seed(args.seed) - model, tokenizer = load_model(args.model, args.dtype) + dtype = getattr(mx, args.dtype) + model, tokenizer = T5.from_pretrained(args.model, dtype) if args.encode_only: print("[INFO] Encoding with T5...", flush=True)