mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-21 02:51:13 +08:00
210 lines
5.9 KiB
Python
210 lines
5.9 KiB
Python
# Copyright © 2024 Apple Inc.
|
|
|
|
import json
|
|
import os
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
import mlx.core as mx
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
from .autoencoder import AutoEncoder, AutoEncoderParams
|
|
from .clip import CLIPTextModel, CLIPTextModelConfig
|
|
from .model import Flux, FluxParams
|
|
from .t5 import T5Config, T5Encoder
|
|
from .tokenizers import CLIPTokenizer, T5Tokenizer
|
|
|
|
|
|
@dataclass
|
|
class ModelSpec:
|
|
params: FluxParams
|
|
ae_params: AutoEncoderParams
|
|
ckpt_path: Optional[str]
|
|
ae_path: Optional[str]
|
|
repo_id: Optional[str]
|
|
repo_flow: Optional[str]
|
|
repo_ae: Optional[str]
|
|
|
|
|
|
configs = {
|
|
"flux-dev": ModelSpec(
|
|
repo_id="black-forest-labs/FLUX.1-dev",
|
|
repo_flow="flux1-dev.safetensors",
|
|
repo_ae="ae.safetensors",
|
|
ckpt_path=os.getenv("FLUX_DEV"),
|
|
params=FluxParams(
|
|
in_channels=64,
|
|
vec_in_dim=768,
|
|
context_in_dim=4096,
|
|
hidden_size=3072,
|
|
mlp_ratio=4.0,
|
|
num_heads=24,
|
|
depth=19,
|
|
depth_single_blocks=38,
|
|
axes_dim=[16, 56, 56],
|
|
theta=10_000,
|
|
qkv_bias=True,
|
|
guidance_embed=True,
|
|
),
|
|
ae_path=os.getenv("AE"),
|
|
ae_params=AutoEncoderParams(
|
|
resolution=256,
|
|
in_channels=3,
|
|
ch=128,
|
|
out_ch=3,
|
|
ch_mult=[1, 2, 4, 4],
|
|
num_res_blocks=2,
|
|
z_channels=16,
|
|
scale_factor=0.3611,
|
|
shift_factor=0.1159,
|
|
),
|
|
),
|
|
"flux-schnell": ModelSpec(
|
|
repo_id="black-forest-labs/FLUX.1-schnell",
|
|
repo_flow="flux1-schnell.safetensors",
|
|
repo_ae="ae.safetensors",
|
|
ckpt_path=os.getenv("FLUX_SCHNELL"),
|
|
params=FluxParams(
|
|
in_channels=64,
|
|
vec_in_dim=768,
|
|
context_in_dim=4096,
|
|
hidden_size=3072,
|
|
mlp_ratio=4.0,
|
|
num_heads=24,
|
|
depth=19,
|
|
depth_single_blocks=38,
|
|
axes_dim=[16, 56, 56],
|
|
theta=10_000,
|
|
qkv_bias=True,
|
|
guidance_embed=False,
|
|
),
|
|
ae_path=os.getenv("AE"),
|
|
ae_params=AutoEncoderParams(
|
|
resolution=256,
|
|
in_channels=3,
|
|
ch=128,
|
|
out_ch=3,
|
|
ch_mult=[1, 2, 4, 4],
|
|
num_res_blocks=2,
|
|
z_channels=16,
|
|
scale_factor=0.3611,
|
|
shift_factor=0.1159,
|
|
),
|
|
),
|
|
}
|
|
|
|
|
|
def load_flow_model(name: str, hf_download: bool = True):
|
|
# Get the safetensors file to load
|
|
ckpt_path = configs[name].ckpt_path
|
|
|
|
# Download if needed
|
|
if (
|
|
ckpt_path is None
|
|
and configs[name].repo_id is not None
|
|
and configs[name].repo_flow is not None
|
|
and hf_download
|
|
):
|
|
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
|
|
|
|
# Make the model
|
|
model = Flux(configs[name].params)
|
|
|
|
# Load the checkpoint if needed
|
|
if ckpt_path is not None:
|
|
weights = mx.load(ckpt_path)
|
|
weights = model.sanitize(weights)
|
|
model.load_weights(list(weights.items()))
|
|
|
|
return model
|
|
|
|
|
|
def load_ae(name: str, hf_download: bool = True):
|
|
# Get the safetensors file to load
|
|
ckpt_path = configs[name].ae_path
|
|
|
|
# Download if needed
|
|
if (
|
|
ckpt_path is None
|
|
and configs[name].repo_id is not None
|
|
and configs[name].repo_ae is not None
|
|
and hf_download
|
|
):
|
|
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
|
|
|
|
# Make the autoencoder
|
|
ae = AutoEncoder(configs[name].ae_params)
|
|
|
|
# Load the checkpoint if needed
|
|
if ckpt_path is not None:
|
|
weights = mx.load(ckpt_path)
|
|
weights = ae.sanitize(weights)
|
|
ae.load_weights(list(weights.items()))
|
|
|
|
return ae
|
|
|
|
|
|
def load_clip(name: str):
|
|
# Load the config
|
|
config_path = hf_hub_download(configs[name].repo_id, "text_encoder/config.json")
|
|
with open(config_path) as f:
|
|
config = CLIPTextModelConfig.from_dict(json.load(f))
|
|
|
|
# Make the clip text encoder
|
|
clip = CLIPTextModel(config)
|
|
|
|
# Load the weights
|
|
ckpt_path = hf_hub_download(configs[name].repo_id, "text_encoder/model.safetensors")
|
|
weights = mx.load(ckpt_path)
|
|
weights = clip.sanitize(weights)
|
|
clip.load_weights(list(weights.items()))
|
|
|
|
return clip
|
|
|
|
|
|
def load_t5(name: str):
|
|
# Load the config
|
|
config_path = hf_hub_download(configs[name].repo_id, "text_encoder_2/config.json")
|
|
with open(config_path) as f:
|
|
config = T5Config.from_dict(json.load(f))
|
|
|
|
# Make the T5 model
|
|
t5 = T5Encoder(config)
|
|
|
|
# Load the weights
|
|
model_index = hf_hub_download(
|
|
configs[name].repo_id, "text_encoder_2/model.safetensors.index.json"
|
|
)
|
|
weight_files = set()
|
|
with open(model_index) as f:
|
|
for _, w in json.load(f)["weight_map"].items():
|
|
weight_files.add(w)
|
|
weights = {}
|
|
for w in weight_files:
|
|
w = f"text_encoder_2/{w}"
|
|
w = hf_hub_download(configs[name].repo_id, w)
|
|
weights.update(mx.load(w))
|
|
weights = t5.sanitize(weights)
|
|
t5.load_weights(list(weights.items()))
|
|
|
|
return t5
|
|
|
|
|
|
def load_clip_tokenizer(name: str):
|
|
vocab_file = hf_hub_download(configs[name].repo_id, "tokenizer/vocab.json")
|
|
with open(vocab_file, encoding="utf-8") as f:
|
|
vocab = json.load(f)
|
|
|
|
merges_file = hf_hub_download(configs[name].repo_id, "tokenizer/merges.txt")
|
|
with open(merges_file, encoding="utf-8") as f:
|
|
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
|
|
bpe_merges = [tuple(m.split()) for m in bpe_merges]
|
|
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
|
|
|
|
return CLIPTokenizer(bpe_ranks, vocab, max_length=77)
|
|
|
|
|
|
def load_t5_tokenizer(name: str, pad: bool = True):
|
|
model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model")
|
|
return T5Tokenizer(model_file, 256 if "schnell" in name else 512)
|