mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Stable diffusion XL working; "add_embedding" layer not implemented
This commit is contained in:
@@ -136,9 +136,9 @@ class StableDiffusionXL(StableDiffusion):
|
||||
)
|
||||
if cfg_weight > 1:
|
||||
tokens2 += (
|
||||
[self.tokenizer.tokenize(negative_text_2)]
|
||||
[self.tokenizer_2.tokenize(negative_text_2)]
|
||||
if text_2
|
||||
else [self.tokenizer.tokenize(negative_text)]
|
||||
else [self.tokenizer_2.tokenize(negative_text)]
|
||||
)
|
||||
lengths2 = [len(t) for t in tokens2]
|
||||
N = max(lengths2)
|
||||
|
@@ -73,8 +73,8 @@ class CLIPTextModel(nn.Module):
|
||||
class CLIPTextModelWithProjection(CLIPTextModel):
|
||||
def __init__(self, config: CLIPTextModelConfig):
|
||||
super().__init__(config)
|
||||
self.projection = nn.Linear(config.model_dims, config.projection_dims)
|
||||
self.text_projection = nn.Linear(config.model_dims, config.projection_dims)
|
||||
|
||||
def __call__(self, x):
|
||||
x = super().__call__(x)
|
||||
return self.projection(x)
|
||||
return self.text_projection(x)
|
||||
|
@@ -38,7 +38,21 @@ class UNetConfig:
|
||||
transformer_layers_per_block: Tuple[int] = (1, 1, 1, 1)
|
||||
num_attention_heads: Tuple[int] = (5, 10, 20, 20)
|
||||
cross_attention_dim: Tuple[int] = (1024,) * 4
|
||||
norm_num_groups: int = 32
|
||||
norm_num_groups: int = (32,)
|
||||
down_block_types: Tuple[str] = (
|
||||
(
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
)
|
||||
up_block_types: Tuple[str] = (
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@@ -10,7 +10,7 @@ from mlx.utils import tree_unflatten
|
||||
from safetensors import safe_open as safetensor_open
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_unflatten
|
||||
from mlx.utils import tree_unflatten, tree_flatten
|
||||
|
||||
from .clip import CLIPTextModel, CLIPTextModelWithProjection
|
||||
from .config import AutoencoderConfig, CLIPTextModelConfig, DiffusionConfig, UNetConfig
|
||||
@@ -195,6 +195,12 @@ def _load_safetensor_weights(mapper, model, weight_file, float16: bool = False):
|
||||
dtype = np.float16 if float16 else np.float32
|
||||
with safetensor_open(weight_file, framework="numpy") as f:
|
||||
weights = _flatten([mapper(k, f.get_tensor(k).astype(dtype)) for k in f.keys()])
|
||||
# debug
|
||||
bar = tree_flatten(model)
|
||||
missing_weights = [w[0] for w in weights if w[0] not in [b[0] for b in bar]]
|
||||
if missing_weights:
|
||||
print("warning: missing weights")
|
||||
print(missing_weights)
|
||||
model.update(tree_unflatten(weights))
|
||||
|
||||
|
||||
@@ -226,6 +232,9 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
else config["attention_head_dim"],
|
||||
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
|
||||
norm_num_groups=config["norm_num_groups"],
|
||||
down_block_types=config["down_block_types"],
|
||||
up_block_types=config["up_block_types"],
|
||||
transformer_layers_per_block=config["transformer_layers_per_block"],
|
||||
)
|
||||
)
|
||||
|
||||
|
@@ -308,7 +308,8 @@ class UNetModel(nn.Module):
|
||||
resnet_groups=config.norm_num_groups,
|
||||
add_downsample=(i < len(config.block_out_channels) - 1),
|
||||
add_upsample=False,
|
||||
add_cross_attention=(i < len(config.block_out_channels) - 1),
|
||||
add_cross_attention=config.down_block_types[i]
|
||||
== "CrossAttnDownBlock2D",
|
||||
)
|
||||
for i, (in_channels, out_channels) in enumerate(
|
||||
zip(block_channels, block_channels[1:])
|
||||
@@ -357,7 +358,9 @@ class UNetModel(nn.Module):
|
||||
resnet_groups=config.norm_num_groups,
|
||||
add_downsample=False,
|
||||
add_upsample=(i > 0),
|
||||
add_cross_attention=(i < len(config.block_out_channels) - 1),
|
||||
# add_cross_attention=(i < len(config.block_out_channels) - 1),
|
||||
add_cross_attention=list(reversed(config.up_block_types))[i]
|
||||
== "CrossAttnUpBlock2D",
|
||||
)
|
||||
for i, (in_channels, out_channels, prev_out_channels) in reversed(
|
||||
list(
|
||||
|
@@ -6,7 +6,7 @@ import mlx.core as mx
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from stable_diffusion import StableDiffusion
|
||||
from stable_diffusion import StableDiffusion, StableDiffusionXL
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
@@ -22,7 +22,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--output", default="out.png")
|
||||
args = parser.parse_args()
|
||||
|
||||
sd = StableDiffusion()
|
||||
sd = StableDiffusionXL("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
# sd = StableDiffusion()
|
||||
|
||||
# Generate the latent vectors using diffusion
|
||||
latents = sd.generate_latents(
|
||||
|
Reference in New Issue
Block a user