mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-05 16:34:34 +08:00
Stable diffusion XL working; "add_embedding" layer not implemented
This commit is contained in:
@@ -136,9 +136,9 @@ class StableDiffusionXL(StableDiffusion):
|
|||||||
)
|
)
|
||||||
if cfg_weight > 1:
|
if cfg_weight > 1:
|
||||||
tokens2 += (
|
tokens2 += (
|
||||||
[self.tokenizer.tokenize(negative_text_2)]
|
[self.tokenizer_2.tokenize(negative_text_2)]
|
||||||
if text_2
|
if text_2
|
||||||
else [self.tokenizer.tokenize(negative_text)]
|
else [self.tokenizer_2.tokenize(negative_text)]
|
||||||
)
|
)
|
||||||
lengths2 = [len(t) for t in tokens2]
|
lengths2 = [len(t) for t in tokens2]
|
||||||
N = max(lengths2)
|
N = max(lengths2)
|
||||||
|
@@ -73,8 +73,8 @@ class CLIPTextModel(nn.Module):
|
|||||||
class CLIPTextModelWithProjection(CLIPTextModel):
|
class CLIPTextModelWithProjection(CLIPTextModel):
|
||||||
def __init__(self, config: CLIPTextModelConfig):
|
def __init__(self, config: CLIPTextModelConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.projection = nn.Linear(config.model_dims, config.projection_dims)
|
self.text_projection = nn.Linear(config.model_dims, config.projection_dims)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
x = super().__call__(x)
|
x = super().__call__(x)
|
||||||
return self.projection(x)
|
return self.text_projection(x)
|
||||||
|
@@ -38,7 +38,21 @@ class UNetConfig:
|
|||||||
transformer_layers_per_block: Tuple[int] = (1, 1, 1, 1)
|
transformer_layers_per_block: Tuple[int] = (1, 1, 1, 1)
|
||||||
num_attention_heads: Tuple[int] = (5, 10, 20, 20)
|
num_attention_heads: Tuple[int] = (5, 10, 20, 20)
|
||||||
cross_attention_dim: Tuple[int] = (1024,) * 4
|
cross_attention_dim: Tuple[int] = (1024,) * 4
|
||||||
norm_num_groups: int = 32
|
norm_num_groups: int = (32,)
|
||||||
|
down_block_types: Tuple[str] = (
|
||||||
|
(
|
||||||
|
"CrossAttnDownBlock2D",
|
||||||
|
"CrossAttnDownBlock2D",
|
||||||
|
"CrossAttnDownBlock2D",
|
||||||
|
"DownBlock2D",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
up_block_types: Tuple[str] = (
|
||||||
|
"UpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@@ -10,7 +10,7 @@ from mlx.utils import tree_unflatten
|
|||||||
from safetensors import safe_open as safetensor_open
|
from safetensors import safe_open as safetensor_open
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.utils import tree_unflatten
|
from mlx.utils import tree_unflatten, tree_flatten
|
||||||
|
|
||||||
from .clip import CLIPTextModel, CLIPTextModelWithProjection
|
from .clip import CLIPTextModel, CLIPTextModelWithProjection
|
||||||
from .config import AutoencoderConfig, CLIPTextModelConfig, DiffusionConfig, UNetConfig
|
from .config import AutoencoderConfig, CLIPTextModelConfig, DiffusionConfig, UNetConfig
|
||||||
@@ -195,6 +195,12 @@ def _load_safetensor_weights(mapper, model, weight_file, float16: bool = False):
|
|||||||
dtype = np.float16 if float16 else np.float32
|
dtype = np.float16 if float16 else np.float32
|
||||||
with safetensor_open(weight_file, framework="numpy") as f:
|
with safetensor_open(weight_file, framework="numpy") as f:
|
||||||
weights = _flatten([mapper(k, f.get_tensor(k).astype(dtype)) for k in f.keys()])
|
weights = _flatten([mapper(k, f.get_tensor(k).astype(dtype)) for k in f.keys()])
|
||||||
|
# debug
|
||||||
|
bar = tree_flatten(model)
|
||||||
|
missing_weights = [w[0] for w in weights if w[0] not in [b[0] for b in bar]]
|
||||||
|
if missing_weights:
|
||||||
|
print("warning: missing weights")
|
||||||
|
print(missing_weights)
|
||||||
model.update(tree_unflatten(weights))
|
model.update(tree_unflatten(weights))
|
||||||
|
|
||||||
|
|
||||||
@@ -226,6 +232,9 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
|
|||||||
else config["attention_head_dim"],
|
else config["attention_head_dim"],
|
||||||
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
|
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
|
||||||
norm_num_groups=config["norm_num_groups"],
|
norm_num_groups=config["norm_num_groups"],
|
||||||
|
down_block_types=config["down_block_types"],
|
||||||
|
up_block_types=config["up_block_types"],
|
||||||
|
transformer_layers_per_block=config["transformer_layers_per_block"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -308,7 +308,8 @@ class UNetModel(nn.Module):
|
|||||||
resnet_groups=config.norm_num_groups,
|
resnet_groups=config.norm_num_groups,
|
||||||
add_downsample=(i < len(config.block_out_channels) - 1),
|
add_downsample=(i < len(config.block_out_channels) - 1),
|
||||||
add_upsample=False,
|
add_upsample=False,
|
||||||
add_cross_attention=(i < len(config.block_out_channels) - 1),
|
add_cross_attention=config.down_block_types[i]
|
||||||
|
== "CrossAttnDownBlock2D",
|
||||||
)
|
)
|
||||||
for i, (in_channels, out_channels) in enumerate(
|
for i, (in_channels, out_channels) in enumerate(
|
||||||
zip(block_channels, block_channels[1:])
|
zip(block_channels, block_channels[1:])
|
||||||
@@ -357,7 +358,9 @@ class UNetModel(nn.Module):
|
|||||||
resnet_groups=config.norm_num_groups,
|
resnet_groups=config.norm_num_groups,
|
||||||
add_downsample=False,
|
add_downsample=False,
|
||||||
add_upsample=(i > 0),
|
add_upsample=(i > 0),
|
||||||
add_cross_attention=(i < len(config.block_out_channels) - 1),
|
# add_cross_attention=(i < len(config.block_out_channels) - 1),
|
||||||
|
add_cross_attention=list(reversed(config.up_block_types))[i]
|
||||||
|
== "CrossAttnUpBlock2D",
|
||||||
)
|
)
|
||||||
for i, (in_channels, out_channels, prev_out_channels) in reversed(
|
for i, (in_channels, out_channels, prev_out_channels) in reversed(
|
||||||
list(
|
list(
|
||||||
|
@@ -6,7 +6,7 @@ import mlx.core as mx
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from stable_diffusion import StableDiffusion
|
from stable_diffusion import StableDiffusion, StableDiffusionXL
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@@ -22,7 +22,8 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--output", default="out.png")
|
parser.add_argument("--output", default="out.png")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sd = StableDiffusion()
|
sd = StableDiffusionXL("stabilityai/stable-diffusion-xl-base-1.0")
|
||||||
|
# sd = StableDiffusion()
|
||||||
|
|
||||||
# Generate the latent vectors using diffusion
|
# Generate the latent vectors using diffusion
|
||||||
latents = sd.generate_latents(
|
latents = sd.generate_latents(
|
||||||
|
Reference in New Issue
Block a user