mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
partial implementation of SD XL, incl. CLIP with projection, but doesn't produce good output
This commit is contained in:
@@ -68,3 +68,13 @@ class CLIPTextModel(nn.Module):
|
||||
|
||||
# Apply the final layernorm and return
|
||||
return self.final_layer_norm(x)
|
||||
|
||||
|
||||
class CLIPTextModelWithProjection(CLIPTextModel):
|
||||
def __init__(self, config: CLIPTextModelConfig):
|
||||
super().__init__(config)
|
||||
self.projection = nn.Linear(config.model_dims, config.projection_dims)
|
||||
|
||||
def __call__(self, x):
|
||||
x = super().__call__(x)
|
||||
return self.projection(x)
|
||||
|
@@ -23,6 +23,7 @@ class CLIPTextModelConfig:
|
||||
num_heads: int = 16
|
||||
max_length: int = 77
|
||||
vocab_size: int = 49408
|
||||
projection_dims: int = 512
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@@ -12,7 +12,7 @@ from safetensors import safe_open as safetensor_open
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
from .clip import CLIPTextModel
|
||||
from .clip import CLIPTextModel, CLIPTextModelWithProjection
|
||||
from .config import AutoencoderConfig, CLIPTextModelConfig, DiffusionConfig, UNetConfig
|
||||
from .tokenizer import Tokenizer
|
||||
from .unet import UNetModel
|
||||
@@ -31,7 +31,39 @@ _MODELS = {
|
||||
"diffusion_config": "scheduler/scheduler_config.json",
|
||||
"tokenizer_vocab": "tokenizer/vocab.json",
|
||||
"tokenizer_merges": "tokenizer/merges.txt",
|
||||
}
|
||||
},
|
||||
# See https://huggingface.co/stabilityai/sdxl-turbo for the model details and license
|
||||
"stabilityai/sdxl-turbo": {
|
||||
"unet_config": "unet/config.json",
|
||||
"unet": "unet/diffusion_pytorch_model.safetensors",
|
||||
"text_encoder_config": "text_encoder/config.json",
|
||||
"text_encoder_2_config": "text_encoder_2/config.json",
|
||||
"text_encoder": "text_encoder/model.safetensors",
|
||||
"text_encoder_2": "text_encoder_2/model.safetensors",
|
||||
"vae_config": "vae/config.json",
|
||||
"vae": "vae/diffusion_pytorch_model.safetensors",
|
||||
"diffusion_config": "scheduler/scheduler_config.json",
|
||||
"tokenizer_vocab": "tokenizer/vocab.json",
|
||||
"tokenizer_merges": "tokenizer/merges.txt",
|
||||
"tokenizer_2_vocab": "tokenizer_2/vocab.json",
|
||||
"tokenizer_2_merges": "tokenizer_2/merges.txt",
|
||||
},
|
||||
# See https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0 for the model details and license
|
||||
"stabilityai/stable-diffusion-xl-base-1.0": {
|
||||
"unet_config": "unet/config.json",
|
||||
"unet": "unet/diffusion_pytorch_model.safetensors",
|
||||
"text_encoder_config": "text_encoder/config.json",
|
||||
"text_encoder_2_config": "text_encoder_2/config.json",
|
||||
"text_encoder": "text_encoder/model.safetensors",
|
||||
"text_encoder_2": "text_encoder_2/model.safetensors",
|
||||
"vae_config": "vae/config.json",
|
||||
"vae": "vae/diffusion_pytorch_model.safetensors",
|
||||
"diffusion_config": "scheduler/scheduler_config.json",
|
||||
"tokenizer_vocab": "tokenizer/vocab.json",
|
||||
"tokenizer_merges": "tokenizer/merges.txt",
|
||||
"tokenizer_2_vocab": "tokenizer_2/vocab.json",
|
||||
"tokenizer_2_merges": "tokenizer_2/merges.txt",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -205,27 +237,42 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
return model
|
||||
|
||||
|
||||
def load_text_encoder(key: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
def load_text_encoder(
|
||||
key: str = _DEFAULT_MODEL, float16: bool = False, name="text_encoder"
|
||||
):
|
||||
"""Load the stable diffusion text encoder from Hugging Face Hub."""
|
||||
_check_key(key, "load_text_encoder")
|
||||
_check_key(key, f"load_{name}")
|
||||
|
||||
# Download the config and create the model
|
||||
text_encoder_config = _MODELS[key]["text_encoder_config"]
|
||||
text_encoder_config = _MODELS[key][f"{name}_config"]
|
||||
with open(hf_hub_download(key, text_encoder_config)) as f:
|
||||
config = json.load(f)
|
||||
|
||||
model = CLIPTextModel(
|
||||
CLIPTextModelConfig(
|
||||
num_layers=config["num_hidden_layers"],
|
||||
model_dims=config["hidden_size"],
|
||||
num_heads=config["num_attention_heads"],
|
||||
max_length=config["max_position_embeddings"],
|
||||
vocab_size=config["vocab_size"],
|
||||
# Check whether encoder has projection layer
|
||||
if "CLIPTextModelWithProjection" in config["architectures"]:
|
||||
model = CLIPTextModelWithProjection(
|
||||
CLIPTextModelConfig(
|
||||
num_layers=config["num_hidden_layers"],
|
||||
model_dims=config["hidden_size"],
|
||||
num_heads=config["num_attention_heads"],
|
||||
max_length=config["max_position_embeddings"],
|
||||
vocab_size=config["vocab_size"],
|
||||
projection_dims=config["projection_dim"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
model = CLIPTextModel(
|
||||
CLIPTextModelConfig(
|
||||
num_layers=config["num_hidden_layers"],
|
||||
model_dims=config["hidden_size"],
|
||||
num_heads=config["num_attention_heads"],
|
||||
max_length=config["max_position_embeddings"],
|
||||
vocab_size=config["vocab_size"],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Download the weights and map them into the model
|
||||
text_encoder_weights = _MODELS[key]["text_encoder"]
|
||||
text_encoder_weights = _MODELS[key][name]
|
||||
weight_file = hf_hub_download(key, text_encoder_weights)
|
||||
_load_safetensor_weights(map_clip_text_encoder_weights, model, weight_file, float16)
|
||||
|
||||
@@ -277,14 +324,14 @@ def load_diffusion_config(key: str = _DEFAULT_MODEL):
|
||||
)
|
||||
|
||||
|
||||
def load_tokenizer(key: str = _DEFAULT_MODEL):
|
||||
_check_key(key, "load_tokenizer")
|
||||
def load_tokenizer(key: str = _DEFAULT_MODEL, name="tokenizer"):
|
||||
_check_key(key, f"load_{name}")
|
||||
|
||||
vocab_file = hf_hub_download(key, _MODELS[key]["tokenizer_vocab"])
|
||||
vocab_file = hf_hub_download(key, _MODELS[key][f"{name}_vocab"])
|
||||
with open(vocab_file, encoding="utf-8") as f:
|
||||
vocab = json.load(f)
|
||||
|
||||
merges_file = hf_hub_download(key, _MODELS[key]["tokenizer_merges"])
|
||||
merges_file = hf_hub_download(key, _MODELS[key][f"{name}_merges"])
|
||||
with open(merges_file, encoding="utf-8") as f:
|
||||
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
|
||||
bpe_merges = [tuple(m.split()) for m in bpe_merges]
|
||||
|
Reference in New Issue
Block a user