mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-09 10:51:14 +08:00

- bert/model.py:10: tree_unflatten - bert/model.py:2: dataclass - bert/model.py:8: numpy - cifar/resnet.py:6: Any - clip/model.py:15: tree_flatten - clip/model.py:9: Union - gcn/main.py:8: download_cora - gcn/main.py:9: cross_entropy - llms/gguf_llm/models.py:12: tree_flatten, tree_unflatten - llms/gguf_llm/models.py:9: numpy - llms/mixtral/mixtral.py:12: tree_map - llms/mlx_lm/models/dbrx.py:2: Dict, Union - llms/mlx_lm/tuner/trainer.py:5: partial - llms/speculative_decoding/decoder.py:1: dataclass, field - llms/speculative_decoding/decoder.py:2: Optional - llms/speculative_decoding/decoder.py:5: mlx.nn - llms/speculative_decoding/decoder.py:6: numpy - llms/speculative_decoding/main.py:2: glob - llms/speculative_decoding/main.py:3: json - llms/speculative_decoding/main.py:5: Path - llms/speculative_decoding/main.py:8: mlx.nn - llms/speculative_decoding/model.py:6: tree_unflatten - llms/speculative_decoding/model.py:7: AutoTokenizer - llms/tests/test_lora.py:13: yaml_loader - lora/lora.py:14: tree_unflatten - lora/models.py:11: numpy - lora/models.py:3: glob - speechcommands/kwt.py:1: Any - speechcommands/main.py:7: mlx.data - stable_diffusion/stable_diffusion/model_io.py:4: partial - whisper/benchmark.py:5: sys - whisper/test.py:5: subprocess - whisper/whisper/audio.py:6: Optional - whisper/whisper/decoding.py:8: mlx.nn
331 lines
11 KiB
Python
331 lines
11 KiB
Python
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
import json
|
|
from typing import Optional
|
|
|
|
import mlx.core as mx
|
|
from huggingface_hub import hf_hub_download
|
|
from mlx.utils import tree_unflatten
|
|
|
|
from .clip import CLIPTextModel
|
|
from .config import AutoencoderConfig, CLIPTextModelConfig, DiffusionConfig, UNetConfig
|
|
from .tokenizer import Tokenizer
|
|
from .unet import UNetModel
|
|
from .vae import Autoencoder
|
|
|
|
_DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1-base"
|
|
_MODELS = {
|
|
# See https://huggingface.co/stabilityai/sdxl-turbo for the model details and license
|
|
"stabilityai/sdxl-turbo": {
|
|
"unet_config": "unet/config.json",
|
|
"unet": "unet/diffusion_pytorch_model.safetensors",
|
|
"text_encoder_config": "text_encoder/config.json",
|
|
"text_encoder": "text_encoder/model.safetensors",
|
|
"text_encoder_2_config": "text_encoder_2/config.json",
|
|
"text_encoder_2": "text_encoder_2/model.safetensors",
|
|
"vae_config": "vae/config.json",
|
|
"vae": "vae/diffusion_pytorch_model.safetensors",
|
|
"diffusion_config": "scheduler/scheduler_config.json",
|
|
"tokenizer_vocab": "tokenizer/vocab.json",
|
|
"tokenizer_merges": "tokenizer/merges.txt",
|
|
"tokenizer_2_vocab": "tokenizer_2/vocab.json",
|
|
"tokenizer_2_merges": "tokenizer_2/merges.txt",
|
|
},
|
|
# See https://huggingface.co/stabilityai/stable-diffusion-2-1-base for the model details and license
|
|
"stabilityai/stable-diffusion-2-1-base": {
|
|
"unet_config": "unet/config.json",
|
|
"unet": "unet/diffusion_pytorch_model.safetensors",
|
|
"text_encoder_config": "text_encoder/config.json",
|
|
"text_encoder": "text_encoder/model.safetensors",
|
|
"vae_config": "vae/config.json",
|
|
"vae": "vae/diffusion_pytorch_model.safetensors",
|
|
"diffusion_config": "scheduler/scheduler_config.json",
|
|
"tokenizer_vocab": "tokenizer/vocab.json",
|
|
"tokenizer_merges": "tokenizer/merges.txt",
|
|
},
|
|
}
|
|
|
|
|
|
def map_unet_weights(key, value):
|
|
# Map up/downsampling
|
|
if "downsamplers" in key:
|
|
key = key.replace("downsamplers.0.conv", "downsample")
|
|
if "upsamplers" in key:
|
|
key = key.replace("upsamplers.0.conv", "upsample")
|
|
|
|
# Map the mid block
|
|
if "mid_block.resnets.0" in key:
|
|
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
|
if "mid_block.attentions.0" in key:
|
|
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
|
|
if "mid_block.resnets.1" in key:
|
|
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
|
|
|
|
# Map attention layers
|
|
if "to_k" in key:
|
|
key = key.replace("to_k", "key_proj")
|
|
if "to_out.0" in key:
|
|
key = key.replace("to_out.0", "out_proj")
|
|
if "to_q" in key:
|
|
key = key.replace("to_q", "query_proj")
|
|
if "to_v" in key:
|
|
key = key.replace("to_v", "value_proj")
|
|
|
|
# Map transformer ffn
|
|
if "ff.net.2" in key:
|
|
key = key.replace("ff.net.2", "linear3")
|
|
if "ff.net.0" in key:
|
|
k1 = key.replace("ff.net.0.proj", "linear1")
|
|
k2 = key.replace("ff.net.0.proj", "linear2")
|
|
v1, v2 = mx.split(value, 2)
|
|
|
|
return [(k1, v1), (k2, v2)]
|
|
|
|
if "conv_shortcut.weight" in key:
|
|
value = value.squeeze()
|
|
|
|
# Transform the weights from 1x1 convs to linear
|
|
if len(value.shape) == 4 and ("proj_in" in key or "proj_out" in key):
|
|
value = value.squeeze()
|
|
|
|
if len(value.shape) == 4:
|
|
value = value.transpose(0, 2, 3, 1)
|
|
value = value.reshape(-1).reshape(value.shape)
|
|
|
|
return [(key, value)]
|
|
|
|
|
|
def map_clip_text_encoder_weights(key, value):
|
|
# Remove prefixes
|
|
if key.startswith("text_model."):
|
|
key = key[11:]
|
|
if key.startswith("embeddings."):
|
|
key = key[11:]
|
|
if key.startswith("encoder."):
|
|
key = key[8:]
|
|
|
|
# Map attention layers
|
|
if "self_attn." in key:
|
|
key = key.replace("self_attn.", "attention.")
|
|
if "q_proj." in key:
|
|
key = key.replace("q_proj.", "query_proj.")
|
|
if "k_proj." in key:
|
|
key = key.replace("k_proj.", "key_proj.")
|
|
if "v_proj." in key:
|
|
key = key.replace("v_proj.", "value_proj.")
|
|
|
|
# Map ffn layers
|
|
if "mlp.fc1" in key:
|
|
key = key.replace("mlp.fc1", "linear1")
|
|
if "mlp.fc2" in key:
|
|
key = key.replace("mlp.fc2", "linear2")
|
|
|
|
return [(key, value)]
|
|
|
|
|
|
def map_vae_weights(key, value):
|
|
# Map up/downsampling
|
|
if "downsamplers" in key:
|
|
key = key.replace("downsamplers.0.conv", "downsample")
|
|
if "upsamplers" in key:
|
|
key = key.replace("upsamplers.0.conv", "upsample")
|
|
|
|
# Map attention layers
|
|
if "to_k" in key:
|
|
key = key.replace("to_k", "key_proj")
|
|
if "to_out.0" in key:
|
|
key = key.replace("to_out.0", "out_proj")
|
|
if "to_q" in key:
|
|
key = key.replace("to_q", "query_proj")
|
|
if "to_v" in key:
|
|
key = key.replace("to_v", "value_proj")
|
|
|
|
# Map the mid block
|
|
if "mid_block.resnets.0" in key:
|
|
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
|
if "mid_block.attentions.0" in key:
|
|
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
|
|
if "mid_block.resnets.1" in key:
|
|
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
|
|
|
|
# Map the quant/post_quant layers
|
|
if "quant_conv" in key:
|
|
key = key.replace("quant_conv", "quant_proj")
|
|
value = value.squeeze()
|
|
|
|
# Map the conv_shortcut to linear
|
|
if "conv_shortcut.weight" in key:
|
|
value = value.squeeze()
|
|
|
|
if len(value.shape) == 4:
|
|
value = value.transpose(0, 2, 3, 1)
|
|
value = value.reshape(-1).reshape(value.shape)
|
|
|
|
return [(key, value)]
|
|
|
|
|
|
def _flatten(params):
|
|
return [(k, v) for p in params for (k, v) in p]
|
|
|
|
|
|
def _load_safetensor_weights(mapper, model, weight_file, float16: bool = False):
|
|
dtype = mx.float16 if float16 else mx.float32
|
|
weights = mx.load(weight_file)
|
|
weights = _flatten([mapper(k, v.astype(dtype)) for k, v in weights.items()])
|
|
model.update(tree_unflatten(weights))
|
|
|
|
|
|
def _check_key(key: str, part: str):
|
|
if key not in _MODELS:
|
|
raise ValueError(
|
|
f"[{part}] '{key}' model not found, choose one of {{{','.join(_MODELS.keys())}}}"
|
|
)
|
|
|
|
|
|
def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
|
|
"""Load the stable diffusion UNet from Hugging Face Hub."""
|
|
_check_key(key, "load_unet")
|
|
|
|
# Download the config and create the model
|
|
unet_config = _MODELS[key]["unet_config"]
|
|
with open(hf_hub_download(key, unet_config)) as f:
|
|
config = json.load(f)
|
|
|
|
n_blocks = len(config["block_out_channels"])
|
|
model = UNetModel(
|
|
UNetConfig(
|
|
in_channels=config["in_channels"],
|
|
out_channels=config["out_channels"],
|
|
block_out_channels=config["block_out_channels"],
|
|
layers_per_block=[config["layers_per_block"]] * n_blocks,
|
|
transformer_layers_per_block=config.get(
|
|
"transformer_layers_per_block", (1,) * 4
|
|
),
|
|
num_attention_heads=(
|
|
[config["attention_head_dim"]] * n_blocks
|
|
if isinstance(config["attention_head_dim"], int)
|
|
else config["attention_head_dim"]
|
|
),
|
|
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
|
|
norm_num_groups=config["norm_num_groups"],
|
|
down_block_types=config["down_block_types"],
|
|
up_block_types=config["up_block_types"][::-1],
|
|
addition_embed_type=config.get("addition_embed_type", None),
|
|
addition_time_embed_dim=config.get("addition_time_embed_dim", None),
|
|
projection_class_embeddings_input_dim=config.get(
|
|
"projection_class_embeddings_input_dim", None
|
|
),
|
|
)
|
|
)
|
|
|
|
# Download the weights and map them into the model
|
|
unet_weights = _MODELS[key]["unet"]
|
|
weight_file = hf_hub_download(key, unet_weights)
|
|
_load_safetensor_weights(map_unet_weights, model, weight_file, float16)
|
|
|
|
return model
|
|
|
|
|
|
def load_text_encoder(
|
|
key: str = _DEFAULT_MODEL,
|
|
float16: bool = False,
|
|
model_key: str = "text_encoder",
|
|
config_key: Optional[str] = None,
|
|
):
|
|
"""Load the stable diffusion text encoder from Hugging Face Hub."""
|
|
_check_key(key, "load_text_encoder")
|
|
|
|
config_key = config_key or (model_key + "_config")
|
|
|
|
# Download the config and create the model
|
|
text_encoder_config = _MODELS[key][config_key]
|
|
with open(hf_hub_download(key, text_encoder_config)) as f:
|
|
config = json.load(f)
|
|
|
|
with_projection = "WithProjection" in config["architectures"][0]
|
|
|
|
model = CLIPTextModel(
|
|
CLIPTextModelConfig(
|
|
num_layers=config["num_hidden_layers"],
|
|
model_dims=config["hidden_size"],
|
|
num_heads=config["num_attention_heads"],
|
|
max_length=config["max_position_embeddings"],
|
|
vocab_size=config["vocab_size"],
|
|
projection_dim=config["projection_dim"] if with_projection else None,
|
|
hidden_act=config.get("hidden_act", "quick_gelu"),
|
|
)
|
|
)
|
|
|
|
# Download the weights and map them into the model
|
|
text_encoder_weights = _MODELS[key][model_key]
|
|
weight_file = hf_hub_download(key, text_encoder_weights)
|
|
_load_safetensor_weights(map_clip_text_encoder_weights, model, weight_file, float16)
|
|
|
|
return model
|
|
|
|
|
|
def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False):
|
|
"""Load the stable diffusion autoencoder from Hugging Face Hub."""
|
|
_check_key(key, "load_autoencoder")
|
|
|
|
# Download the config and create the model
|
|
vae_config = _MODELS[key]["vae_config"]
|
|
with open(hf_hub_download(key, vae_config)) as f:
|
|
config = json.load(f)
|
|
|
|
model = Autoencoder(
|
|
AutoencoderConfig(
|
|
in_channels=config["in_channels"],
|
|
out_channels=config["out_channels"],
|
|
latent_channels_out=2 * config["latent_channels"],
|
|
latent_channels_in=config["latent_channels"],
|
|
block_out_channels=config["block_out_channels"],
|
|
layers_per_block=config["layers_per_block"],
|
|
norm_num_groups=config["norm_num_groups"],
|
|
scaling_factor=config.get("scaling_factor", 0.18215),
|
|
)
|
|
)
|
|
|
|
# Download the weights and map them into the model
|
|
vae_weights = _MODELS[key]["vae"]
|
|
weight_file = hf_hub_download(key, vae_weights)
|
|
_load_safetensor_weights(map_vae_weights, model, weight_file, float16)
|
|
|
|
return model
|
|
|
|
|
|
def load_diffusion_config(key: str = _DEFAULT_MODEL):
|
|
"""Load the stable diffusion config from Hugging Face Hub."""
|
|
_check_key(key, "load_diffusion_config")
|
|
|
|
diffusion_config = _MODELS[key]["diffusion_config"]
|
|
with open(hf_hub_download(key, diffusion_config)) as f:
|
|
config = json.load(f)
|
|
|
|
return DiffusionConfig(
|
|
beta_start=config["beta_start"],
|
|
beta_end=config["beta_end"],
|
|
beta_schedule=config["beta_schedule"],
|
|
num_train_steps=config["num_train_timesteps"],
|
|
)
|
|
|
|
|
|
def load_tokenizer(
|
|
key: str = _DEFAULT_MODEL,
|
|
vocab_key: str = "tokenizer_vocab",
|
|
merges_key: str = "tokenizer_merges",
|
|
):
|
|
_check_key(key, "load_tokenizer")
|
|
|
|
vocab_file = hf_hub_download(key, _MODELS[key][vocab_key])
|
|
with open(vocab_file, encoding="utf-8") as f:
|
|
vocab = json.load(f)
|
|
|
|
merges_file = hf_hub_download(key, _MODELS[key][merges_key])
|
|
with open(merges_file, encoding="utf-8") as f:
|
|
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
|
|
bpe_merges = [tuple(m.split()) for m in bpe_merges]
|
|
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
|
|
|
|
return Tokenizer(bpe_ranks, vocab)
|