mlx-examples/flux/flux/utils.py
2024-10-11 21:17:41 -07:00

210 lines
5.9 KiB
Python

# Copyright © 2024 Apple Inc.
import json
import os
from dataclasses import dataclass
from typing import Optional
import mlx.core as mx
from huggingface_hub import hf_hub_download
from .autoencoder import AutoEncoder, AutoEncoderParams
from .clip import CLIPTextModel, CLIPTextModelConfig
from .model import Flux, FluxParams
from .t5 import T5Config, T5Encoder
from .tokenizers import CLIPTokenizer, T5Tokenizer
@dataclass
class ModelSpec:
params: FluxParams
ae_params: AutoEncoderParams
ckpt_path: Optional[str]
ae_path: Optional[str]
repo_id: Optional[str]
repo_flow: Optional[str]
repo_ae: Optional[str]
configs = {
"flux-dev": ModelSpec(
repo_id="black-forest-labs/FLUX.1-dev",
repo_flow="flux1-dev.safetensors",
repo_ae="ae.safetensors",
ckpt_path=os.getenv("FLUX_DEV"),
params=FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
"flux-schnell": ModelSpec(
repo_id="black-forest-labs/FLUX.1-schnell",
repo_flow="flux1-schnell.safetensors",
repo_ae="ae.safetensors",
ckpt_path=os.getenv("FLUX_SCHNELL"),
params=FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=False,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
}
def load_flow_model(name: str, hf_download: bool = True):
# Get the safetensors file to load
ckpt_path = configs[name].ckpt_path
# Download if needed
if (
ckpt_path is None
and configs[name].repo_id is not None
and configs[name].repo_flow is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
# Make the model
model = Flux(configs[name].params)
# Load the checkpoint if needed
if ckpt_path is not None:
weights = mx.load(ckpt_path)
weights = model.sanitize(weights)
model.load_weights(list(weights.items()))
return model
def load_ae(name: str, hf_download: bool = True):
# Get the safetensors file to load
ckpt_path = configs[name].ae_path
# Download if needed
if (
ckpt_path is None
and configs[name].repo_id is not None
and configs[name].repo_ae is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
# Make the autoencoder
ae = AutoEncoder(configs[name].ae_params)
# Load the checkpoint if needed
if ckpt_path is not None:
weights = mx.load(ckpt_path)
weights = ae.sanitize(weights)
ae.load_weights(list(weights.items()))
return ae
def load_clip(name: str):
# Load the config
config_path = hf_hub_download(configs[name].repo_id, "text_encoder/config.json")
with open(config_path) as f:
config = CLIPTextModelConfig.from_dict(json.load(f))
# Make the clip text encoder
clip = CLIPTextModel(config)
# Load the weights
ckpt_path = hf_hub_download(configs[name].repo_id, "text_encoder/model.safetensors")
weights = mx.load(ckpt_path)
weights = clip.sanitize(weights)
clip.load_weights(list(weights.items()))
return clip
def load_t5(name: str):
# Load the config
config_path = hf_hub_download(configs[name].repo_id, "text_encoder_2/config.json")
with open(config_path) as f:
config = T5Config.from_dict(json.load(f))
# Make the T5 model
t5 = T5Encoder(config)
# Load the weights
model_index = hf_hub_download(
configs[name].repo_id, "text_encoder_2/model.safetensors.index.json"
)
weight_files = set()
with open(model_index) as f:
for _, w in json.load(f)["weight_map"].items():
weight_files.add(w)
weights = {}
for w in weight_files:
w = f"text_encoder_2/{w}"
w = hf_hub_download(configs[name].repo_id, w)
weights.update(mx.load(w))
weights = t5.sanitize(weights)
t5.load_weights(list(weights.items()))
return t5
def load_clip_tokenizer(name: str):
vocab_file = hf_hub_download(configs[name].repo_id, "tokenizer/vocab.json")
with open(vocab_file, encoding="utf-8") as f:
vocab = json.load(f)
merges_file = hf_hub_download(configs[name].repo_id, "tokenizer/merges.txt")
with open(merges_file, encoding="utf-8") as f:
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
bpe_merges = [tuple(m.split()) for m in bpe_merges]
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
return CLIPTokenizer(bpe_ranks, vocab, max_length=77)
def load_t5_tokenizer(name: str, pad: bool = True):
model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model")
return T5Tokenizer(model_file, 256 if "schnell" in name else 512)