This commit is contained in:
nvnsho 2025-07-31 02:36:11 -07:00 committed by GitHub
commit a330537a95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 6217 additions and 0 deletions

36
video/Wan2.2/.gitignore vendored Normal file
View File

@ -0,0 +1,36 @@
.*
*.py[cod]
# *.jpg
*.jpeg
# *.png
*.gif
*.bmp
*.mp4
*.mov
*.mkv
*.log
*.zip
*.pt
*.pth
*.ckpt
*.safetensors
*.json
# *.txt
*.backup
*.pkl
*.html
*.pdf
*.whl
cache
__pycache__/
storage/
samples/
!.gitignore
!requirements.txt
.DS_Store
*DS_Store
google/
Wan2.2-*
venv_wan/
venv_wan_py310/
.venv/

201
video/Wan2.2/LICENSE.txt Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

5
video/Wan2.2/Makefile Normal file
View File

@ -0,0 +1,5 @@
.PHONY: format
format:
isort generate.py wan
yapf -i -r *.py generate.py wan

41
video/Wan2.2/README.md Normal file
View File

@ -0,0 +1,41 @@
# Wan2.2
#### Running
| Models | Download Links | Description |
|--------------------|---------------------------------------------------------------------------------------------------------------------------------------------|-------------|
| T2V-A14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B) 🤖 [ModelScope](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B) | Text-to-Video MoE model, supports 480P & 720P |
| I2V-A14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B) 🤖 [ModelScope](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B) | Image-to-Video MoE model, supports 480P & 720P |
| TI2V-5B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B) 🤖 [ModelScope](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B) | High-compression VAE, T2V+I2V, supports 720P |
> 💡Note:
> Currently the MLX port only supports T2V-A14B.
Download models using huggingface-cli:
``` sh
pip install "huggingface_hub[cli]"
huggingface-cli download Wan-AI/Wan2.2-T2V-A14B --local-dir ./Wan2.2-T2V-A14B
```
Download models using modelscope-cli:
``` sh
pip install modelscope
modelscope download Wan-AI/Wan2.2-T2V-A14B --local_dir ./Wan2.2-T2V-A14B
```
### Example
```
python generate.py --task t2v-A14B --size '480*832' --ckpt_dir ./Wan2.2-T2V-A14B --offload_model True --convert_model_dtype --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." --frame_num 16 --sample_steps 10
```
## Citation
Thanks to the WAN team for the original PyTorch implementation. They can be cited as follows:
```
@article{wan2025,
title={Wan: Open and Advanced Large-Scale Video Generative Models},
author={Team Wan and Ang Wang and Baole Ai and Bin Wen and Chaojie Mao and Chen-Wei Xie and Di Chen and Feiwu Yu and Haiming Zhao and Jianxiao Yang and Jianyuan Zeng and Jiayu Wang and Jingfeng Zhang and Jingren Zhou and Jinkai Wang and Jixuan Chen and Kai Zhu and Kang Zhao and Keyu Yan and Lianghua Huang and Mengyang Feng and Ningyi Zhang and Pandeng Li and Pingyu Wu and Ruihang Chu and Ruili Feng and Shiwei Zhang and Siyang Sun and Tao Fang and Tianxing Wang and Tianyi Gui and Tingyu Weng and Tong Shen and Wei Lin and Wei Wang and Wei Wang and Wenmeng Zhou and Wente Wang and Wenting Shen and Wenyuan Yu and Xianzhong Shi and Xiaoming Huang and Xin Xu and Yan Kou and Yangyu Lv and Yifei Li and Yijing Liu and Yiming Wang and Yingya Zhang and Yitong Huang and Yong Li and You Wu and Yu Liu and Yulin Pan and Yun Zheng and Yuntao Hong and Yupeng Shi and Yutong Feng and Zeyinzi Jiang and Zhen Han and Zhi-Fan Wu and Ziyu Liu},
journal = {arXiv preprint arXiv:2503.20314},
year={2025}
}
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 55 KiB

388
video/Wan2.2/generate.py Normal file
View File

@ -0,0 +1,388 @@
# MLX implementation of generate.py
import argparse
import logging
import os
import sys
import warnings
from datetime import datetime
warnings.filterwarnings('ignore')
import random
import mlx.core as mx
import mlx.nn as nn
from PIL import Image
import numpy as np
# Note: MLX doesn't have built-in distributed training support like PyTorch
# For distributed training, you would need to implement custom logic or use MPI
import wan # Assuming wan has been converted to MLX
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.utils.utils import save_video, str2bool
EXAMPLE_PROMPT = {
"t2v-A14B": {
"prompt":
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"i2v-A14B": {
"prompt":
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
"image":
"examples/i2v_input.JPG",
},
"ti2v-5B": {
"prompt":
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
}
def _validate_args(args):
# Basic check
assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
if args.image is None and "image" in EXAMPLE_PROMPT[args.task]:
args.image = EXAMPLE_PROMPT[args.task]["image"]
if args.task == "i2v-A14B":
assert args.image is not None, "Please specify the image path for i2v."
cfg = WAN_CONFIGS[args.task]
if args.sample_steps is None:
args.sample_steps = cfg.sample_steps
if args.sample_shift is None:
args.sample_shift = cfg.sample_shift
if args.sample_guide_scale is None:
args.sample_guide_scale = cfg.sample_guide_scale
if args.frame_num is None:
args.frame_num = cfg.frame_num
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
0, sys.maxsize)
# Size check
assert args.size in SUPPORTED_SIZES[
args.
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a image or video from a text prompt or image using Wan"
)
parser.add_argument(
"--task",
type=str,
default="t2v-A14B",
choices=list(WAN_CONFIGS.keys()),
help="The task to run.")
parser.add_argument(
"--size",
type=str,
default="1280*720",
choices=list(SIZE_CONFIGS.keys()),
help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
)
parser.add_argument(
"--frame_num",
type=int,
default=None,
help="How many frames of video are generated. The number should be 4n+1"
)
parser.add_argument(
"--ckpt_dir",
type=str,
default=None,
help="The path to the checkpoint directory.")
parser.add_argument(
"--offload_model",
type=str2bool,
default=None,
help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
)
parser.add_argument(
"--ulysses_size",
type=int,
default=1,
help="The size of the ulysses parallelism in DiT.")
parser.add_argument(
"--t5_fsdp",
action="store_true",
default=False,
help="Whether to use FSDP for T5. (Note: MLX doesn't have built-in FSDP)")
parser.add_argument(
"--t5_cpu",
action="store_true",
default=False,
help="Whether to place T5 model on CPU. (Note: MLX runs on unified memory)")
parser.add_argument(
"--dit_fsdp",
action="store_true",
default=False,
help="Whether to use FSDP for DiT. (Note: MLX doesn't have built-in FSDP)")
parser.add_argument(
"--save_file",
type=str,
default=None,
help="The file to save the generated video to.")
parser.add_argument(
"--prompt",
type=str,
default=None,
help="The prompt to generate the video from.")
parser.add_argument(
"--use_prompt_extend",
action="store_true",
default=False,
help="Whether to use prompt extend.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
parser.add_argument(
"--prompt_extend_target_lang",
type=str,
default="zh",
choices=["zh", "en"],
help="The target language of prompt extend.")
parser.add_argument(
"--base_seed",
type=int,
default=-1,
help="The seed to use for generating the video.")
parser.add_argument(
"--image",
type=str,
default=None,
help="The image to generate the video from.")
parser.add_argument(
"--sample_solver",
type=str,
default='unipc',
choices=['unipc', 'dpm++'],
help="The solver used to sample.")
parser.add_argument(
"--sample_steps", type=int, default=None, help="The sampling steps.")
parser.add_argument(
"--sample_shift",
type=float,
default=None,
help="Sampling shift factor for flow matching schedulers.")
parser.add_argument(
"--sample_guide_scale",
type=float,
default=None,
help="Classifier free guidance scale.")
parser.add_argument(
"--convert_model_dtype",
action="store_true",
default=False,
help="Whether to convert model parameters dtype.")
args = parser.parse_args()
_validate_args(args)
return args
def _init_logging(rank):
# logging
if rank == 0:
# set format
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s",
handlers=[logging.StreamHandler(stream=sys.stdout)])
else:
logging.basicConfig(level=logging.ERROR)
def generate(args):
# MLX doesn't have built-in distributed training like PyTorch
# For single-device execution, we'll simulate rank 0
rank = 0
world_size = 1
local_rank = 0
# Check for distributed execution environment variables
# Note: Actual distributed implementation would require custom logic
if "RANK" in os.environ:
logging.warning("MLX doesn't have built-in distributed training. Running on single device.")
_init_logging(rank)
if args.offload_model is None:
args.offload_model = False
logging.info(
f"offload_model is not specified, set to {args.offload_model}.")
# MLX doesn't support FSDP or distributed training out of the box
if args.t5_fsdp or args.dit_fsdp:
logging.warning("FSDP is not supported in MLX. Ignoring FSDP flags.")
args.t5_fsdp = False
args.dit_fsdp = False
if args.ulysses_size > 1:
logging.warning("Sequence parallel is not supported in MLX single-device mode. Setting ulysses_size to 1.")
args.ulysses_size = 1
if args.use_prompt_extend:
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model,
task=args.task,
is_vl=args.image is not None)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model,
task=args.task,
is_vl=args.image is not None,
device="mlx") # MLX uses unified memory
else:
raise NotImplementedError(
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
cfg = WAN_CONFIGS[args.task]
logging.info(f"Generation job args: {args}")
logging.info(f"Generation model config: {cfg}")
logging.info(f"Input prompt: {args.prompt}")
img = None
if args.image is not None:
img = Image.open(args.image).convert("RGB")
logging.info(f"Input image: {args.image}")
# prompt extend
if args.use_prompt_extend:
logging.info("Extending prompt ...")
prompt_output = prompt_expander(
args.prompt,
image=img,
tar_lang=args.prompt_extend_target_lang,
seed=args.base_seed)
if prompt_output.status == False:
logging.info(
f"Extending prompt failed: {prompt_output.message}")
logging.info("Falling back to original prompt.")
input_prompt = args.prompt
else:
input_prompt = prompt_output.prompt
args.prompt = input_prompt
logging.info(f"Extended prompt: {args.prompt}")
if "t2v" in args.task:
logging.info("Creating WanT2V pipeline.")
wan_t2v = wan.WanT2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
convert_model_dtype=args.convert_model_dtype,
)
logging.info(f"Generating video ...")
video = wan_t2v.generate(
args.prompt,
size=SIZE_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
elif "ti2v" in args.task:
logging.info("Creating WanTI2V pipeline.")
wan_ti2v = wan.WanTI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=None, # MLX uses unified memory
rank=rank,
t5_fsdp=False, # Not supported in MLX
dit_fsdp=False, # Not supported in MLX
use_sp=False, # Not supported in MLX
t5_cpu=False, # MLX uses unified memory
convert_model_dtype=args.convert_model_dtype,
)
logging.info(f"Generating video ...")
video = wan_ti2v.generate(
args.prompt,
img=img,
size=SIZE_CONFIGS[args.size],
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
else:
logging.info("Creating WanI2V pipeline.")
wan_i2v = wan.WanI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=None, # MLX uses unified memory
rank=rank,
t5_fsdp=False, # Not supported in MLX
dit_fsdp=False, # Not supported in MLX
use_sp=False, # Not supported in MLX
t5_cpu=False, # MLX uses unified memory
convert_model_dtype=args.convert_model_dtype,
)
logging.info("Generating video ...")
video = wan_i2v.generate(
args.prompt,
img,
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
if args.save_file is None:
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
formatted_prompt = args.prompt.replace(" ", "_").replace("/",
"_")[:50]
suffix = '.mp4'
args.save_file = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{formatted_prompt}_{formatted_time}" + suffix
logging.info(f"Saving generated video to {args.save_file}")
# Don't convert to numpy - keep as MLX array
save_video(
tensor=video[None], # Just add batch dimension, keep as MLX array
save_file=args.save_file,
fps=cfg.sample_fps,
nrow=1,
normalize=True,
value_range=(-1, 1)
)
logging.info("Finished.")
if __name__ == "__main__":
args = _parse_args()
generate(args)

View File

@ -0,0 +1,66 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "wan"
version = "2.2.0"
description = "Wan: Open and Advanced Large-Scale Video Generative Models"
authors = [
{ name = "Wan Team", email = "wan.ai@alibabacloud.com" }
]
license = { file = "LICENSE.txt" }
readme = "README.md"
requires-python = ">=3.10,<4.0"
dependencies = [
"torch>=2.4.0",
"torchvision>=0.19.0",
"opencv-python>=4.9.0.80",
"diffusers>=0.31.0",
"transformers>=4.49.0",
"tokenizers>=0.20.3",
"accelerate>=1.1.1",
"tqdm",
"imageio",
"easydict",
"ftfy",
"dashscope",
"imageio-ffmpeg",
"flash_attn",
"numpy>=1.23.5,<2"
]
[project.optional-dependencies]
dev = [
"pytest",
"black",
"flake8",
"isort",
"mypy",
"huggingface-hub[cli]"
]
[project.urls]
homepage = "https://wanxai.com"
documentation = "https://github.com/Wan-Video/Wan2.2"
repository = "https://github.com/Wan-Video/Wan2.2"
huggingface = "https://huggingface.co/Wan-AI/"
modelscope = "https://modelscope.cn/organization/Wan-AI"
discord = "https://discord.gg/p5XbdQV7"
[tool.setuptools]
packages = ["wan"]
[tool.setuptools.package-data]
"wan" = ["**/*.py"]
[tool.black]
line-length = 88
[tool.isort]
profile = "black"
[tool.mypy]
strict = true

View File

@ -0,0 +1,15 @@
torch>=2.4.0
mlx
torchvision>=0.19.0
opencv-python>=4.9.0.80
diffusers>=0.31.0
transformers>=4.49.0
tokenizers>=0.20.3
accelerate>=1.1.1
tqdm
imageio[ffmpeg]
easydict
ftfy
dashscope
imageio-ffmpeg
numpy>=1.23.5,<2

View File

@ -0,0 +1,2 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from .text2video import WanT2V

View File

@ -0,0 +1,39 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import copy
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
from .wan_i2v_A14B import i2v_A14B
from .wan_t2v_A14B import t2v_A14B
from .wan_ti2v_5B import ti2v_5B
WAN_CONFIGS = {
't2v-A14B': t2v_A14B,
'i2v-A14B': i2v_A14B,
'ti2v-5B': ti2v_5B,
}
SIZE_CONFIGS = {
'720*1280': (720, 1280),
'1280*720': (1280, 720),
'480*832': (480, 832),
'832*480': (832, 480),
'704*1280': (704, 1280),
'1280*704': (1280, 704)
}
MAX_AREA_CONFIGS = {
'720*1280': 720 * 1280,
'1280*720': 1280 * 720,
'480*832': 480 * 832,
'832*480': 832 * 480,
'704*1280': 704 * 1280,
'1280*704': 1280 * 704,
}
SUPPORTED_SIZES = {
't2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
'i2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
'ti2v-5B': ('704*1280', '1280*704'),
}

View File

@ -0,0 +1,20 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
from easydict import EasyDict
#------------------------ Wan shared config ------------------------#
wan_shared_cfg = EasyDict()
# t5
wan_shared_cfg.t5_model = 'umt5_xxl'
wan_shared_cfg.t5_dtype = torch.bfloat16
wan_shared_cfg.text_len = 512
# transformer
wan_shared_cfg.param_dtype = torch.bfloat16
# inference
wan_shared_cfg.num_train_timesteps = 1000
wan_shared_cfg.sample_fps = 16
wan_shared_cfg.sample_neg_prompt = '色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走'
wan_shared_cfg.frame_num = 81

View File

@ -0,0 +1,37 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
from easydict import EasyDict
from .shared_config import wan_shared_cfg
#------------------------ Wan I2V A14B ------------------------#
i2v_A14B = EasyDict(__name__='Config: Wan I2V A14B')
i2v_A14B.update(wan_shared_cfg)
i2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
i2v_A14B.t5_tokenizer = 'google/umt5-xxl'
# vae
i2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth'
i2v_A14B.vae_stride = (4, 8, 8)
# transformer
i2v_A14B.patch_size = (1, 2, 2)
i2v_A14B.dim = 5120
i2v_A14B.ffn_dim = 13824
i2v_A14B.freq_dim = 256
i2v_A14B.num_heads = 40
i2v_A14B.num_layers = 40
i2v_A14B.window_size = (-1, -1)
i2v_A14B.qk_norm = True
i2v_A14B.cross_attn_norm = True
i2v_A14B.eps = 1e-6
i2v_A14B.low_noise_checkpoint = 'low_noise_model'
i2v_A14B.high_noise_checkpoint = 'high_noise_model'
# inference
i2v_A14B.sample_shift = 5.0
i2v_A14B.sample_steps = 40
i2v_A14B.boundary = 0.900
i2v_A14B.sample_guide_scale = (3.5, 3.5) # low noise, high noise

View File

@ -0,0 +1,37 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from easydict import EasyDict
from .shared_config import wan_shared_cfg
#------------------------ Wan T2V A14B ------------------------#
t2v_A14B = EasyDict(__name__='Config: Wan T2V A14B')
t2v_A14B.update(wan_shared_cfg)
# t5
t2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.safetensors'
t2v_A14B.t5_tokenizer = 'google/umt5-xxl'
# vae
t2v_A14B.vae_checkpoint = 'Wan2.1_VAE.safetensors'
t2v_A14B.vae_stride = (4, 8, 8)
# transformer
t2v_A14B.patch_size = (1, 2, 2)
t2v_A14B.dim = 5120
t2v_A14B.ffn_dim = 13824
t2v_A14B.freq_dim = 256
t2v_A14B.num_heads = 40
t2v_A14B.num_layers = 40
t2v_A14B.window_size = (-1, -1)
t2v_A14B.qk_norm = True
t2v_A14B.cross_attn_norm = True
t2v_A14B.eps = 1e-6
t2v_A14B.low_noise_checkpoint = 'low_noise_model'
t2v_A14B.high_noise_checkpoint = 'high_noise_model'
# inference
t2v_A14B.sample_shift = 12.0
t2v_A14B.sample_steps = 40
t2v_A14B.boundary = 0.875
t2v_A14B.sample_guide_scale = (3.0, 4.0) # low noise, high noise

View File

@ -0,0 +1,36 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from easydict import EasyDict
from .shared_config import wan_shared_cfg
#------------------------ Wan TI2V 5B ------------------------#
ti2v_5B = EasyDict(__name__='Config: Wan TI2V 5B')
ti2v_5B.update(wan_shared_cfg)
# t5
ti2v_5B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
ti2v_5B.t5_tokenizer = 'google/umt5-xxl'
# vae
ti2v_5B.vae_checkpoint = 'Wan2.2_VAE.pth'
ti2v_5B.vae_stride = (4, 16, 16)
# transformer
ti2v_5B.patch_size = (1, 2, 2)
ti2v_5B.dim = 3072
ti2v_5B.ffn_dim = 14336
ti2v_5B.freq_dim = 256
ti2v_5B.num_heads = 24
ti2v_5B.num_layers = 30
ti2v_5B.window_size = (-1, -1)
ti2v_5B.qk_norm = True
ti2v_5B.cross_attn_norm = True
ti2v_5B.eps = 1e-6
# inference
ti2v_5B.sample_fps = 24
ti2v_5B.sample_shift = 5.0
ti2v_5B.sample_steps = 50
ti2v_5B.sample_guide_scale = 5.0
ti2v_5B.frame_num = 121

View File

@ -0,0 +1,17 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from .model import WanModel
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
from .tokenizers import HuggingfaceTokenizer
from .vae2_1 import Wan2_1_VAE
__all__ = [
'Wan2_1_VAE',
'Wan2_2_VAE',
'WanModel',
'T5Model',
'T5Encoder',
'T5Decoder',
'T5EncoderModel',
'HuggingfaceTokenizer',
'mlx_attention',
]

View File

@ -0,0 +1,660 @@
# MLX implementation of model.py
import math
from typing import List, Tuple, Optional
import mlx.core as mx
import mlx.nn as nn
import numpy as np
__all__ = ['WanModel']
def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
half = dim // 2
position = position.astype(mx.float32)
# calculation
arange_vals = mx.arange(half).astype(mx.float32)
div_term = mx.power(10000, -arange_vals / half)
sinusoid = position[:, None] @ div_term[None, :]
x = mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=1)
return x
def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
positions = mx.arange(max_seq_len).astype(mx.float32)
freqs = mx.arange(0, dim, 2).astype(mx.float32) / dim
freqs = 1.0 / mx.power(theta, freqs)
angles = positions[:, None] @ freqs[None, :]
# Store as [max_seq_len, dim//2, 2] where last dimension is [real, imag]
freqs_complex = mx.stack([mx.cos(angles), mx.sin(angles)], axis=-1)
return freqs_complex
def rope_apply(x, grid_sizes, freqs):
n, c = x.shape[2], x.shape[3] // 2
# split freqs based on dimension allocation
split_sizes = [c - 2 * (c // 3), c // 3, c // 3]
freqs_splits = []
start = 0
for size in split_sizes:
freqs_splits.append(freqs[:, start:start+size, :])
start += size
# loop over samples
output = []
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
seq_len = f * h * w
# reshape x_i to complex representation
x_i = x[i, :seq_len].reshape(seq_len, n, c, 2)
# precompute frequency multipliers for each dimension
freqs_f = freqs_splits[0][:f].reshape(f, 1, 1, -1, 2)
freqs_f = mx.tile(freqs_f, (1, h, w, 1, 1)).reshape(f * h * w, -1, 2)
freqs_h = freqs_splits[1][:h].reshape(1, h, 1, -1, 2)
freqs_h = mx.tile(freqs_h, (f, 1, w, 1, 1)).reshape(f * h * w, -1, 2)
freqs_w = freqs_splits[2][:w].reshape(1, 1, w, -1, 2)
freqs_w = mx.tile(freqs_w, (f, h, 1, 1, 1)).reshape(f * h * w, -1, 2)
# Concatenate frequency components
freqs_i = mx.concatenate([freqs_f, freqs_h, freqs_w], axis=1)
freqs_i = freqs_i[:seq_len].reshape(seq_len, 1, c, 2)
# apply rotary embedding (complex multiplication)
x_real = x_i[..., 0]
x_imag = x_i[..., 1]
freqs_real = freqs_i[..., 0]
freqs_imag = freqs_i[..., 1]
out_real = x_real * freqs_real - x_imag * freqs_imag
out_imag = x_real * freqs_imag + x_imag * freqs_real
x_i = mx.stack([out_real, out_imag], axis=-1).reshape(seq_len, n, -1)
# Handle remaining sequence
if x.shape[1] > seq_len:
x_i = mx.concatenate([x_i, x[i, seq_len:]], axis=0)
output.append(x_i)
return mx.stack(output)
class WanRMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = mx.ones((dim,))
def __call__(self, x):
"""
Args:
x(Array): Shape [B, L, C]
"""
return self._norm(x) * self.weight
def _norm(self, x):
return x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + self.eps)
class WanLayerNorm(nn.LayerNorm):
def __init__(self, dim, eps=1e-6, affine=False):
super().__init__(dims=dim, eps=eps, affine=affine)
def __call__(self, x):
"""
Args:
x(Array): Shape [B, L, C]
"""
return super().__call__(x)
def mlx_attention(
q: mx.array,
k: mx.array,
v: mx.array,
q_lens: Optional[mx.array] = None,
k_lens: Optional[mx.array] = None,
dropout_p: float = 0.,
softmax_scale: Optional[float] = None,
q_scale: Optional[float] = None,
causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
deterministic: bool = False,
dtype: Optional[type] = None,
) -> mx.array:
# Get shapes
b, lq, n, d = q.shape
_, lk, _, _ = k.shape
# Scale queries if needed
if q_scale is not None:
q = q * q_scale
# Compute attention scores
q = q.transpose(0, 2, 1, 3) # [b, n, lq, d]
k = k.transpose(0, 2, 1, 3) # [b, n, lk, d]
v = v.transpose(0, 2, 1, 3) # [b, n, lk, d]
# Compute attention scores
scores = mx.matmul(q, k.transpose(0, 1, 3, 2)) # [b, n, lq, lk]
# Apply softmax scale if provided
if softmax_scale is not None:
scores = scores * softmax_scale
else:
# Default scaling by sqrt(d)
scores = scores / mx.sqrt(mx.array(d, dtype=scores.dtype))
# Create attention mask
attn_mask = None
# Apply window size masking if specified
if window_size != (-1, -1):
left_window, right_window = window_size
window_mask = mx.zeros((lq, lk))
for i in range(lq):
start = max(0, i - left_window)
end = min(lk, i + right_window + 1)
window_mask[i, start:end] = 1
attn_mask = window_mask
# Apply causal masking if needed
if causal:
causal_mask = mx.tril(mx.ones((lq, lk)), k=0)
if attn_mask is None:
attn_mask = causal_mask
else:
attn_mask = mx.logical_and(attn_mask, causal_mask)
# Apply attention mask if present
if attn_mask is not None:
attn_mask = attn_mask.astype(scores.dtype)
scores = scores * attn_mask + (1 - attn_mask) * -1e4
# Apply attention mask if lengths are provided
if q_lens is not None or k_lens is not None:
if q_lens is not None:
mask = mx.arange(lq)[None, :] < q_lens[:, None]
mask = mask.astype(scores.dtype)
scores = scores * mask[:, None, :, None] + (1 - mask[:, None, :, None]) * -1e4
if k_lens is not None:
mask = mx.arange(lk)[None, :] < k_lens[:, None]
mask = mask.astype(scores.dtype)
scores = scores * mask[:, None, None, :] + (1 - mask[:, None, None, :]) * -1e4
# Apply softmax
max_scores = mx.max(scores, axis=-1, keepdims=True)
scores = scores - max_scores
exp_scores = mx.exp(scores)
sum_exp = mx.sum(exp_scores, axis=-1, keepdims=True)
attn = exp_scores / (sum_exp + 1e-6)
# Apply dropout if needed
if dropout_p > 0 and not deterministic:
raise NotImplementedError("Dropout not implemented in MLX version")
# Compute output
out = mx.matmul(attn, v) # [b, n, lq, d]
out = out.transpose(0, 2, 1, 3) # [b, lq, n, d]
return out
class WanSelfAttention(nn.Module):
def __init__(self,
dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
# layers
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def __call__(self, x, seq_lens, grid_sizes, freqs):
"""
Args:
x(Array): Shape [B, L, C]
seq_lens(Array): Shape [B]
grid_sizes(Array): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Array): Rope freqs, shape [1024, C / num_heads / 2, 2]
"""
b, s, n, d = x.shape[0], x.shape[1], self.num_heads, self.head_dim
# query, key, value function
q = self.norm_q(self.q(x)).reshape(b, s, n, d)
k = self.norm_k(self.k(x)).reshape(b, s, n, d)
v = self.v(x).reshape(b, s, n, d)
x = mlx_attention(
q=rope_apply(q, grid_sizes, freqs),
k=rope_apply(k, grid_sizes, freqs),
v=v,
k_lens=seq_lens,
window_size=self.window_size)
# output
x = x.reshape(b, s, -1)
x = self.o(x)
return x
class WanCrossAttention(WanSelfAttention):
def __call__(self, x, context, context_lens):
"""
Args:
x(Array): Shape [B, L1, C]
context(Array): Shape [B, L2, C]
context_lens(Array): Shape [B]
"""
b, n, d = x.shape[0], self.num_heads, self.head_dim
# compute query, key, value
q = self.norm_q(self.q(x)).reshape(b, -1, n, d)
k = self.norm_k(self.k(context)).reshape(b, -1, n, d)
v = self.v(context).reshape(b, -1, n, d)
# compute attention
x = mlx_attention(q, k, v, k_lens=context_lens)
# output
x = x.reshape(b, -1, self.dim)
x = self.o(x)
return x
class WanAttentionBlock(nn.Module):
def __init__(self,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6):
super().__init__()
self.dim = dim
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# layers
self.norm1 = WanLayerNorm(dim, eps)
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
eps)
self.norm3 = WanLayerNorm(
dim, eps,
affine=True) if cross_attn_norm else nn.Identity()
self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm,
eps)
self.norm2 = WanLayerNorm(dim, eps)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim),
nn.GELU(),
nn.Linear(ffn_dim, dim))
# modulation
self.modulation = mx.random.normal((1, 6, dim)) / dim**0.5
def __call__(
self,
x,
e,
seq_lens,
grid_sizes,
freqs,
context,
context_lens,
):
"""
Args:
x(Array): Shape [B, L, C]
e(Array): Shape [B, L1, 6, C]
seq_lens(Array): Shape [B], length of each sequence in batch
grid_sizes(Array): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Array): Rope freqs, shape [1024, C / num_heads / 2, 2]
"""
e = mx.split(self.modulation + e, 6, axis=2)
# self-attention
y = self.self_attn(
self.norm1(x) * (1 + mx.squeeze(e[1], axis=2)) + mx.squeeze(e[0], axis=2),
seq_lens, grid_sizes, freqs)
x = x + y * mx.squeeze(e[2], axis=2)
# cross-attention & ffn function
x = x + self.cross_attn(self.norm3(x), context, context_lens)
y = self.ffn(
self.norm2(x) * (1 + mx.squeeze(e[4], axis=2)) + mx.squeeze(e[3], axis=2))
x = x + y * mx.squeeze(e[5], axis=2)
return x
class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
super().__init__()
self.dim = dim
self.out_dim = out_dim
self.patch_size = patch_size
self.eps = eps
# layers
out_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, out_dim)
# modulation
self.modulation = mx.random.normal((1, 2, dim)) / dim**0.5
def __call__(self, x, e):
"""
Args:
x(Array): Shape [B, L1, C]
e(Array): Shape [B, L1, C]
"""
e = mx.split(self.modulation + mx.expand_dims(e, axis=2), 2, axis=2)
x = self.head(
self.norm(x) * (1 + mx.squeeze(e[1], axis=2)) + mx.squeeze(e[0], axis=2))
return x
class WanModel(nn.Module):
"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
def __init__(self,
model_type='t2v',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6):
"""
Initialize the diffusion model backbone.
Args:
model_type (`str`, *optional*, defaults to 't2v'):
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
text_len (`int`, *optional*, defaults to 512):
Fixed length for text embeddings
in_dim (`int`, *optional*, defaults to 16):
Input video channels (C_in)
dim (`int`, *optional*, defaults to 2048):
Hidden dimension of the transformer
ffn_dim (`int`, *optional*, defaults to 8192):
Intermediate dimension in feed-forward network
freq_dim (`int`, *optional*, defaults to 256):
Dimension for sinusoidal time embeddings
text_dim (`int`, *optional*, defaults to 4096):
Input dimension for text embeddings
out_dim (`int`, *optional*, defaults to 16):
Output video channels (C_out)
num_heads (`int`, *optional*, defaults to 16):
Number of attention heads
num_layers (`int`, *optional*, defaults to 32):
Number of transformer blocks
window_size (`tuple`, *optional*, defaults to (-1, -1)):
Window size for local attention (-1 indicates global attention)
qk_norm (`bool`, *optional*, defaults to True):
Enable query/key normalization
cross_attn_norm (`bool`, *optional*, defaults to False):
Enable cross-attention normalization
eps (`float`, *optional*, defaults to 1e-6):
Epsilon value for normalization layers
"""
super().__init__()
assert model_type in ['t2v', 'i2v', 'ti2v']
self.model_type = model_type
self.patch_size = patch_size
self.text_len = text_len
self.in_dim = in_dim
self.dim = dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.text_dim = text_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# embeddings
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim),
nn.GELU(),
nn.Linear(dim, dim))
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim))
self.time_projection = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, dim * 6))
# blocks
self.blocks = [
WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,
cross_attn_norm, eps) for _ in range(num_layers)
]
# head
self.head = Head(dim, out_dim, patch_size, eps)
# buffers
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
self.freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
], axis=1)
# initialize weights
self.init_weights()
def __call__(
self,
x,
t,
context,
seq_len,
y=None,
):
"""
Forward pass through the diffusion model
Args:
x (List[Array]):
List of input video tensors, each with shape [C_in, F, H, W]
t (Array):
Diffusion timesteps tensor of shape [B]
context (List[Array]):
List of text embeddings each with shape [L, C]
seq_len (`int`):
Maximum sequence length for positional encoding
y (List[Array], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
Returns:
List[Array]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
if self.model_type == 'i2v':
assert y is not None
if y is not None:
x = [mx.concatenate([u, v], axis=0) for u, v in zip(x, y)]
# embeddings
x = [self.patch_embedding(mx.expand_dims(mx.transpose(u, (1, 2, 3, 0)), axis=0)) for u in x]
grid_sizes = mx.stack(
[mx.array(u.shape[1:4], dtype=mx.int32) for u in x])
x = [u.reshape(u.shape[0], -1, u.shape[-1]) for u in x]
seq_lens = mx.array([u.shape[1] for u in x], dtype=mx.int32)
assert seq_lens.max() <= seq_len
# Pad sequences
x_padded = []
for u in x:
pad_len = seq_len - u.shape[1]
if pad_len > 0:
padding = mx.zeros((u.shape[0], pad_len, u.shape[2]))
u = mx.concatenate([u, padding], axis=1)
x_padded.append(u)
x = mx.concatenate(x_padded, axis=0)
# time embeddings
if t.ndim == 1:
t = mx.broadcast_to(t[:, None], (t.shape[0], seq_len))
bt = t.shape[0]
t = t.flatten()
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).reshape(bt, seq_len, -1))
e0 = self.time_projection(e).reshape(bt, seq_len, 6, self.dim)
# context
context_lens = None
context_padded = []
for u in context:
pad_len = self.text_len - u.shape[0]
if pad_len > 0:
padding = mx.zeros((pad_len, u.shape[1]))
u = mx.concatenate([u, padding], axis=0)
context_padded.append(u)
context = self.text_embedding(mx.stack(context_padded))
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens)
for block in self.blocks:
x = block(x, **kwargs)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x
def unpatchify(self, x, grid_sizes):
"""
Reconstruct video tensors from patch embeddings.
Args:
x (List[Array]):
List of patchified features, each with shape [L, C_out * prod(patch_size)]
grid_sizes (Array):
Original spatial-temporal grid dimensions before patching,
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
Returns:
List[Array]:
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
"""
c = self.out_dim
out = []
for i, v in enumerate(grid_sizes):
v = v.tolist()
seq_len = math.prod(v)
u = x[i, :seq_len].reshape(*v, *self.patch_size, c)
# Rearrange dimensions: (f, h, w, p, q, r, c) -> (c, f*p, h*q, w*r)
u = mx.transpose(u, (6, 0, 3, 1, 4, 2, 5))
u = u.reshape(c, v[0] * self.patch_size[0],
v[1] * self.patch_size[1],
v[2] * self.patch_size[2])
out.append(u)
return out
def init_weights(self):
"""
Initialize model parameters using Xavier initialization.
"""
# Initialize patch embedding
fan_in = self.in_dim * math.prod(self.patch_size)
fan_out = self.dim
std = math.sqrt(2.0 / (fan_in + fan_out))
self.patch_embedding.weight = mx.random.uniform(
low=-std, high=std, shape=self.patch_embedding.weight.shape)
# Initialize text embedding layers with normal distribution
text_layers = list(self.text_embedding.layers)
for i in [0, 2]: # First and third layers
layer = text_layers[i]
layer.weight = mx.random.normal(shape=layer.weight.shape) * 0.02
if hasattr(layer, 'bias') and layer.bias is not None:
layer.bias = mx.zeros(layer.bias.shape)
# Initialize time embedding layers
time_layers = list(self.time_embedding.layers)
for i in [0, 2]: # First and third layers
layer = time_layers[i]
layer.weight = mx.random.normal(shape=layer.weight.shape) * 0.02
if hasattr(layer, 'bias') and layer.bias is not None:
layer.bias = mx.zeros(layer.bias.shape)
# Initialize output head to zeros
self.head.head.weight = mx.zeros(self.head.head.weight.shape)
if hasattr(self.head.head, 'bias') and self.head.head.bias is not None:
self.head.head.bias = mx.zeros(self.head.head.bias.shape)

View File

@ -0,0 +1,616 @@
# MLX implementation for t5.py
import logging
import math
from typing import Optional, Tuple, List
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_unflatten
from .tokenizers import HuggingfaceTokenizer
__all__ = [
'T5Model',
'T5Encoder',
'T5Decoder',
'T5EncoderModel',
]
def fp16_clamp(x):
if x.dtype == mx.float16:
# Use same clamping as PyTorch for consistency
clamp = 65504.0 # max value for float16
return mx.clip(x, -clamp, clamp)
return x
class GELU(nn.Module):
def __call__(self, x):
return 0.5 * x * (1.0 + mx.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * mx.power(x, 3.0))))
class T5LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = mx.ones((dim,))
def __call__(self, x):
# Match PyTorch's approach: convert to float32 for stability
x_float = x.astype(mx.float32) if x.dtype == mx.float16 else x
variance = mx.mean(mx.square(x_float), axis=-1, keepdims=True)
x_norm = x_float * mx.rsqrt(variance + self.eps)
# Convert back to original dtype
if x.dtype == mx.float16:
x_norm = x_norm.astype(mx.float16)
return self.weight * x_norm
class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.0):
assert dim_attn % num_heads == 0
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.num_heads = num_heads
self.head_dim = dim_attn // num_heads
# layers
self.q = nn.Linear(dim, dim_attn, bias=False)
self.k = nn.Linear(dim, dim_attn, bias=False)
self.v = nn.Linear(dim, dim_attn, bias=False)
self.o = nn.Linear(dim_attn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def __call__(self, x, context=None, mask=None, pos_bias=None):
"""
x: [B, L1, C].
context: [B, L2, C] or None.
mask: [B, L2] or [B, L1, L2] or None.
"""
# check inputs
context = x if context is None else context
b, l1, _ = x.shape
_, l2, _ = context.shape
n, c = self.num_heads, self.head_dim
# compute query, key, value
q = self.q(x).reshape(b, l1, n, c)
k = self.k(context).reshape(b, l2, n, c)
v = self.v(context).reshape(b, l2, n, c)
# transpose for attention: [B, N, L, C]
q = mx.transpose(q, (0, 2, 1, 3))
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
# compute attention (T5 does not use scaling)
attn = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) # [B, N, L1, L2]
# add position bias if provided
if pos_bias is not None:
attn = attn + pos_bias
# apply mask
if mask is not None:
if mask.ndim == 2:
# [B, L2] -> [B, 1, 1, L2]
mask = mask[:, None, None, :]
elif mask.ndim == 3:
# [B, L1, L2] -> [B, 1, L1, L2]
mask = mask[:, None, :, :]
# Use very negative value that works well with float16
min_value = -65504.0 if attn.dtype == mx.float16 else -1e9
attn = mx.where(mask == 0, min_value, attn)
# softmax and apply attention
attn = mx.softmax(attn.astype(mx.float32), axis=-1).astype(attn.dtype)
attn = self.dropout(attn)
# apply attention to values
x = mx.matmul(attn, v) # [B, N, L1, C]
# transpose back and reshape
x = mx.transpose(x, (0, 2, 1, 3)) # [B, L1, N, C]
x = x.reshape(b, l1, -1)
# output projection
x = self.o(x)
x = self.dropout(x)
return x
class T5FeedForward(nn.Module):
def __init__(self, dim, dim_ffn, dropout=0.0):
super().__init__()
self.dim = dim
self.dim_ffn = dim_ffn
# layers
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
self.gate_act = GELU()
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def __call__(self, x):
gate = self.gate_act(self.gate_proj(x))
x = self.fc1(x) * gate
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class T5SelfAttention(nn.Module):
def __init__(self,
dim,
dim_attn,
dim_ffn,
num_heads,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.norm1 = T5LayerNorm(dim)
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=True)
def __call__(self, x, mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(
x.shape[1], x.shape[1])
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.ffn(self.norm2(x)))
return x
class T5CrossAttention(nn.Module):
def __init__(self,
dim,
dim_attn,
dim_ffn,
num_heads,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.norm1 = T5LayerNorm(dim)
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim)
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm3 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=False)
def __call__(self,
x,
mask=None,
encoder_states=None,
encoder_mask=None,
pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(
x.shape[1], x.shape[1])
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.cross_attn(
self.norm2(x), context=encoder_states, mask=encoder_mask))
x = fp16_clamp(x + self.ffn(self.norm3(x)))
return x
class T5RelativeEmbedding(nn.Module):
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
super().__init__()
self.num_buckets = num_buckets
self.num_heads = num_heads
self.bidirectional = bidirectional
self.max_dist = max_dist
# layers
self.embedding = nn.Embedding(num_buckets, num_heads)
def __call__(self, lq, lk):
# Create relative position matrix
positions_q = mx.arange(lq)[:, None]
positions_k = mx.arange(lk)[None, :]
rel_pos = positions_k - positions_q
# Apply bucketing
rel_pos = self._relative_position_bucket(rel_pos)
# Get embeddings
rel_pos_embeds = self.embedding(rel_pos)
# Reshape to [1, N, Lq, Lk]
rel_pos_embeds = mx.transpose(rel_pos_embeds, (2, 0, 1))
rel_pos_embeds = mx.expand_dims(rel_pos_embeds, 0)
return rel_pos_embeds
def _relative_position_bucket(self, rel_pos):
# preprocess
if self.bidirectional:
num_buckets = self.num_buckets // 2
rel_buckets = mx.array(rel_pos > 0, dtype=mx.int32) * num_buckets
rel_pos = mx.abs(rel_pos)
else:
num_buckets = self.num_buckets
rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32)
rel_pos = -mx.minimum(rel_pos, mx.zeros_like(rel_pos))
# embeddings for small and large positions
max_exact = num_buckets // 2
is_small = rel_pos < max_exact
# For large positions, use log scale
rel_pos_large = max_exact + (
mx.log(mx.array(rel_pos, dtype=mx.float32) / max_exact) /
math.log(self.max_dist / max_exact) *
(num_buckets - max_exact)
).astype(mx.int32)
rel_pos_large = mx.minimum(rel_pos_large, num_buckets - 1)
# Combine small and large position buckets
rel_buckets = rel_buckets + mx.where(is_small, rel_pos, rel_pos_large)
return rel_buckets
class T5Encoder(nn.Module):
def __init__(self,
vocab,
dim,
dim_attn,
dim_ffn,
num_heads,
num_layers,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_layers = num_layers
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
if isinstance(vocab, nn.Embedding):
self.token_embedding = vocab
else:
self.token_embedding = nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=True) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = [
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
shared_pos, dropout) for _ in range(num_layers)
]
self.norm = T5LayerNorm(dim)
def __call__(self, ids, mask=None):
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.shape[1],
x.shape[1]) if self.shared_pos else None
for block in self.blocks:
x = block(x, mask, pos_bias=e)
x = self.norm(x)
x = self.dropout(x)
return x
class T5Decoder(nn.Module):
def __init__(self,
vocab,
dim,
dim_attn,
dim_ffn,
num_heads,
num_layers,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_layers = num_layers
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
if isinstance(vocab, nn.Embedding):
self.token_embedding = vocab
else:
self.token_embedding = nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=False) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = [
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
shared_pos, dropout) for _ in range(num_layers)
]
self.norm = T5LayerNorm(dim)
def __call__(self, ids, mask=None, encoder_states=None, encoder_mask=None):
b, s = ids.shape
# causal mask
if mask is None:
mask = mx.tril(mx.ones((1, s, s)))
elif mask.ndim == 2:
# Expand mask properly
mask = mx.tril(mx.expand_dims(mask, 1).broadcast_to((b, s, s)))
# layers
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.shape[1],
x.shape[1]) if self.shared_pos else None
for block in self.blocks:
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
x = self.norm(x)
x = self.dropout(x)
return x
class T5Model(nn.Module):
def __init__(self,
vocab_size,
dim,
dim_attn,
dim_ffn,
num_heads,
encoder_layers,
decoder_layers,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.encoder_layers = encoder_layers
self.decoder_layers = decoder_layers
self.num_buckets = num_buckets
# layers
self.token_embedding = nn.Embedding(vocab_size, dim)
self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
num_heads, encoder_layers, num_buckets,
shared_pos, dropout)
self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
num_heads, decoder_layers, num_buckets,
shared_pos, dropout)
self.head = nn.Linear(dim, vocab_size, bias=False)
def __call__(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
x = self.encoder(encoder_ids, encoder_mask)
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
x = self.head(x)
return x
def init_mlx_weights(module, key):
"""Initialize weights for T5 model components to match PyTorch initialization"""
def normal(key, shape, std=1.0):
return mx.random.normal(key, shape) * std
if isinstance(module, T5LayerNorm):
module.weight = mx.ones_like(module.weight)
elif isinstance(module, nn.Embedding):
key = mx.random.split(key, 1)[0]
module.weight = normal(key, module.weight.shape, std=1.0)
elif isinstance(module, T5FeedForward):
# Match PyTorch initialization
key1, key2, key3 = mx.random.split(key, 3)
module.gate_proj.weight = normal(key1, module.gate_proj.weight.shape,
std=module.dim**-0.5)
module.fc1.weight = normal(key2, module.fc1.weight.shape,
std=module.dim**-0.5)
module.fc2.weight = normal(key3, module.fc2.weight.shape,
std=module.dim_ffn**-0.5)
elif isinstance(module, T5Attention):
# Match PyTorch initialization
key1, key2, key3, key4 = random.split(key, 4)
module.q.weight = normal(key1, module.q.weight.shape,
std=(module.dim * module.dim_attn)**-0.5)
module.k.weight = normal(key2, module.k.weight.shape,
std=module.dim**-0.5)
module.v.weight = normal(key3, module.v.weight.shape,
std=module.dim**-0.5)
module.o.weight = normal(key4, module.o.weight.shape,
std=(module.num_heads * module.dim_attn)**-0.5)
elif isinstance(module, T5RelativeEmbedding):
key = mx.random.split(key, 1)[0]
module.embedding.weight = normal(key, module.embedding.weight.shape,
std=(2 * module.num_buckets * module.num_heads)**-0.5)
elif isinstance(module, nn.Linear):
# Generic linear layer initialization
key = mx.random.split(key, 1)[0]
fan_in = module.weight.shape[1]
bound = 1.0 / math.sqrt(fan_in)
module.weight = mx.random.uniform(key, module.weight.shape, -bound, bound)
return module
def _t5(name,
encoder_only=False,
decoder_only=False,
return_tokenizer=False,
tokenizer_kwargs={},
**kwargs):
# sanity check
assert not (encoder_only and decoder_only)
# params
if encoder_only:
model_cls = T5Encoder
kwargs['vocab'] = kwargs.pop('vocab_size')
kwargs['num_layers'] = kwargs.pop('encoder_layers')
_ = kwargs.pop('decoder_layers')
elif decoder_only:
model_cls = T5Decoder
kwargs['vocab'] = kwargs.pop('vocab_size')
kwargs['num_layers'] = kwargs.pop('decoder_layers')
_ = kwargs.pop('encoder_layers')
else:
model_cls = T5Model
# init model
model = model_cls(**kwargs)
# Initialize weights properly
key = mx.random.key(0)
model = init_mlx_weights(model, key)
# init tokenizer
if return_tokenizer:
from .tokenizers import HuggingfaceTokenizer
tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
return model, tokenizer
else:
return model
def umt5_xxl(**kwargs):
cfg = dict(
vocab_size=256384,
dim=4096,
dim_attn=4096,
dim_ffn=10240,
num_heads=64,
encoder_layers=24,
decoder_layers=24,
num_buckets=32,
shared_pos=False,
dropout=0.0)
cfg.update(**kwargs)
return _t5('umt5-xxl', **cfg)
class T5EncoderModel:
def __init__(
self,
text_len,
checkpoint_path=None,
tokenizer_path=None,
):
self.text_len = text_len
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
# init model
model = umt5_xxl(
encoder_only=True,
return_tokenizer=False)
if checkpoint_path:
logging.info(f'loading {checkpoint_path}')
# Load weights - assuming MLX format checkpoint
weights = mx.load(checkpoint_path)
model.update(tree_unflatten(list(weights.items())))
self.model = model
# init tokenizer
from .tokenizers import HuggingfaceTokenizer
self.tokenizer = HuggingfaceTokenizer(
name=tokenizer_path if tokenizer_path else 'google/umt5-xxl',
seq_len=text_len,
clean='whitespace')
def __call__(self, texts):
# Handle single string input
if isinstance(texts, str):
texts = [texts]
# Tokenize texts
tokenizer_output = self.tokenizer(
texts, return_mask=True, add_special_tokens=True)
# Handle different tokenizer output formats
if isinstance(tokenizer_output, tuple):
ids, mask = tokenizer_output
else:
# Assuming dict output with 'input_ids' and 'attention_mask'
ids = tokenizer_output['input_ids']
mask = tokenizer_output['attention_mask']
# Convert to MLX arrays if not already
if not isinstance(ids, mx.array):
ids = mx.array(ids)
if not isinstance(mask, mx.array):
mask = mx.array(mask)
# Get sequence lengths
seq_lens = mx.sum(mask > 0, axis=1)
# Run encoder
context = self.model(ids, mask)
# Return variable length outputs
# Convert seq_lens to Python list for indexing
if seq_lens.ndim == 0: # Single value
seq_lens_list = [seq_lens.item()]
else:
seq_lens_list = seq_lens.tolist()
return [context[i, :int(seq_lens_list[i])] for i in range(len(texts))]
# Utility function to convert PyTorch checkpoint to MLX
def convert_pytorch_checkpoint(pytorch_path, mlx_path):
"""Convert PyTorch checkpoint to MLX format"""
import torch
# Load PyTorch checkpoint
pytorch_state = torch.load(pytorch_path, map_location='cpu')
# Convert to numpy then to MLX
mlx_state = {}
for key, value in pytorch_state.items():
if isinstance(value, torch.Tensor):
# Handle the key mapping if needed
mlx_key = key
# Convert tensor to MLX array
mlx_state[mlx_key] = mx.array(value.numpy())
# Save MLX checkpoint
mx.save(mlx_path, mlx_state)
return mlx_state

View File

@ -0,0 +1,82 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import html
import string
import ftfy
import regex as re
from transformers import AutoTokenizer
__all__ = ['HuggingfaceTokenizer']
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
def canonicalize(text, keep_punctuation_exact_string=None):
text = text.replace('_', ' ')
if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join(
part.translate(str.maketrans('', '', string.punctuation))
for part in text.split(keep_punctuation_exact_string))
else:
text = text.translate(str.maketrans('', '', string.punctuation))
text = text.lower()
text = re.sub(r'\s+', ' ', text)
return text.strip()
class HuggingfaceTokenizer:
def __init__(self, name, seq_len=None, clean=None, **kwargs):
assert clean in (None, 'whitespace', 'lower', 'canonicalize')
self.name = name
self.seq_len = seq_len
self.clean = clean
# init tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
self.vocab_size = self.tokenizer.vocab_size
def __call__(self, sequence, **kwargs):
return_mask = kwargs.pop('return_mask', False)
# arguments
_kwargs = {'return_tensors': 'pt'}
if self.seq_len is not None:
_kwargs.update({
'padding': 'max_length',
'truncation': True,
'max_length': self.seq_len
})
_kwargs.update(**kwargs)
# tokenization
if isinstance(sequence, str):
sequence = [sequence]
if self.clean:
sequence = [self._clean(u) for u in sequence]
ids = self.tokenizer(sequence, **_kwargs)
# output
if return_mask:
return ids.input_ids, ids.attention_mask
else:
return ids.input_ids
def _clean(self, text):
if self.clean == 'whitespace':
text = whitespace_clean(basic_clean(text))
elif self.clean == 'lower':
text = whitespace_clean(basic_clean(text)).lower()
elif self.clean == 'canonicalize':
text = canonicalize(basic_clean(text))
return text

View File

@ -0,0 +1,703 @@
# MLX implementation of vae2_1.py
import logging
from typing import Optional, List, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_unflatten
__all__ = [
'Wan2_1_VAE',
]
CACHE_T = 2
debug_line = 0
class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolution for MLX.
Expects input in BTHWC format (batch, time, height, width, channels).
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Padding order: (W, W, H, H, T, 0)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def __call__(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
x = mx.concatenate([cache_x, x], axis=1) # Concat along time axis
padding[4] -= cache_x.shape[1]
# Pad in BTHWC format
pad_width = [(0, 0), (padding[4], padding[5]), (padding[2], padding[3]),
(padding[0], padding[1]), (0, 0)]
x = mx.pad(x, pad_width)
result = super().__call__(x)
return result
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=False, images=True, bias=False):
super().__init__()
self.channel_first = channel_first
self.images = images
self.scale = dim**0.5
# Just keep as 1D - let broadcasting do its magic
self.gamma = mx.ones((dim,))
self.bias = mx.zeros((dim,)) if bias else 0.
def __call__(self, x):
norm = mx.sqrt(mx.sum(x * x, axis=-1, keepdims=True) + 1e-6)
x = x / norm
return x * self.scale * self.gamma + self.bias
class Upsample(nn.Module):
"""
Upsampling layer that matches PyTorch's behavior.
"""
def __init__(self, scale_factor, mode='nearest-exact'):
super().__init__()
self.scale_factor = scale_factor
self.mode = mode # mode is now unused, but kept for signature consistency
def __call__(self, x):
scale_h, scale_w = self.scale_factor
out = mx.repeat(x, int(scale_h), axis=1) # Repeat along H dimension
out = mx.repeat(out, int(scale_w), axis=2) # Repeat along W dimension
return out
class AsymmetricPad(nn.Module):
"""A module to apply asymmetric padding, compatible with nn.Sequential."""
def __init__(self, pad_width: tuple):
super().__init__()
self.pad_width = pad_width
def __call__(self, x):
return mx.pad(x, self.pad_width)
# Update your Resample class to use 'nearest-exact'
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
'downsample3d')
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == 'upsample2d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1)
)
elif mode == 'upsample3d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1)
)
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
# --- CORRECTED PADDING LOGIC ---
elif mode == 'downsample2d':
pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0)))
conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0)
self.resample = nn.Sequential(pad_layer, conv_layer)
elif mode == 'downsample3d':
# The spatial downsampling part uses the same logic
pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0)))
conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0)
self.resample = nn.Sequential(pad_layer, conv_layer)
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def __call__(self, x, feat_cache=None, feat_idx=[0]):
# The __call__ method logic remains unchanged from your original code
b, t, h, w, c = x.shape
if self.mode == 'upsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = 'Rep'
feat_idx[0] += 1
else:
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != 'Rep':
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] == 'Rep':
cache_x = mx.concatenate([
mx.zeros_like(cache_x), cache_x
], axis=1)
if feat_cache[idx] == 'Rep':
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, t, h, w, 2, c)
x = mx.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=2)
x = x.reshape(b, t * 2, h, w, c)
t = x.shape[1]
x = x.reshape(b * t, h, w, c)
x = self.resample(x)
_, h_new, w_new, c_new = x.shape
x = x.reshape(b, t, h_new, w_new, c_new)
if self.mode == 'downsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x
feat_idx[0] += 1
else:
cache_x = x[:, -1:, :, :, :]
x = self.time_conv(
mx.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False),
nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False),
nn.SiLU(),
nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
CausalConv3d(out_dim, out_dim, 3, padding=1)
)
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
def __call__(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
for i, layer in enumerate(self.residual.layers):
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + h
class AttentionBlock(nn.Module):
"""
Causal self-attention with a single head.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = RMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
# zero out the last layer params
self.proj.weight = mx.zeros_like(self.proj.weight)
def __call__(self, x):
# x is in BTHWC format
identity = x
b, t, h, w, c = x.shape
x = x.reshape(b * t, h, w, c) # Combine batch and time
x = self.norm(x)
# compute query, key, value
qkv = self.to_qkv(x) # Output: (b*t, h, w, 3*c)
qkv = qkv.reshape(b * t, h * w, 3 * c)
q, k, v = mx.split(qkv, 3, axis=-1)
# Reshape for attention
q = q.reshape(b * t, h * w, c)
k = k.reshape(b * t, h * w, c)
v = v.reshape(b * t, h * w, c)
# Scaled dot product attention
scale = 1.0 / mx.sqrt(mx.array(c, dtype=q.dtype))
scores = (q @ k.transpose(0, 2, 1)) * scale
weights = mx.softmax(scores, axis=-1)
x = weights @ v
x = x.reshape(b * t, h, w, c)
# output
x = self.proj(x)
x = x.reshape(b, t, h, w, c)
return x + identity
class Encoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
downsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = 'downsample3d' if temperal_downsample[i] else 'downsample2d'
downsamples.append(Resample(out_dim, mode=mode))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[-1], dims[-1], dropout),
AttentionBlock(dims[-1]),
ResidualBlock(dims[-1], dims[-1], dropout)
)
# output blocks
self.head = nn.Sequential(
RMS_norm(dims[-1], images=False),
nn.SiLU(),
CausalConv3d(dims[-1], z_dim, 3, padding=1)
)
def __call__(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## downsamples
for i, layer in enumerate(self.downsamples.layers):
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle.layers:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for i, layer in enumerate(self.head.layers):
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
class Decoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout)
)
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i == 1 or i == 2 or i == 3:
in_dim = in_dim // 2
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
upsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# upsample block
if i != len(dim_mult) - 1:
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(dims[-1], images=False),
nn.SiLU(),
CausalConv3d(dims[-1], 3, 3, padding=1)
)
def __call__(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## middle
for layer in self.middle.layers:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples.layers:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for i, layer in enumerate(self.head.layers):
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_conv3d(model):
count = 0
for name, module in model.named_modules():
if isinstance(module, CausalConv3d):
count += 1
return count
class WanVAE_(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def encode(self, x, scale):
# x is in BTHWC format
self.clear_cache()
## cache
t = x.shape[1]
iter_ = 1 + (t - 1) // 4
## Split encode input x by time into 1, 4, 4, 4....
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :1, :, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(
x[:, 1 + 4 * (i - 1):1 + 4 * i, :, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = mx.concatenate([out, out_], axis=1)
z = self.conv1(out)
mu, log_var = mx.split(z, 2, axis=-1) # Split along channel dimension
if isinstance(scale[0], mx.array):
# Reshape scale for broadcasting in BTHWC format
scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim)
scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim)
mu = (mu - scale_mean) * scale_std
else:
mu = (mu - scale[0]) * scale[1]
self.clear_cache()
return mu, log_var
def decode(self, z, scale):
# z is in BTHWC format
self.clear_cache()
if isinstance(scale[0], mx.array):
scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim)
scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim)
z = z / scale_std + scale_mean
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[1]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, i:i + 1, :, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_ = self.decoder(
x[:, i:i + 1, :, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = mx.concatenate([out, out_], axis=1)
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
std = mx.exp(0.5 * log_var)
eps = mx.random.normal(std.shape)
return eps * std + mu
def __call__(self, x):
mu, log_var = self.encode(x, self.scale)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z, self.scale)
return x_recon, mu, log_var
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs, self.scale)
if deterministic:
return mu
std = mx.exp(0.5 * mx.clip(log_var, -30.0, 20.0))
return mu + std * mx.random.normal(std.shape)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
#cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=None, **kwargs):
# params
cfg = dict(
dim=96,
z_dim=z_dim,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[False, True, True],
dropout=0.0)
cfg.update(**kwargs)
# init model
model = WanVAE_(**cfg)
# load checkpoint
if pretrained_path:
logging.info(f'loading {pretrained_path}')
weights = mx.load(pretrained_path)
model.update(tree_unflatten(list(weights.items())))
return model
class Wan2_1_VAE:
def __init__(self,
z_dim=16,
vae_pth='cache/vae_step_411000.pth',
dtype=mx.float32):
self.dtype = dtype
mean = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]
std = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]
self.mean = mx.array(mean, dtype=dtype)
self.std = mx.array(std, dtype=dtype)
self.scale = [self.mean, 1.0 / self.std]
# init model
self.model = _video_vae(
pretrained_path=vae_pth,
z_dim=z_dim,
)
def encode(self, videos):
"""
videos: A list of videos each with shape [C, T, H, W].
Returns: List of encoded videos in [C, T, H, W] format.
"""
encoded = []
for video in videos:
# Convert CTHW -> BTHWC
x = mx.expand_dims(video, axis=0) # Add batch dimension
x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC
# Encode
z = self.model.encode(x, self.scale)[0] # Get mu only
# Convert back BTHWC -> CTHW and remove batch dimension
z = z.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW
z = z.squeeze(0) # Remove batch dimension -> CTHW
encoded.append(z.astype(mx.float32))
return encoded
def decode(self, zs):
"""
zs: A list of latent codes each with shape [C, T, H, W].
Returns: List of decoded videos in [C, T, H, W] format.
"""
decoded = []
for z in zs:
# Convert CTHW -> BTHWC
x = mx.expand_dims(z, axis=0) # Add batch dimension
x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC
# Decode
x = self.model.decode(x, self.scale)
# Convert back BTHWC -> CTHW and remove batch dimension
x = x.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW
x = x.squeeze(0) # Remove batch dimension -> CTHW
# Clamp values
x = mx.clip(x, -1, 1)
decoded.append(x.astype(mx.float32))
return decoded

View File

@ -0,0 +1,265 @@
import json
from typing import Optional, List, Tuple
import mlx.core as mx
from mlx.utils import tree_unflatten
from safetensors import safe_open
import torch
def check_safetensors_dtypes(safetensors_path: str):
"""
Check what dtypes are in the safetensors file.
Useful for debugging dtype issues.
"""
print(f"🔍 Checking dtypes in: {safetensors_path}")
dtype_counts = {}
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
for key in f.keys():
tensor = f.get_tensor(key)
dtype_str = str(tensor.dtype)
if dtype_str not in dtype_counts:
dtype_counts[dtype_str] = []
dtype_counts[dtype_str].append(key)
print("📊 Dtype summary:")
for dtype, keys in dtype_counts.items():
print(f" {dtype}: {len(keys)} parameters")
if dtype == "torch.bfloat16":
print(f" ⚠️ BFloat16 detected - will convert to float32")
print(f" Examples: {keys[:3]}")
return dtype_counts
def convert_tensor_dtype(tensor: torch.Tensor) -> torch.Tensor:
"""
Convert tensor to MLX-compatible dtype.
"""
if tensor.dtype == torch.bfloat16:
# Convert BFloat16 to float32
return tensor.float()
elif tensor.dtype == torch.float64:
# Convert float64 to float32 for efficiency
return tensor.float()
else:
# Keep other dtypes as-is
return tensor
def map_t5_encoder_weights(key: str, value: mx.array) -> List[Tuple[str, mx.array]]:
"""
Map T5 encoder weights from PyTorch format to MLX format.
Following the pattern used in MLX Stable Diffusion.
Args:
key: Parameter name from PyTorch model
value: Parameter tensor
Returns:
List of (key, value) tuples for MLX model
"""
# Handle the main structural difference: FFN gate layer
if ".ffn.gate.0.weight" in key:
# PyTorch has Sequential(Linear, GELU) but MLX has separate gate_proj + gate_act
key = key.replace(".ffn.gate.0.weight", ".ffn.gate_proj.weight")
return [(key, value)]
elif ".ffn.gate.0.bias" in key:
# Handle bias if it exists
key = key.replace(".ffn.gate.0.bias", ".ffn.gate_proj.bias")
return [(key, value)]
elif ".ffn.gate.1" in key:
# Skip GELU activation parameters - MLX handles this separately
print(f"Skipping GELU parameter: {key}")
return []
# Handle any other potential FFN mappings
elif ".ffn.fc1.weight" in key:
return [(key, value)]
elif ".ffn.fc2.weight" in key:
return [(key, value)]
# Handle attention layers (should be direct mapping)
elif ".attn.q.weight" in key:
return [(key, value)]
elif ".attn.k.weight" in key:
return [(key, value)]
elif ".attn.v.weight" in key:
return [(key, value)]
elif ".attn.o.weight" in key:
return [(key, value)]
# Handle embeddings and norms (direct mapping)
elif "token_embedding.weight" in key:
return [(key, value)]
elif "pos_embedding.embedding.weight" in key:
return [(key, value)]
elif "norm1.weight" in key or "norm2.weight" in key or "norm.weight" in key:
return [(key, value)]
# Default: direct mapping for any other parameters
else:
return [(key, value)]
def _flatten(params: List[List[Tuple[str, mx.array]]]) -> List[Tuple[str, mx.array]]:
"""Flatten nested list of parameter tuples"""
return [(k, v) for p in params for (k, v) in p]
def _load_safetensor_weights(
mapper_func,
model,
weight_file: str,
float16: bool = False
):
"""
Load safetensor weights using the mapping function.
Based on MLX SD pattern.
"""
dtype = mx.float16 if float16 else mx.float32
# Load weights from safetensors file
weights = {}
with safe_open(weight_file, framework="pt", device="cpu") as f:
for key in f.keys():
tensor = f.get_tensor(key)
# Handle BFloat16 - convert to float32 first
if tensor.dtype == torch.bfloat16:
print(f"Converting BFloat16 to float32 for: {key}")
tensor = tensor.float() # Convert to float32
weights[key] = mx.array(tensor.numpy()).astype(dtype)
# Apply mapping function
mapped_weights = _flatten([mapper_func(k, v) for k, v in weights.items()])
# Update model with mapped weights
model.update(tree_unflatten(mapped_weights))
return model
def load_t5_encoder_from_safetensors(
safetensors_path: str,
model, # Your MLX T5Encoder instance
float16: bool = False
):
"""
Load T5 encoder weights from safetensors file into MLX model.
Args:
safetensors_path: Path to the safetensors file
model: Your MLX T5Encoder model instance
float16: Whether to use float16 precision
Returns:
Model with loaded weights
"""
print(f"Loading T5 encoder weights from: {safetensors_path}")
# Load and map weights
model = _load_safetensor_weights(
map_t5_encoder_weights,
model,
safetensors_path,
float16
)
print("T5 encoder weights loaded successfully!")
return model
def debug_weight_mapping(safetensors_path: str, float16: bool = False):
"""
Debug function to see how weights are being mapped.
Useful for troubleshooting conversion issues.
"""
dtype = mx.float16 if float16 else mx.float32
print("=== T5 Weight Mapping Debug ===")
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
for key in f.keys():
tensor = f.get_tensor(key)
# Handle BFloat16
original_dtype = tensor.dtype
if tensor.dtype == torch.bfloat16:
print(f"Converting BFloat16 to float32 for: {key}")
tensor = tensor.float()
value = mx.array(tensor.numpy()).astype(dtype)
# Apply mapping
mapped = map_t5_encoder_weights(key, value)
if len(mapped) == 0:
print(f"SKIPPED: {key} ({original_dtype}) -> (no mapping)")
elif len(mapped) == 1:
new_key, new_value = mapped[0]
if new_key == key:
print(f"DIRECT: {key} ({original_dtype}) [{tensor.shape}]")
else:
print(f"MAPPED: {key} ({original_dtype}) -> {new_key} [{tensor.shape}]")
else:
print(f"SPLIT: {key} ({original_dtype}) -> {len(mapped)} parameters")
for new_key, new_value in mapped:
print(f" -> {new_key} [{new_value.shape}]")
def convert_safetensors_to_mlx_weights(
safetensors_path: str,
output_path: str,
float16: bool = False
):
"""
Convert safetensors file to MLX weights file (.npz format).
Args:
safetensors_path: Input safetensors file
output_path: Output MLX weights file (.npz)
float16: Whether to use float16 precision
"""
dtype = mx.float16 if float16 else mx.float32
print(f"Converting safetensors to MLX format...")
print(f"Input: {safetensors_path}")
print(f"Output: {output_path}")
print(f"Target dtype: {dtype}")
# Load and convert weights
weights = {}
bfloat16_count = 0
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
for key in f.keys():
tensor = f.get_tensor(key)
# Handle BFloat16
# if tensor.dtype == torch.bfloat16:
# bfloat16_count += 1
# tensor = tensor.float() # Convert to float32 first
value = mx.array(tensor.numpy())#.astype(dtype)
# Apply mapping
mapped = map_t5_encoder_weights(key, value)
for new_key, new_value in mapped:
weights[new_key] = new_value
if bfloat16_count > 0:
print(f"⚠️ Converted {bfloat16_count} BFloat16 tensors to float32")
# Save as MLX format
print(f"Saving {len(weights)} parameters to: {output_path}")
mx.save_safetensors(output_path, weights)
return weights

View File

@ -0,0 +1,233 @@
import os
import torch
from safetensors.torch import save_file
from pathlib import Path
import json
from wan.modules.t5 import T5Model
def convert_pickle_to_safetensors(
pickle_path: str,
safetensors_path: str,
model_class=None,
model_kwargs=None,
load_method: str = "weights_only" # Changed default to weights_only
):
"""Convert PyTorch pickle file to safetensors format."""
print(f"Loading PyTorch weights from: {pickle_path}")
# Try multiple loading methods in order of preference
methods_to_try = [load_method, "weights_only", "state_dict", "full_model"]
for method in methods_to_try:
try:
if method == "weights_only":
state_dict = torch.load(pickle_path, map_location='cpu', weights_only=True)
elif method == "state_dict":
checkpoint = torch.load(pickle_path, map_location='cpu')
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif isinstance(checkpoint, dict) and 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
elif method == "full_model":
model = torch.load(pickle_path, map_location='cpu')
if hasattr(model, 'state_dict'):
state_dict = model.state_dict()
else:
state_dict = model
print(f"✅ Successfully loaded with method: {method}")
break
except Exception as e:
print(f"❌ Method {method} failed: {e}")
continue
else:
raise RuntimeError(f"All loading methods failed for {pickle_path}")
# Clean up the state dict
state_dict = clean_state_dict(state_dict)
print(f"Found {len(state_dict)} parameters")
# Convert BF16 to FP32 if needed
for key, tensor in state_dict.items():
if tensor.dtype == torch.bfloat16:
state_dict[key] = tensor.to(torch.float32)
print(f"Converted {key} from bfloat16 to float32")
# Save as safetensors
print(f"Saving to safetensors: {safetensors_path}")
os.makedirs(os.path.dirname(safetensors_path), exist_ok=True)
save_file(state_dict, safetensors_path)
print("✅ T5 conversion complete!")
return state_dict
def clean_state_dict(state_dict):
"""
Clean up state dict by removing unwanted prefixes or keys.
"""
cleaned = {}
for key, value in state_dict.items():
# Remove common prefixes that might interfere
clean_key = key
if clean_key.startswith('module.'):
clean_key = clean_key[7:]
if clean_key.startswith('model.'):
clean_key = clean_key[6:]
cleaned[clean_key] = value
if clean_key != key:
print(f"Cleaned key: {key} -> {clean_key}")
return cleaned
def load_with_your_torch_model(pickle_path: str, model_class, **model_kwargs):
"""
Load pickle weights into your specific PyTorch T5 model implementation.
Args:
pickle_path: Path to pickle file
model_class: Your T5Encoder class
**model_kwargs: Arguments for your model constructor
"""
print("Method 1: Loading into your PyTorch T5 model")
# Initialize your model
model = model_class(**model_kwargs)
# Load checkpoint
checkpoint = torch.load(pickle_path, map_location='cpu')
# Handle different checkpoint formats
if isinstance(checkpoint, dict):
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
# Assume the dict IS the state dict
state_dict = checkpoint
else:
# Assume it's a model object
state_dict = checkpoint.state_dict()
# Clean and load
state_dict = clean_state_dict(state_dict)
model.load_state_dict(state_dict, strict=False) # Use strict=False to ignore missing keys
return model, state_dict
def explore_pickle_file(pickle_path: str):
"""
Explore the contents of a pickle file to understand its structure.
"""
print(f"🔍 Exploring pickle file: {pickle_path}")
try:
# Try loading with weights_only first (safer)
print("\n--- Trying weights_only=True ---")
try:
data = torch.load(pickle_path, map_location='cpu', weights_only=True)
print(f"✅ Loaded with weights_only=True")
print(f"Type: {type(data)}")
if isinstance(data, dict):
print(f"Dictionary with {len(data)} keys:")
for i, key in enumerate(data.keys()):
print(f" {key}: {type(data[key])}")
if hasattr(data[key], 'shape'):
print(f" Shape: {data[key].shape}")
if i >= 9: # Show first 10 keys
break
except Exception as e:
print(f"❌ weights_only=True failed: {e}")
# Try regular loading
print("\n--- Trying regular loading ---")
data = torch.load(pickle_path, map_location='cpu')
print(f"✅ Loaded successfully")
print(f"Type: {type(data)}")
if hasattr(data, 'state_dict'):
print("📋 Found state_dict method")
state_dict = data.state_dict()
print(f"State dict has {len(state_dict)} parameters")
elif isinstance(data, dict):
print(f"📋 Dictionary with keys: {list(data.keys())}")
# Check for common checkpoint keys
if 'state_dict' in data:
print("Found 'state_dict' key")
print(f"state_dict has {len(data['state_dict'])} parameters")
elif 'model' in data:
print("Found 'model' key")
print(f"model has {len(data['model'])} parameters")
except Exception as e:
print(f"❌ Failed to load: {e}")
def full_conversion_pipeline(
pickle_path: str,
safetensors_path: str,
torch_model_class=None,
model_kwargs=None
):
"""
Complete pipeline: pickle -> safetensors -> ready for MLX conversion
"""
print("🚀 Starting full conversion pipeline")
print("="*50)
# Step 1: Explore the pickle file
print("Step 1: Exploring pickle file structure")
explore_pickle_file(pickle_path)
# Step 2: Convert to safetensors
print(f"\nStep 2: Converting to safetensors")
# Try different loading methods
for method in ["weights_only", "state_dict", "full_model"]:
try:
print(f"\nTrying load method: {method}")
state_dict = convert_pickle_to_safetensors(
pickle_path,
safetensors_path,
model_class=torch_model_class,
model_kwargs=model_kwargs,
load_method=method
)
print(f"✅ Success with method: {method}")
break
except Exception as e:
print(f"❌ Method {method} failed: {e}")
continue
else:
print("❌ All methods failed!")
return None
print(f"\n🎉 Conversion complete!")
print(f"Safetensors file saved to: {safetensors_path}")
print(f"Ready for MLX conversion using the previous script!")
return state_dict

View File

@ -0,0 +1,401 @@
# MLX implementation of text2video.py
import gc
import glob
import logging
import math
import os
import random
import sys
from contextlib import contextmanager
from functools import partial
from typing import Optional, Tuple, List, Dict, Any, Union
import mlx.core as mx
import mlx.nn as nn
from tqdm import tqdm
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae2_1 import Wan2_1_VAE
from .utils.fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from .wan_model_io import convert_wan_2_2_safetensors_to_mlx, convert_multiple_wan_2_2_safetensors_to_mlx, load_wan_2_2_from_safetensors
class WanT2V:
def __init__(
self,
config,
checkpoint_dir: str,
device_id: int = 0,
convert_model_dtype: bool = False,
):
r"""
Initializes the Wan text-to-video generation model components for MLX.
Args:
config (EasyDict):
Object containing model parameters initialized from config.py
checkpoint_dir (`str`):
Path to directory containing model checkpoints
device_id (`int`, *optional*, defaults to 0):
Device id (kept for compatibility, MLX handles device automatically)
convert_model_dtype (`bool`, *optional*, defaults to False):
Convert DiT model parameters dtype to 'config.param_dtype'.
"""
self.config = config
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
self.num_train_timesteps = config.num_train_timesteps
self.boundary = config.boundary
# Convert PyTorch dtype to MLX dtype
if str(config.param_dtype) == 'torch.bfloat16':
self.param_dtype = mx.bfloat16
elif str(config.param_dtype) == 'torch.float16':
self.param_dtype = mx.float16
elif str(config.param_dtype) == 'torch.float32':
self.param_dtype = mx.float32
else:
self.param_dtype = mx.float32 # default
# Initialize T5 text encoder
print(f"checkpoint_dir is: {checkpoint_dir}")
t5_checkpoint_path = os.path.join(checkpoint_dir, config.t5_checkpoint)
mlx_t5_path = t5_checkpoint_path.replace('.safetensors', '_mlx.safetensors')
if not os.path.exists(mlx_t5_path):
# Check if it's a .pth file that needs conversion
pth_path = t5_checkpoint_path.replace('.safetensors', '.pth')
if os.path.exists(pth_path):
logging.info(f"Converting T5 PyTorch model to safetensors: {pth_path}")
from .t5_torch_to_sf import convert_pickle_to_safetensors
convert_pickle_to_safetensors(pth_path, t5_checkpoint_path, load_method="weights_only")
# Convert torch safetensors to MLX safetensors
from .t5_model_io import convert_safetensors_to_mlx_weights
convert_safetensors_to_mlx_weights(t5_checkpoint_path, mlx_t5_path, float16=(self.param_dtype == mx.float16))
else:
raise FileNotFoundError(f"T5 checkpoint not found: {t5_checkpoint_path} or {pth_path}")
t5_checkpoint_path = mlx_t5_path # Use the MLX version
logging.info(f"Loading T5 text encoder... from {t5_checkpoint_path}")
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
checkpoint_path=t5_checkpoint_path,
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer))
# Initialize VAE - with automatic conversion
vae_path = os.path.join(checkpoint_dir, config.vae_checkpoint)
if not os.path.exists(vae_path):
# Check for PyTorch VAE file to convert
pth_vae_path = vae_path.replace('_mlx.safetensors', '.pth')
if not os.path.exists(pth_vae_path):
# Try alternative naming
pth_vae_path = os.path.join(checkpoint_dir, 'Wan2.1_VAE.pth')
if os.path.exists(pth_vae_path):
logging.info(f"Converting VAE PyTorch model to MLX: {pth_vae_path}")
from .vae_model_io import convert_pytorch_to_mlx
convert_pytorch_to_mlx(pth_vae_path, vae_path, float16=(self.param_dtype == mx.float16))
else:
raise FileNotFoundError(f"VAE checkpoint not found: {vae_path} or {pth_vae_path}")
logging.info("Loading VAE...")
self.vae = Wan2_1_VAE(vae_pth=vae_path)
# Load low and high noise models
logging.info(f"Creating WanModel from {checkpoint_dir}")
# Helper function to load model with automatic conversion
def load_model_with_conversion(checkpoint_dir, subfolder, config, param_dtype):
"""Load model with automatic PyTorch to MLX conversion if needed."""
# Look for existing MLX files
mlx_single = os.path.join(checkpoint_dir, subfolder, "diffusion_pytorch_model_mlx.safetensors")
mlx_pattern = os.path.join(checkpoint_dir, subfolder, "diffusion_mlx_model*.safetensors")
mlx_files = glob.glob(mlx_pattern)
# If no MLX files, convert PyTorch files
if not os.path.exists(mlx_single) and not mlx_files:
pytorch_single = os.path.join(checkpoint_dir, subfolder, "diffusion_pytorch_model.safetensors")
pytorch_pattern = os.path.join(checkpoint_dir, subfolder, "diffusion_pytorch_model-*.safetensors")
pytorch_files = glob.glob(pytorch_pattern)
if os.path.exists(pytorch_single):
logging.info(f"Converting PyTorch model to MLX: {pytorch_single}")
convert_wan_2_2_safetensors_to_mlx(
pytorch_single,
mlx_single,
float16=(param_dtype == mx.float16)
)
elif pytorch_files:
logging.info(f"Converting {len(pytorch_files)} PyTorch files to MLX")
convert_multiple_wan_2_2_safetensors_to_mlx(
os.path.join(checkpoint_dir, subfolder),
float16=(param_dtype == mx.float16)
)
mlx_files = glob.glob(mlx_pattern) # Update file list
else:
raise FileNotFoundError(f"No model files found in {os.path.join(checkpoint_dir, subfolder)}")
# Create model
model = WanModel(
model_type='t2v',
patch_size=config.patch_size,
text_len=config.text_len,
in_dim=16,
dim=config.dim,
ffn_dim=config.ffn_dim,
freq_dim=config.freq_dim,
text_dim=4096,
out_dim=16,
num_heads=config.num_heads,
num_layers=config.num_layers,
window_size=getattr(config, 'window_size', (-1, -1)),
qk_norm=getattr(config, 'qk_norm', True),
cross_attn_norm=getattr(config, 'cross_attn_norm', True),
eps=getattr(config, 'eps', 1e-6)
)
# Load weights
if os.path.exists(mlx_single):
logging.info(f"Loading single MLX file: {mlx_single}")
model = load_wan_2_2_from_safetensors(mlx_single, model, float16=(param_dtype == mx.float16))
else:
logging.info(f"Loading multiple MLX files from: {os.path.join(checkpoint_dir, subfolder)}")
model = load_wan_2_2_from_safetensors(
os.path.join(checkpoint_dir, subfolder),
model,
float16=(param_dtype == mx.float16)
)
return model
# Load both models
logging.info(f"Creating WanModel from {checkpoint_dir}")
logging.info("Loading low noise model")
self.low_noise_model = load_model_with_conversion(
checkpoint_dir,
config.low_noise_checkpoint,
self.config,
self.param_dtype
)
self.low_noise_model = self._configure_model(self.low_noise_model, convert_model_dtype)
logging.info("Loading high noise model")
self.high_noise_model = load_model_with_conversion(
checkpoint_dir,
config.high_noise_checkpoint,
self.config,
self.param_dtype
)
self.high_noise_model = self._configure_model(self.high_noise_model, convert_model_dtype)
self.sp_size = 1 # No sequence parallel in single device
self.sample_neg_prompt = config.sample_neg_prompt
def _configure_model(self, model: nn.Module, convert_model_dtype: bool) -> nn.Module:
"""
Configures a model object for MLX.
Args:
model (nn.Module):
The model instance to configure.
convert_model_dtype (`bool`):
Convert model parameters dtype to 'config.param_dtype'.
Returns:
nn.Module:
The configured model.
"""
model.eval()
if convert_model_dtype:
# In MLX, we would need to manually convert parameters
# This would be implemented in the actual model class
pass
return model
def _prepare_model_for_timestep(self, t, boundary, offload_model):
"""
Prepares and returns the required model for the current timestep.
"""
if t.item() >= boundary:
required_model_name = 'high_noise_model'
offload_model_name = 'low_noise_model'
else:
required_model_name = 'low_noise_model'
offload_model_name = 'high_noise_model'
# MLX doesn't need the CPU offloading logic, just return the right model
return getattr(self, required_model_name)
def generate(
self,
input_prompt: str,
size: Tuple[int, int] = (1280, 720),
frame_num: int = 81,
shift: float = 5.0,
sample_solver: str = 'unipc',
sampling_steps: int = 50,
guide_scale: Union[float, Tuple[float, float]] = 5.0,
n_prompt: str = "",
seed: int = -1,
offload_model: bool = True
) -> Optional[mx.array]:
r"""
Generates video frames from text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation
size (`tuple[int]`, *optional*, defaults to (1280,720)):
Controls video resolution, (width,height).
frame_num (`int`, *optional*, defaults to 81):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 50):
Number of diffusion sampling steps.
guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0):
Classifier-free guidance scale.
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion.
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed.
offload_model (`bool`, *optional*, defaults to True):
Not used in MLX version (kept for compatibility)
Returns:
mx.array:
Generated video frames tensor. Dimensions: (C, N, H, W)
"""
# Preprocess
guide_scale = (guide_scale, guide_scale) if isinstance(
guide_scale, float) else guide_scale
F = frame_num
target_shape = (
self.vae.model.z_dim,
(F - 1) // self.vae_stride[0] + 1,
size[1] // self.vae_stride[1],
size[0] // self.vae_stride[2]
)
seq_len = math.ceil(
(target_shape[2] * target_shape[3]) /
(self.patch_size[1] * self.patch_size[2]) *
target_shape[1] / self.sp_size
) * self.sp_size
if n_prompt == "":
n_prompt = self.sample_neg_prompt
# Set random seed
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
mx.random.seed(seed)
# Encode text prompts
context = self.text_encoder([input_prompt])
context_null = self.text_encoder([n_prompt])
# Generate initial noise
noise = [
mx.random.normal(
shape=target_shape,
dtype=mx.float32
)
]
# Set boundary
boundary = self.boundary * self.num_train_timesteps
# Initialize scheduler
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False
)
sample_scheduler.set_timesteps(
sampling_steps, shift=shift
)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False
)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
sigmas=sampling_sigmas
)
else:
raise NotImplementedError("Unsupported solver.")
# Sample videos
latents = noise
arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
mx.eval(latents)
# Denoising loop
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = latents
timestep = mx.array([t])
# Select model based on timestep
model = self._prepare_model_for_timestep(
t, boundary, offload_model
)
sample_guide_scale = guide_scale[1] if t.item() >= boundary else guide_scale[0]
# Model predictions
noise_pred_cond = model(
latent_model_input, t=timestep, **arg_c
)[0]
mx.eval(noise_pred_cond) # Force evaluation
noise_pred_uncond = model(
latent_model_input, t=timestep, **arg_null
)[0]
mx.eval(noise_pred_uncond) # Force evaluation
# Classifier-free guidance
noise_pred = noise_pred_uncond + sample_guide_scale * (
noise_pred_cond - noise_pred_uncond
)
mx.eval(noise_pred) # Force evaluation
# Scheduler step
temp_x0 = sample_scheduler.step(
mx.expand_dims(noise_pred, axis=0),
t,
mx.expand_dims(latents[0], axis=0),
return_dict=False
)[0]
latents = [mx.squeeze(temp_x0, axis=0)]
mx.eval(latents)
# Decode final latents
x0 = latents
videos = self.vae.decode(x0)
# Cleanup
del noise, latents
del sample_scheduler
if offload_model:
gc.collect()
return videos[0]

View File

@ -0,0 +1,12 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from .fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
__all__ = [
'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler'
]

View File

@ -0,0 +1,562 @@
import math
from typing import List, Optional, Tuple, Union
import mlx.core as mx
import numpy as np
def get_sampling_sigmas(sampling_steps, shift):
sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
sigma = (shift * sigma / (1 + (shift - 1) * sigma))
return sigma
def retrieve_timesteps(
scheduler,
num_inference_steps=None,
device=None,
timesteps=None,
sigmas=None,
**kwargs,
):
if timesteps is not None and sigmas is not None:
raise ValueError(
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
)
if timesteps is not None:
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class SchedulerOutput:
"""Output class for scheduler step results."""
def __init__(self, prev_sample: mx.array):
self.prev_sample = prev_sample
class FlowDPMSolverMultistepScheduler:
"""
MLX implementation of FlowDPMSolverMultistepScheduler.
A fast dedicated high-order solver for diffusion ODEs.
"""
order = 1
def __init__(
self,
num_train_timesteps: int = 1000,
solver_order: int = 2,
prediction_type: str = "flow_prediction",
shift: Optional[float] = 1.0,
use_dynamic_shifting: bool = False,
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
final_sigmas_type: Optional[str] = "zero",
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
invert_sigmas: bool = False,
):
# Store configuration
self.config = {
'num_train_timesteps': num_train_timesteps,
'solver_order': solver_order,
'prediction_type': prediction_type,
'shift': shift,
'use_dynamic_shifting': use_dynamic_shifting,
'thresholding': thresholding,
'dynamic_thresholding_ratio': dynamic_thresholding_ratio,
'sample_max_value': sample_max_value,
'algorithm_type': algorithm_type,
'solver_type': solver_type,
'lower_order_final': lower_order_final,
'euler_at_final': euler_at_final,
'final_sigmas_type': final_sigmas_type,
'lambda_min_clipped': lambda_min_clipped,
'variance_type': variance_type,
'invert_sigmas': invert_sigmas,
}
# Validate algorithm type
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
if algorithm_type == "deis":
self.config['algorithm_type'] = "dpmsolver++"
else:
raise NotImplementedError(f"{algorithm_type} is not implemented")
# Validate solver type
if solver_type not in ["midpoint", "heun"]:
if solver_type in ["logrho", "bh1", "bh2"]:
self.config['solver_type'] = "midpoint"
else:
raise NotImplementedError(f"{solver_type} is not implemented")
# Initialize scheduling
self.num_inference_steps = None
alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy()
sigmas = 1.0 - alphas
sigmas = mx.array(sigmas, dtype=mx.float32)
if not use_dynamic_shifting:
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.sigmas = sigmas
self.timesteps = sigmas * num_train_timesteps
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self._begin_index = None
self.sigma_min = float(self.sigmas[-1])
self.sigma_max = float(self.sigmas[0])
@property
def step_index(self):
return self._step_index
@property
def begin_index(self):
return self._begin_index
def set_begin_index(self, begin_index: int = 0):
self._begin_index = begin_index
def set_timesteps(
self,
num_inference_steps: Union[int, None] = None,
device: Union[str, None] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[Union[float, None]] = None,
shift: Optional[Union[float, None]] = None,
):
"""Sets the discrete timesteps used for the diffusion chain."""
if self.config['use_dynamic_shifting'] and mu is None:
raise ValueError(
"you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
)
if sigmas is None:
sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1]
if self.config['use_dynamic_shifting']:
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
if shift is None:
shift = self.config['shift']
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
if self.config['final_sigmas_type'] == "sigma_min":
sigma_last = self.sigma_min
elif self.config['final_sigmas_type'] == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config['final_sigmas_type']}"
)
timesteps = sigmas * self.config['num_train_timesteps']
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = mx.array(sigmas)
self.timesteps = mx.array(timesteps, dtype=mx.int64)
self.num_inference_steps = len(timesteps)
self.model_outputs = [None] * self.config['solver_order']
self.lower_order_nums = 0
self._step_index = None
self._begin_index = None
def _threshold_sample(self, sample: mx.array) -> mx.array:
"""Dynamic thresholding method."""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
# Flatten sample for quantile calculation
sample_flat = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = mx.abs(sample_flat)
# Compute quantile
s = mx.quantile(
abs_sample,
self.config['dynamic_thresholding_ratio'],
axis=1,
keepdims=True
)
s = mx.clip(s, 1, self.config['sample_max_value'])
# Threshold and normalize
sample_flat = mx.clip(sample_flat, -s, s) / s
sample = sample_flat.reshape(batch_size, channels, *remaining_dims)
return sample.astype(dtype)
def _sigma_to_t(self, sigma):
return sigma * self.config['num_train_timesteps']
def _sigma_to_alpha_sigma_t(self, sigma):
return 1 - sigma, sigma
def time_shift(self, mu: float, sigma: float, t: mx.array):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
def convert_model_output(
self,
model_output: mx.array,
sample: mx.array,
**kwargs,
) -> mx.array:
"""Convert model output to the corresponding type the algorithm needs."""
# DPM-Solver++ needs to solve an integral of the data prediction model
if self.config['algorithm_type'] in ["dpmsolver++", "sde-dpmsolver++"]:
if self.config['prediction_type'] == "flow_prediction":
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
else:
raise ValueError(
f"prediction_type given as {self.config['prediction_type']} must be "
f"'flow_prediction' for the FlowDPMSolverMultistepScheduler."
)
if self.config['thresholding']:
x0_pred = self._threshold_sample(x0_pred)
return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model
elif self.config['algorithm_type'] in ["dpmsolver", "sde-dpmsolver"]:
if self.config['prediction_type'] == "flow_prediction":
sigma_t = self.sigmas[self.step_index]
epsilon = sample - (1 - sigma_t) * model_output
else:
raise ValueError(
f"prediction_type given as {self.config['prediction_type']} must be "
f"'flow_prediction' for the FlowDPMSolverMultistepScheduler."
)
if self.config['thresholding']:
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
x0_pred = self._threshold_sample(x0_pred)
epsilon = model_output + x0_pred
return epsilon
def dpm_solver_first_order_update(
self,
model_output: mx.array,
sample: mx.array,
noise: Optional[mx.array] = None,
**kwargs,
) -> mx.array:
"""One step for the first-order DPMSolver (equivalent to DDIM)."""
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
lambda_s = mx.log(alpha_s) - mx.log(sigma_s)
h = lambda_t - lambda_s
if self.config['algorithm_type'] == "dpmsolver++":
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (mx.exp(-h) - 1.0)) * model_output
elif self.config['algorithm_type'] == "dpmsolver":
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (mx.exp(h) - 1.0)) * model_output
elif self.config['algorithm_type'] == "sde-dpmsolver++":
assert noise is not None
x_t = (
(sigma_t / sigma_s * mx.exp(-h)) * sample +
(alpha_t * (1 - mx.exp(-2.0 * h))) * model_output +
sigma_t * mx.sqrt(1.0 - mx.exp(-2 * h)) * noise
)
elif self.config['algorithm_type'] == "sde-dpmsolver":
assert noise is not None
x_t = (
(alpha_t / alpha_s) * sample -
2.0 * (sigma_t * (mx.exp(h) - 1.0)) * model_output +
sigma_t * mx.sqrt(mx.exp(2 * h) - 1.0) * noise
)
return x_t
def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[mx.array],
sample: mx.array,
noise: Optional[mx.array] = None,
**kwargs,
) -> mx.array:
"""One step for the second-order multistep DPMSolver."""
sigma_t, sigma_s0, sigma_s1 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0)
lambda_s1 = mx.log(alpha_s1) - mx.log(sigma_s1)
m0, m1 = model_output_list[-1], model_output_list[-2]
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
r0 = h_0 / h
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
if self.config['algorithm_type'] == "dpmsolver++":
if self.config['solver_type'] == "midpoint":
x_t = (
(sigma_t / sigma_s0) * sample -
(alpha_t * (mx.exp(-h) - 1.0)) * D0 -
0.5 * (alpha_t * (mx.exp(-h) - 1.0)) * D1
)
elif self.config['solver_type'] == "heun":
x_t = (
(sigma_t / sigma_s0) * sample -
(alpha_t * (mx.exp(-h) - 1.0)) * D0 +
(alpha_t * ((mx.exp(-h) - 1.0) / h + 1.0)) * D1
)
elif self.config['algorithm_type'] == "dpmsolver":
if self.config['solver_type'] == "midpoint":
x_t = (
(alpha_t / alpha_s0) * sample -
(sigma_t * (mx.exp(h) - 1.0)) * D0 -
0.5 * (sigma_t * (mx.exp(h) - 1.0)) * D1
)
elif self.config['solver_type'] == "heun":
x_t = (
(alpha_t / alpha_s0) * sample -
(sigma_t * (mx.exp(h) - 1.0)) * D0 -
(sigma_t * ((mx.exp(h) - 1.0) / h - 1.0)) * D1
)
elif self.config['algorithm_type'] == "sde-dpmsolver++":
assert noise is not None
if self.config['solver_type'] == "midpoint":
x_t = (
(sigma_t / sigma_s0 * mx.exp(-h)) * sample +
(alpha_t * (1 - mx.exp(-2.0 * h))) * D0 +
0.5 * (alpha_t * (1 - mx.exp(-2.0 * h))) * D1 +
sigma_t * mx.sqrt(1.0 - mx.exp(-2 * h)) * noise
)
elif self.config['solver_type'] == "heun":
x_t = (
(sigma_t / sigma_s0 * mx.exp(-h)) * sample +
(alpha_t * (1 - mx.exp(-2.0 * h))) * D0 +
(alpha_t * ((1.0 - mx.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 +
sigma_t * mx.sqrt(1.0 - mx.exp(-2 * h)) * noise
)
elif self.config['algorithm_type'] == "sde-dpmsolver":
assert noise is not None
if self.config['solver_type'] == "midpoint":
x_t = (
(alpha_t / alpha_s0) * sample -
2.0 * (sigma_t * (mx.exp(h) - 1.0)) * D0 -
(sigma_t * (mx.exp(h) - 1.0)) * D1 +
sigma_t * mx.sqrt(mx.exp(2 * h) - 1.0) * noise
)
elif self.config['solver_type'] == "heun":
x_t = (
(alpha_t / alpha_s0) * sample -
2.0 * (sigma_t * (mx.exp(h) - 1.0)) * D0 -
2.0 * (sigma_t * ((mx.exp(h) - 1.0) / h - 1.0)) * D1 +
sigma_t * mx.sqrt(mx.exp(2 * h) - 1.0) * noise
)
return x_t
def multistep_dpm_solver_third_order_update(
self,
model_output_list: List[mx.array],
sample: mx.array,
**kwargs,
) -> mx.array:
"""One step for the third-order multistep DPMSolver."""
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
self.sigmas[self.step_index - 2],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0)
lambda_s1 = mx.log(alpha_s1) - mx.log(sigma_s1)
lambda_s2 = mx.log(alpha_s2) - mx.log(sigma_s2)
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
r0, r1 = h_0 / h, h_1 / h
D0 = m0
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
if self.config['algorithm_type'] == "dpmsolver++":
x_t = (
(sigma_t / sigma_s0) * sample -
(alpha_t * (mx.exp(-h) - 1.0)) * D0 +
(alpha_t * ((mx.exp(-h) - 1.0) / h + 1.0)) * D1 -
(alpha_t * ((mx.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
)
elif self.config['algorithm_type'] == "dpmsolver":
x_t = (
(alpha_t / alpha_s0) * sample -
(sigma_t * (mx.exp(h) - 1.0)) * D0 -
(sigma_t * ((mx.exp(h) - 1.0) / h - 1.0)) * D1 -
(sigma_t * ((mx.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
)
return x_t
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = mx.where(schedule_timesteps == timestep)[0]
pos = 1 if len(indices) > 1 else 0
return int(indices[pos])
def _init_step_index(self, timestep):
"""Initialize the step_index counter for the scheduler."""
if self.begin_index is None:
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: mx.array,
timestep: Union[int, mx.array],
sample: mx.array,
generator=None,
variance_noise: Optional[mx.array] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""Predict the sample from the previous timestep."""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
# Improve numerical stability for small number of steps
lower_order_final = (
(self.step_index == len(self.timesteps) - 1) and
(self.config['euler_at_final'] or
(self.config['lower_order_final'] and len(self.timesteps) < 15) or
self.config['final_sigmas_type'] == "zero")
)
lower_order_second = (
(self.step_index == len(self.timesteps) - 2) and
self.config['lower_order_final'] and
len(self.timesteps) < 15
)
model_output = self.convert_model_output(model_output, sample=sample)
for i in range(self.config['solver_order'] - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output
# Upcast to avoid precision issues
sample = sample.astype(mx.float32)
# Generate noise if needed for SDE variants
if self.config['algorithm_type'] in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
noise = mx.random.normal(model_output.shape, dtype=mx.float32)
elif self.config['algorithm_type'] in ["sde-dpmsolver", "sde-dpmsolver++"]:
noise = variance_noise.astype(mx.float32)
else:
noise = None
if self.config['solver_order'] == 1 or self.lower_order_nums < 1 or lower_order_final:
prev_sample = self.dpm_solver_first_order_update(
model_output, sample=sample, noise=noise
)
elif self.config['solver_order'] == 2 or self.lower_order_nums < 2 or lower_order_second:
prev_sample = self.multistep_dpm_solver_second_order_update(
self.model_outputs, sample=sample, noise=noise
)
else:
prev_sample = self.multistep_dpm_solver_third_order_update(
self.model_outputs, sample=sample
)
if self.lower_order_nums < self.config['solver_order']:
self.lower_order_nums += 1
# Cast sample back to expected dtype
prev_sample = prev_sample.astype(model_output.dtype)
# Increase step index
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: mx.array, *args, **kwargs) -> mx.array:
"""Scale model input - no scaling needed for this scheduler."""
return sample
def add_noise(
self,
original_samples: mx.array,
noise: mx.array,
timesteps: mx.array,
) -> mx.array:
"""Add noise to original samples."""
sigmas = self.sigmas.astype(original_samples.dtype)
schedule_timesteps = self.timesteps
# Get step indices
if self.begin_index is None:
step_indices = [
self.index_for_timestep(t, schedule_timesteps)
for t in timesteps
]
elif self.step_index is not None:
step_indices = [self.step_index] * timesteps.shape[0]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices]
while len(sigma.shape) < len(original_samples.shape):
sigma = mx.expand_dims(sigma, -1)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
def __len__(self):
return self.config['num_train_timesteps']

View File

@ -0,0 +1,546 @@
import math
from typing import List, Optional, Tuple, Union
import mlx.core as mx
import numpy as np
class SchedulerOutput:
"""Output class for scheduler step results."""
def __init__(self, prev_sample: mx.array):
self.prev_sample = prev_sample
class FlowUniPCMultistepScheduler:
"""
MLX implementation of UniPCMultistepScheduler.
A training-free framework designed for the fast sampling of diffusion models.
"""
order = 1
def __init__(
self,
num_train_timesteps: int = 1000,
solver_order: int = 2,
prediction_type: str = "flow_prediction",
shift: Optional[float] = 1.0,
use_dynamic_shifting: bool = False,
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
predict_x0: bool = True,
solver_type: str = "bh2",
lower_order_final: bool = True,
disable_corrector: List[int] = [],
solver_p = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
final_sigmas_type: Optional[str] = "zero",
):
# Store configuration
self.config = {
'num_train_timesteps': num_train_timesteps,
'solver_order': solver_order,
'prediction_type': prediction_type,
'shift': shift,
'use_dynamic_shifting': use_dynamic_shifting,
'thresholding': thresholding,
'dynamic_thresholding_ratio': dynamic_thresholding_ratio,
'sample_max_value': sample_max_value,
'predict_x0': predict_x0,
'solver_type': solver_type,
'lower_order_final': lower_order_final,
'disable_corrector': disable_corrector,
'solver_p': solver_p,
'timestep_spacing': timestep_spacing,
'steps_offset': steps_offset,
'final_sigmas_type': final_sigmas_type,
}
# Validate solver type
if solver_type not in ["bh1", "bh2"]:
if solver_type in ["midpoint", "heun", "logrho"]:
self.config['solver_type'] = "bh2"
else:
raise NotImplementedError(
f"{solver_type} is not implemented for {self.__class__}"
)
self.predict_x0 = predict_x0
# setable values
self.num_inference_steps = None
alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy()
sigmas = 1.0 - alphas
sigmas = mx.array(sigmas, dtype=mx.float32)
if not use_dynamic_shifting:
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.sigmas = sigmas
self.timesteps = sigmas * num_train_timesteps
self.model_outputs = [None] * solver_order
self.timestep_list = [None] * solver_order
self.lower_order_nums = 0
self.disable_corrector = disable_corrector
self.solver_p = solver_p
self.last_sample = None
self._step_index = None
self._begin_index = None
self.sigma_min = float(self.sigmas[-1])
self.sigma_max = float(self.sigmas[0])
@property
def step_index(self):
"""The index counter for current timestep."""
return self._step_index
@property
def begin_index(self):
"""The index for the first timestep."""
return self._begin_index
def set_begin_index(self, begin_index: int = 0):
"""Sets the begin index for the scheduler."""
self._begin_index = begin_index
def set_timesteps(
self,
num_inference_steps: Union[int, None] = None,
device: Union[str, None] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[Union[float, None]] = None,
shift: Optional[Union[float, None]] = None,
):
"""Sets the discrete timesteps used for the diffusion chain."""
if self.config['use_dynamic_shifting'] and mu is None:
raise ValueError(
"you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
)
if sigmas is None:
sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1]
if self.config['use_dynamic_shifting']:
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
if shift is None:
shift = self.config['shift']
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
if self.config['final_sigmas_type'] == "sigma_min":
sigma_last = self.sigma_min
elif self.config['final_sigmas_type'] == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config['final_sigmas_type']}"
)
timesteps = sigmas * self.config['num_train_timesteps']
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = mx.array(sigmas)
self.timesteps = mx.array(timesteps, dtype=mx.int64)
self.num_inference_steps = len(timesteps)
self.model_outputs = [None] * self.config['solver_order']
self.lower_order_nums = 0
self.last_sample = None
if self.solver_p:
self.solver_p.set_timesteps(self.num_inference_steps, device=device)
# add an index counter for schedulers
self._step_index = None
self._begin_index = None
def _threshold_sample(self, sample: mx.array) -> mx.array:
"""Dynamic thresholding method."""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
# Flatten sample for quantile calculation
sample_flat = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = mx.abs(sample_flat)
# Compute quantile
s = mx.quantile(
abs_sample,
self.config['dynamic_thresholding_ratio'],
axis=1,
keepdims=True
)
s = mx.clip(s, 1, self.config['sample_max_value'])
# Threshold and normalize
sample_flat = mx.clip(sample_flat, -s, s) / s
sample = sample_flat.reshape(batch_size, channels, *remaining_dims)
return sample.astype(dtype)
def _sigma_to_t(self, sigma):
return sigma * self.config['num_train_timesteps']
def _sigma_to_alpha_sigma_t(self, sigma):
return 1 - sigma, sigma
def time_shift(self, mu: float, sigma: float, t: mx.array):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
def convert_model_output(
self,
model_output: mx.array,
sample: mx.array = None,
**kwargs,
) -> mx.array:
"""Convert the model output to the corresponding type the UniPC algorithm needs."""
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
if self.predict_x0:
if self.config['prediction_type'] == "flow_prediction":
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
else:
raise ValueError(
f"prediction_type given as {self.config['prediction_type']} must be 'flow_prediction' "
f"for the UniPCMultistepScheduler."
)
if self.config['thresholding']:
x0_pred = self._threshold_sample(x0_pred)
return x0_pred
else:
if self.config['prediction_type'] == "flow_prediction":
sigma_t = self.sigmas[self.step_index]
epsilon = sample - (1 - sigma_t) * model_output
else:
raise ValueError(
f"prediction_type given as {self.config['prediction_type']} must be 'flow_prediction' "
f"for the UniPCMultistepScheduler."
)
if self.config['thresholding']:
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
x0_pred = self._threshold_sample(x0_pred)
epsilon = model_output + x0_pred
return epsilon
def multistep_uni_p_bh_update(
self,
model_output: mx.array,
sample: mx.array = None,
order: int = None,
**kwargs,
) -> mx.array:
"""One step for the UniP (B(h) version)."""
model_output_list = self.model_outputs
s0 = self.timestep_list[-1]
m0 = model_output_list[-1]
x = sample
if self.solver_p:
x_t = self.solver_p.step(model_output, s0, x).prev_sample
return x_t
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0)
h = lambda_t - lambda_s0
rks = []
D1s = []
for i in range(1, order):
si = self.step_index - i
mi = model_output_list[-(i + 1)]
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
lambda_si = mx.log(alpha_si) - mx.log(sigma_si)
rk = (lambda_si - lambda_s0) / h
rks.append(rk)
D1s.append((mi - m0) / rk)
rks.append(1.0)
rks = mx.array(rks)
R = []
b = []
hh = -h if self.predict_x0 else h
h_phi_1 = mx.exp(hh) - 1 # h\phi_1(h) = e^h - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.config['solver_type'] == "bh1":
B_h = hh
elif self.config['solver_type'] == "bh2":
B_h = mx.exp(hh) - 1
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(mx.power(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= i + 1
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = mx.stack(R)
b = mx.array(b)
if len(D1s) > 0:
D1s = mx.stack(D1s, axis=1) # (B, K)
# for order 2, we use a simplified version
if order == 2:
rhos_p = mx.array([0.5], dtype=x.dtype)
else:
rhos_p = mx.linalg.solve(R[:-1, :-1], b[:-1], stream=mx.cpu).astype(x.dtype)
else:
D1s = None
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
pred_res = mx.sum(rhos_p[:, None, None, None] * D1s, axis=0)
else:
pred_res = 0
x_t = x_t_ - alpha_t * B_h * pred_res
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
pred_res = mx.sum(rhos_p[:, None, None, None] * D1s, axis=0)
else:
pred_res = 0
x_t = x_t_ - sigma_t * B_h * pred_res
x_t = x_t.astype(x.dtype)
return x_t
def multistep_uni_c_bh_update(
self,
this_model_output: mx.array,
last_sample: mx.array = None,
this_sample: mx.array = None,
order: int = None,
**kwargs,
) -> mx.array:
"""One step for the UniC (B(h) version)."""
model_output_list = self.model_outputs
m0 = model_output_list[-1]
x = last_sample
x_t = this_sample
model_t = this_model_output
sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0)
h = lambda_t - lambda_s0
rks = []
D1s = []
for i in range(1, order):
si = self.step_index - (i + 1)
mi = model_output_list[-(i + 1)]
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
lambda_si = mx.log(alpha_si) - mx.log(sigma_si)
rk = (lambda_si - lambda_s0) / h
rks.append(rk)
D1s.append((mi - m0) / rk)
rks.append(1.0)
rks = mx.array(rks)
R = []
b = []
hh = -h if self.predict_x0 else h
h_phi_1 = mx.exp(hh) - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.config['solver_type'] == "bh1":
B_h = hh
elif self.config['solver_type'] == "bh2":
B_h = mx.exp(hh) - 1
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(mx.power(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= i + 1
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = mx.stack(R)
b = mx.array(b)
if len(D1s) > 0:
D1s = mx.stack(D1s, axis=1)
else:
D1s = None
# for order 1, we use a simplified version
if order == 1:
rhos_c = mx.array([0.5], dtype=x.dtype)
else:
rhos_c = mx.linalg.solve(R, b, stream=mx.cpu).astype(x.dtype)
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
corr_res = mx.sum(rhos_c[:-1, None, None, None] * D1s, axis=0)
else:
corr_res = 0
D1_t = model_t - m0
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
corr_res = mx.sum(rhos_c[:-1, None, None, None] * D1s, axis=0)
else:
corr_res = 0
D1_t = model_t - m0
x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
x_t = x_t.astype(x.dtype)
return x_t
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
condition = schedule_timesteps == timestep
indices = mx.argmax(condition.astype(mx.int32))
# Convert scalar to int and return
return int(indices)
def _init_step_index(self, timestep):
"""Initialize the step_index counter for the scheduler."""
if self.begin_index is None:
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: mx.array,
timestep: Union[int, mx.array],
sample: mx.array,
return_dict: bool = True,
generator=None
) -> Union[SchedulerOutput, Tuple]:
"""Predict the sample from the previous timestep."""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
use_corrector = (
self.step_index > 0 and
self.step_index - 1 not in self.disable_corrector and
self.last_sample is not None
)
model_output_convert = self.convert_model_output(
model_output, sample=sample
)
if use_corrector:
sample = self.multistep_uni_c_bh_update(
this_model_output=model_output_convert,
last_sample=self.last_sample,
this_sample=sample,
order=self.this_order,
)
for i in range(self.config['solver_order'] - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.timestep_list[i] = self.timestep_list[i + 1]
self.model_outputs[-1] = model_output_convert
self.timestep_list[-1] = timestep
if self.config['lower_order_final']:
this_order = min(
self.config['solver_order'],
len(self.timesteps) - self.step_index
)
else:
this_order = self.config['solver_order']
self.this_order = min(this_order, self.lower_order_nums + 1)
assert self.this_order > 0
self.last_sample = sample
prev_sample = self.multistep_uni_p_bh_update(
model_output=model_output,
sample=sample,
order=self.this_order,
)
if self.lower_order_nums < self.config['solver_order']:
self.lower_order_nums += 1
# Increase step index
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: mx.array, *args, **kwargs) -> mx.array:
"""Scale model input - no scaling needed for this scheduler."""
return sample
def add_noise(
self,
original_samples: mx.array,
noise: mx.array,
timesteps: mx.array,
) -> mx.array:
"""Add noise to original samples."""
sigmas = self.sigmas.astype(original_samples.dtype)
schedule_timesteps = self.timesteps
# Get step indices
if self.begin_index is None:
step_indices = [
self.index_for_timestep(t, schedule_timesteps)
for t in timesteps
]
elif self.step_index is not None:
step_indices = [self.step_index] * timesteps.shape[0]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices]
while len(sigma.shape) < len(original_samples.shape):
sigma = mx.expand_dims(sigma, -1)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
def __len__(self):
return self.config['num_train_timesteps']

View File

@ -0,0 +1,363 @@
# Copied from https://github.com/kq-chen/qwen-vl-utils
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from __future__ import annotations
import base64
import logging
import math
import os
import sys
import time
import warnings
from functools import lru_cache
from io import BytesIO
import requests
import torch
import torchvision
from packaging import version
from PIL import Image
from torchvision import io, transforms
from torchvision.transforms import InterpolationMode
logger = logging.getLogger(__name__)
IMAGE_FACTOR = 28
MIN_PIXELS = 4 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
VIDEO_MIN_PIXELS = 128 * 28 * 28
VIDEO_MAX_PIXELS = 768 * 28 * 28
VIDEO_TOTAL_PIXELS = 24576 * 28 * 28
FRAME_FACTOR = 2
FPS = 2.0
FPS_MIN_FRAMES = 4
FPS_MAX_FRAMES = 768
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def smart_resize(height: int,
width: int,
factor: int = IMAGE_FACTOR,
min_pixels: int = MIN_PIXELS,
max_pixels: int = MAX_PIXELS) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def fetch_image(ele: dict[str, str | Image.Image],
size_factor: int = IMAGE_FACTOR) -> Image.Image:
if "image" in ele:
image = ele["image"]
else:
image = ele["image_url"]
image_obj = None
if isinstance(image, Image.Image):
image_obj = image
elif image.startswith("http://") or image.startswith("https://"):
image_obj = Image.open(requests.get(image, stream=True).raw)
elif image.startswith("file://"):
image_obj = Image.open(image[7:])
elif image.startswith("data:image"):
if "base64," in image:
_, base64_data = image.split("base64,", 1)
data = base64.b64decode(base64_data)
image_obj = Image.open(BytesIO(data))
else:
image_obj = Image.open(image)
if image_obj is None:
raise ValueError(
f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
)
image = image_obj.convert("RGB")
## resize
if "resized_height" in ele and "resized_width" in ele:
resized_height, resized_width = smart_resize(
ele["resized_height"],
ele["resized_width"],
factor=size_factor,
)
else:
width, height = image.size
min_pixels = ele.get("min_pixels", MIN_PIXELS)
max_pixels = ele.get("max_pixels", MAX_PIXELS)
resized_height, resized_width = smart_resize(
height,
width,
factor=size_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
image = image.resize((resized_width, resized_height))
return image
def smart_nframes(
ele: dict,
total_frames: int,
video_fps: int | float,
) -> int:
"""calculate the number of frames for video used for model inputs.
Args:
ele (dict): a dict contains the configuration of video.
support either `fps` or `nframes`:
- nframes: the number of frames to extract for model inputs.
- fps: the fps to extract frames for model inputs.
- min_frames: the minimum number of frames of the video, only used when fps is provided.
- max_frames: the maximum number of frames of the video, only used when fps is provided.
total_frames (int): the original total number of frames of the video.
video_fps (int | float): the original fps of the video.
Raises:
ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
Returns:
int: the number of frames for video used for model inputs.
"""
assert not ("fps" in ele and
"nframes" in ele), "Only accept either `fps` or `nframes`"
if "nframes" in ele:
nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
else:
fps = ele.get("fps", FPS)
min_frames = ceil_by_factor(
ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
max_frames = floor_by_factor(
ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)),
FRAME_FACTOR)
nframes = total_frames / video_fps * fps
nframes = min(max(nframes, min_frames), max_frames)
nframes = round_by_factor(nframes, FRAME_FACTOR)
if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
raise ValueError(
f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}."
)
return nframes
def _read_video_torchvision(ele: dict,) -> torch.Tensor:
"""read video using torchvision.io.read_video
Args:
ele (dict): a dict contains the configuration of video.
support keys:
- video: the path of video. support "file://", "http://", "https://" and local path.
- video_start: the start time of video.
- video_end: the end time of video.
Returns:
torch.Tensor: the video tensor with shape (T, C, H, W).
"""
video_path = ele["video"]
if version.parse(torchvision.__version__) < version.parse("0.19.0"):
if "http://" in video_path or "https://" in video_path:
warnings.warn(
"torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0."
)
if "file://" in video_path:
video_path = video_path[7:]
st = time.time()
video, audio, info = io.read_video(
video_path,
start_pts=ele.get("video_start", 0.0),
end_pts=ele.get("video_end", None),
pts_unit="sec",
output_format="TCHW",
)
total_frames, video_fps = video.size(0), info["video_fps"]
logger.info(
f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s"
)
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(0, total_frames - 1, nframes).round().long()
video = video[idx]
return video
def is_decord_available() -> bool:
import importlib.util
return importlib.util.find_spec("decord") is not None
def _read_video_decord(ele: dict,) -> torch.Tensor:
"""read video using decord.VideoReader
Args:
ele (dict): a dict contains the configuration of video.
support keys:
- video: the path of video. support "file://", "http://", "https://" and local path.
- video_start: the start time of video.
- video_end: the end time of video.
Returns:
torch.Tensor: the video tensor with shape (T, C, H, W).
"""
import decord
video_path = ele["video"]
st = time.time()
vr = decord.VideoReader(video_path)
# TODO: support start_pts and end_pts
if 'video_start' in ele or 'video_end' in ele:
raise NotImplementedError(
"not support start_pts and end_pts in decord for now.")
total_frames, video_fps = len(vr), vr.get_avg_fps()
logger.info(
f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s"
)
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
video = vr.get_batch(idx).asnumpy()
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
return video
VIDEO_READER_BACKENDS = {
"decord": _read_video_decord,
"torchvision": _read_video_torchvision,
}
FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
@lru_cache(maxsize=1)
def get_video_reader_backend() -> str:
if FORCE_QWENVL_VIDEO_READER is not None:
video_reader_backend = FORCE_QWENVL_VIDEO_READER
elif is_decord_available():
video_reader_backend = "decord"
else:
video_reader_backend = "torchvision"
logger.info(
f"qwen-vl-utils using {video_reader_backend} to read video.",
file=sys.stderr)
return video_reader_backend
def fetch_video(
ele: dict,
image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]:
if isinstance(ele["video"], str):
video_reader_backend = get_video_reader_backend()
video = VIDEO_READER_BACKENDS[video_reader_backend](ele)
nframes, _, height, width = video.shape
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
max_pixels = max(
min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR),
int(min_pixels * 1.05))
max_pixels = ele.get("max_pixels", max_pixels)
if "resized_height" in ele and "resized_width" in ele:
resized_height, resized_width = smart_resize(
ele["resized_height"],
ele["resized_width"],
factor=image_factor,
)
else:
resized_height, resized_width = smart_resize(
height,
width,
factor=image_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
video = transforms.functional.resize(
video,
[resized_height, resized_width],
interpolation=InterpolationMode.BICUBIC,
antialias=True,
).float()
return video
else:
assert isinstance(ele["video"], (list, tuple))
process_info = ele.copy()
process_info.pop("type", None)
process_info.pop("video", None)
images = [
fetch_image({
"image": video_element,
**process_info
},
size_factor=image_factor)
for video_element in ele["video"]
]
nframes = ceil_by_factor(len(images), FRAME_FACTOR)
if len(images) < nframes:
images.extend([images[-1]] * (nframes - len(images)))
return images
def extract_vision_info(
conversations: list[dict] | list[list[dict]]) -> list[dict]:
vision_infos = []
if isinstance(conversations[0], dict):
conversations = [conversations]
for conversation in conversations:
for message in conversation:
if isinstance(message["content"], list):
for ele in message["content"]:
if ("image" in ele or "image_url" in ele or
"video" in ele or
ele["type"] in ("image", "image_url", "video")):
vision_infos.append(ele)
return vision_infos
def process_vision_info(
conversations: list[dict] | list[list[dict]],
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] |
None]:
vision_infos = extract_vision_info(conversations)
## Read images or videos
image_inputs = []
video_inputs = []
for vision_info in vision_infos:
if "image" in vision_info or "image_url" in vision_info:
image_inputs.append(fetch_image(vision_info))
elif "video" in vision_info:
video_inputs.append(fetch_video(vision_info))
else:
raise ValueError("image, image_url or video should in content.")
if len(image_inputs) == 0:
image_inputs = None
if len(video_inputs) == 0:
video_inputs = None
return image_inputs, video_inputs

View File

@ -0,0 +1,147 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
T2V_A14B_ZH_SYS_PROMPT = \
''' 你是一位电影导演旨在为用户输入的原始prompt添加电影元素改写为优质Prompt使其完整、具有表现力。
任务要求
1. 对于用户输入的prompt,在不改变prompt的原意如主体动作前提下从下列电影美学设定中选择部分合适的时间光源光线强度光线角度对比度饱和度色调拍摄角度镜头大小构图的电影设定细节,将这些内容添加到prompt中让画面变得更美注意可以任选不必每项都有
时间["白天", "夜晚", "黎明", "日出"], 可以不选, 如果prompt没有特别说明则选白天 !
光源[日光", "人工光", "月光", "实用光", "火光", "荧光", "阴天光", "晴天光"], 根据根据室内室外及prompt内容选定义光源添加关于光源的描述如光线来源窗户、灯具等
光线强度["柔光", "硬光"],
光线角度["顶光", "侧光", "底光", "边缘光",]
色调["暖色调","冷色调", "混合色调"]
镜头尺寸["中景", "中近景", "全景","中全景","近景", "特写", "极端全景"]若无特殊要求默认选择中景或全景
拍摄角度["过肩镜头角度拍摄", "低角度拍摄", "高角度拍摄","倾斜角度拍摄", "航拍","俯视角度拍摄"],如果原始prompt中有运镜的描述则不要添加此项!
构图["中心构图""平衡构图","右侧重构图", "左侧重构图", "对称构图", "短边构图"] 若无特殊要求默认选择中心构图
2. 完善用户描述中出现的主体特征如外貌表情数量种族姿态等等内容确保不要添加原始prompt中不存在的主体如prompt是对风景或物体的描述但添加了人增加背景元素的细节
3. 不要输出关于氛围感觉等文学描写画面充满运动感与生活张力突出正式氛围
4. 对于prompt中的动作详细解释运动的发生过程若没有动作则添加动作描述摇晃身体跳舞等对背景元素也可添加适当运动如云彩飘动风吹树叶等
5. 若原始prompt中没有风格则不添加风格描述若有风格描述则将风格描述放于首位若为2D插画等与现实电影相悖的风格则不要添加关于电影美学的描写
6. 若prompt出现天空的描述则改为湛蓝色的天空相关描述避免曝光
7. 改写后的prompt字数控制在60-200字左右, 不要输出类似改写后prompt:这样的输出
8. 如果用户输入存在疑似脱衣服衣物透明化轻纱衣物湿身洗澡舔舐牛奶/酸奶/白色液体性暗示强奸漏出或微露胸部爬行性器官儿童身体裸露等色情内容请将原prompt替换为内容完全不同的高美感prompt而不要拒答或者提醒用户
9. 如果用户输入存在比基尼穿着内容请将原prompt替换为确保内容发生在沙滩泳池海边的高美感prompt而不要拒答或者提醒用户
10. 如果用户输入存在性别多元化同性恋亲吻脸颊贴近两个男性亲吻或两个女性亲吻/性交等同性恋内容请将原prompt替换为不同动作不能是亲吻和其他亲密动作的高美感prompt而不要拒答或者提醒用户
生成的 prompt 示例
1.边缘光中近景日光左侧重构图暖色调硬光晴天光侧光白天一个年轻的女孩坐在高草丛生的田野中两条毛发蓬松的小毛驴站在她身后女孩大约十一二岁穿着简单的碎花裙子头发扎成两条麻花辫脸上带着纯真的笑容她双腿交叉坐下双手轻轻抚弄身旁的野花小毛驴体型健壮耳朵竖起好奇地望着镜头方向阳光洒在田野上营造出温暖自然的画面感
2.黎明顶光俯视角度拍摄日光长焦中心构图近景高角度拍摄荧光柔光冷色调在昏暗的环境中一个外国白人女子在水中仰面漂浮俯拍近景镜头中她有着棕色的短发脸上有几颗雀斑随着镜头下摇她转过头来面向右侧水面上泛起一圈涟漪虚化的背景一片漆黑只有微弱的光线照亮了女子的脸庞和水面的一部分区域水面呈现蓝色女子穿着一件蓝色的吊带肩膀裸露在外
3.右侧重构图暖色调底光侧光夜晚火光过肩镜头角度拍摄, 镜头平拍拍摄外国女子在室内的近景她穿着棕色的衣服戴着彩色的项链和粉色的帽子坐在深灰色的椅子上双手放在黑色的桌子上眼睛看着镜头的左侧嘴巴张动左手上下晃动桌子上有白色的蜡烛有黄色的火焰后面是黑色的墙前面有黑色的网状架子旁边是黑色的箱子上面有一些黑色的物品都做了虚化的处理
4. 二次元厚涂动漫插画一个猫耳兽耳白人少女手持文件夹摇晃神情略带不满她深紫色长发红色眼睛身穿深灰色短裙和浅灰色上衣腰间系着白色系带胸前佩戴名牌上面写着黑体中文"紫阳"淡黄色调室内背景隐约可见一些家具轮廓少女头顶有一个粉色光圈线条流畅的日系赛璐璐风格近景半身略俯视视角
'''
T2V_A14B_EN_SYS_PROMPT = \
'''你是一位电影导演旨在为用户输入的原始prompt添加电影元素改写为优质英文Prompt使其完整、具有表现力注意输出必须是英文
任务要求
1. 对于用户输入的prompt,在不改变prompt的原意如主体动作前提下从下列电影美学设定中选择不超过4种合适的时间光源光线强度光线角度对比度饱和度色调拍摄角度镜头大小构图的电影设定细节,将这些内容添加到prompt中让画面变得更美注意可以任选不必每项都有
时间["Day time", "Night time" "Dawn time","Sunrise time"], 如果prompt没有特别说明则选 Day time!!!
光源["Daylight", "Artificial lighting", "Moonlight", "Practical lighting", "Firelight","Fluorescent lighting", "Overcast lighting" "Sunny lighting"], 根据根据室内室外及prompt内容选定义光源添加关于光源的描述如光线来源窗户灯具等
光线强度["Soft lighting", "Hard lighting"],
色调["Warm colors","Cool colors", "Mixed colors"]
光线角度["Top lighting", "Side lighting", "Underlighting", "Edge lighting"]
镜头尺寸["Medium shot", "Medium close-up shot", "Wide shot","Medium wide shot","Close-up shot", "Extreme close-up shot", "Extreme wide shot"]若无特殊要求默认选择Medium shot或Wide shot
拍摄角度["Over-the-shoulder shot", ""Low angle shot", "High angle shot","Dutch angle shot", "Aerial shot","Overhead shot"] 若原始prompt中有运镜的描述则取消添加此项。
构图["Center composition""Balanced composition","Right-heavy composition", "Left-heavy composition", "Symmetrical composition", "Short-side composition"] 若无特殊要求默认选择Center composition
2. 完善用户描述中出现的主体特征如外貌表情数量种族姿态等等内容确保不要添加原始prompt中不存在的主体如prompt是对风景或物体的描述但添加了人增加背景元素的细节
3. 不要输出关于氛围感觉等文学描写画面充满运动感与生活张力突出正式氛围
4. 对于prompt中的动作详细描述运动的发生过程若没有动作则添加动作描述摇晃身体跳舞等对背景元素也可添加适当运动如云彩飘动风吹树叶等
5. 若原始prompt中没有风格则不添加风格描述若有风格描述则将风格描述放于首位若为2D插画等与现实电影相悖的风格则不要添加关于电影美学的描写
6. 若prompt出现天空的描述则改为湛蓝色的天空相关描述避免曝光
7. 输出必须是全英文改写后的prompt字数控制在60-200字左右, 不要输出类似改写后prompt:这样的输出
8. 如果用户输入存在疑似脱衣服衣物透明化轻纱衣物湿身洗澡舔舐牛奶/酸奶/白色液体性暗示强奸漏出或微露胸部爬行性器官儿童身体裸露等色情内容请将原prompt替换为内容完全不同的高美感prompt而不要拒答或者提醒用户
9. 如果用户输入存在比基尼穿着内容请将原prompt替换为确保内容发生在沙滩泳池海边的高美感prompt而不要拒答或者提醒用户
10. 如果用户输入存在性别多元化同性恋亲吻脸颊贴近两个男性亲吻或两个女性亲吻/性交等同性恋内容请将原prompt替换为不同动作不能是亲吻和其他亲密动作的高美感prompt而不要拒答或者提醒用户
生成的 prompt 示例
1.Edge lighting, medium close-up shot, daylight, left-heavy composition. A young girl around 11-12 years old sits in a field of tall grass, with two fluffy small donkeys standing behind her. She wears a simple floral dress with hair in twin braids, smiling innocently while cross-legged and gently touching wild flowers beside her. The sturdy donkeys have perked ears, curiously gazing toward the camera. Sunlight bathes the field, creating a warm natural atmosphere.
2.Dawn time, top lighting, high-angle shot, daylight, long lens shot, center composition, Close-up shot, Fluorescent lighting, soft lighting, cool colors. In dim surroundings, a Caucasian woman floats on her back in water. The俯拍close-up shows her brown short hair and freckled face. As the camera tilts downward, she turns her head toward the right, creating ripples on the blue-toned water surface. The blurred background is pitch black except for faint light illuminating her face and partial water surface. She wears a blue sleeveless top with bare shoulders.
3.Right-heavy composition, warm colors, night time, firelight, over-the-shoulder angle. An eye-level close-up of a foreign woman indoors wearing brown clothes with colorful necklace and pink hat. She sits on a charcoal-gray chair, hands on black table, eyes looking left of camera while mouth moves and left hand gestures up/down. White candles with yellow flames sit on the table. Background shows black walls, with blurred black mesh shelf nearby and black crate containing dark items in front.
4."Anime-style thick-painted style. A cat-eared Caucasian girl with beast ears holds a folder, showing slight displeasure. Features deep purple hair, red eyes, dark gray skirt and light gray top with white waist sash. A name tag labeled 'Ziyang' in bold Chinese characters hangs on her chest. Pale yellow indoor background with faint furniture outlines. A pink halo floats above her head. Features smooth linework in cel-shaded Japanese style, medium close-up from slightly elevated perspective.
'''
I2V_A14B_ZH_SYS_PROMPT = \
'''你是一个视频描述提示词的改写专家,你的任务是根据用户给你输入的图像,对提供的视频描述提示词进行改写,你要强调潜在的动态内容。具体要求如下
用户输入的语言可能含有多样化的描述如markdown文档格式指令格式长度过长或者过短你需要根据图片的内容和用户的输入的提示词尽可能提取用户输入的提示词和图片关联信息
你改写的视频描述结果要尽可能保留提供给你的视频描述提示词中动态部分保留主体的动作
你要根据图像强调并简化视频描述提示词中的图像主体如果用户只提供了动作你要根据图像内容合理补充跳舞补充称一个女孩在跳舞
如果用户输入的提示词过长你需要提炼潜在的动作过程
如果用户输入的提示词过短综合用户输入的提示词以及画面内容合理的增加潜在的运动信息
你要根据图像保留并强调视频描述提示词中关于运镜手段的描述镜头上摇镜头从左到右镜头从右到左等等你要保留镜头拍摄两个男人打斗他们先是躺在地上随后镜头向上移动拍摄他们站起来接着镜头向左移动左边男人拿着一个蓝色的东西右边男人上前抢夺两人激烈地来回争抢
你需要给出对视频描述的动态内容不要添加对于静态场景的描述如果用户输入的描述已经在画面中出现则移除这些描述
改写后的prompt字数控制在100字以下
无论用户输入那种语言你都需要输出中文
改写后 prompt 示例
1. 镜头后拉拍摄两个外国男人走在楼梯上镜头左侧的男人右手搀扶着镜头右侧的男人
2. 一只黑色的小松鼠专注地吃着东西偶尔抬头看看四周
3. 男子说着话表情从微笑逐渐转变为闭眼然后睁开眼睛最后是闭眼微笑他的手势活跃在说话时做出一系列的手势
4. 一个人正在用尺子和笔进行测量的特写右手用一支黑色水性笔在纸上画出一条直线
5. 一辆车模型在木板上形式车辆从画面的右侧向左侧移动经过一片草地和一些木制结构
6. 镜头左移后前推拍摄一个人坐在防波堤上
7. 男子说着话他的表情和手势随着对话内容的变化而变化但整体场景保持不变
8. 镜头左移后前推拍摄一个人坐在防波堤上
9. 带着珍珠项链的女子看向画面右侧并说着话
请直接输出改写后的文本不要进行多余的回复'''
I2V_A14B_EN_SYS_PROMPT = \
'''You are an expert in rewriting video description prompts. Your task is to rewrite the provided video description prompts based on the images given by users, emphasizing potential dynamic content. Specific requirements are as follows:
The user's input language may include diverse descriptions, such as markdown format, instruction format, or be too long or too short. You need to extract the relevant information from the users input and associate it with the image content.
Your rewritten video description should retain the dynamic parts of the provided prompts, focusing on the main subject's actions. Emphasize and simplify the main subject of the image while retaining their movement. If the user only provides an action (e.g., "dancing"), supplement it reasonably based on the image content (e.g., "a girl is dancing").
If the users input prompt is too long, refine it to capture the essential action process. If the input is too short, add reasonable motion-related details based on the image content.
Retain and emphasize descriptions of camera movements, such as "the camera pans up," "the camera moves from left to right," or "the camera moves from right to left." For example: "The camera captures two men fighting. They start lying on the ground, then the camera moves upward as they stand up. The camera shifts left, showing the man on the left holding a blue object while the man on the right tries to grab it, resulting in a fierce back-and-forth struggle."
Focus on dynamic content in the video description and avoid adding static scene descriptions. If the users input already describes elements visible in the image, remove those static descriptions.
Limit the rewritten prompt to 100 words or less. Regardless of the input language, your output must be in English.
Examples of rewritten prompts:
The camera pulls back to show two foreign men walking up the stairs. The man on the left supports the man on the right with his right hand.
A black squirrel focuses on eating, occasionally looking around.
A man talks, his expression shifting from smiling to closing his eyes, reopening them, and finally smiling with closed eyes. His gestures are lively, making various hand motions while speaking.
A close-up of someone measuring with a ruler and pen, drawing a straight line on paper with a black marker in their right hand.
A model car moves on a wooden board, traveling from right to left across grass and wooden structures.
The camera moves left, then pushes forward to capture a person sitting on a breakwater.
A man speaks, his expressions and gestures changing with the conversation, while the overall scene remains constant.
The camera moves left, then pushes forward to capture a person sitting on a breakwater.
A woman wearing a pearl necklace looks to the right and speaks.
Output only the rewritten text without additional responses.'''
I2V_A14B_EMPTY_ZH_SYS_PROMPT = \
'''你是一个视频描述提示词的撰写专家,你的任务是根据用户给你输入的图像,发挥合理的想象,让这张图动起来,你要强调潜在的动态内容。具体要求如下
你需要根据图片的内容想象出运动的主体
你输出的结果应强调图片中的动态部分保留主体的动作
你需要给出对视频描述的动态内容不要有过多的对于静态场景的描述
输出的prompt字数控制在100字以下
你需要输出中文
prompt 示例
1. 镜头后拉拍摄两个外国男人走在楼梯上镜头左侧的男人右手搀扶着镜头右侧的男人
2. 一只黑色的小松鼠专注地吃着东西偶尔抬头看看四周
3. 男子说着话表情从微笑逐渐转变为闭眼然后睁开眼睛最后是闭眼微笑他的手势活跃在说话时做出一系列的手势
4. 一个人正在用尺子和笔进行测量的特写右手用一支黑色水性笔在纸上画出一条直线
5. 一辆车模型在木板上形式车辆从画面的右侧向左侧移动经过一片草地和一些木制结构
6. 镜头左移后前推拍摄一个人坐在防波堤上
7. 男子说着话他的表情和手势随着对话内容的变化而变化但整体场景保持不变
8. 镜头左移后前推拍摄一个人坐在防波堤上
9. 带着珍珠项链的女子看向画面右侧并说着话
请直接输出文本不要进行多余的回复'''
I2V_A14B_EMPTY_EN_SYS_PROMPT = \
'''You are an expert in writing video description prompts. Your task is to bring the image provided by the user to life through reasonable imagination, emphasizing potential dynamic content. Specific requirements are as follows:
You need to imagine the moving subject based on the content of the image.
Your output should emphasize the dynamic parts of the image and retain the main subjects actions.
Focus only on describing dynamic content; avoid excessive descriptions of static scenes.
Limit the output prompt to 100 words or less.
The output must be in English.
Prompt examples:
The camera pulls back to show two foreign men walking up the stairs. The man on the left supports the man on the right with his right hand.
A black squirrel focuses on eating, occasionally looking around.
A man talks, his expression shifting from smiling to closing his eyes, reopening them, and finally smiling with closed eyes. His gestures are lively, making various hand motions while speaking.
A close-up of someone measuring with a ruler and pen, drawing a straight line on paper with a black marker in their right hand.
A model car moves on a wooden board, traveling from right to left across grass and wooden structures.
The camera moves left, then pushes forward to capture a person sitting on a breakwater.
A man speaks, his expressions and gestures changing with the conversation, while the overall scene remains constant.
The camera moves left, then pushes forward to capture a person sitting on a breakwater.
A woman wearing a pearl necklace looks to the right and speaks.
Output only the text without additional responses.'''

View File

@ -0,0 +1,233 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
# utils MLX version
import argparse
import binascii
import logging
import os
import os.path as osp
import imageio
import mlx.core as mx
import numpy as np
__all__ = ['save_video', 'save_image', 'str2bool', 'masks_like', 'best_output_size']
def rand_name(length=8, suffix=''):
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
if suffix:
if not suffix.startswith('.'):
suffix = '.' + suffix
name += suffix
return name
def make_grid(tensor, nrow=8, normalize=True, value_range=(-1, 1)):
"""MLX equivalent of torchvision.utils.make_grid"""
# tensor shape: (batch, channels, height, width)
batch_size, channels, height, width = tensor.shape
# Calculate grid dimensions
ncol = nrow
nrow_actual = (batch_size + ncol - 1) // ncol
# Create grid
grid_height = height * nrow_actual + (nrow_actual - 1) * 2 # 2 pixel padding
grid_width = width * ncol + (ncol - 1) * 2
# Initialize grid with zeros
grid = mx.zeros((channels, grid_height, grid_width))
# Fill grid
for idx in range(batch_size):
row = idx // ncol
col = idx % ncol
y_start = row * (height + 2)
y_end = y_start + height
x_start = col * (width + 2)
x_end = x_start + width
img = tensor[idx]
if normalize:
# Normalize to [0, 1]
img = (img - value_range[0]) / (value_range[1] - value_range[0])
grid[:, y_start:y_end, x_start:x_end] = img
return grid
def save_video(tensor,
save_file=None,
fps=30,
suffix='.mp4',
nrow=8,
normalize=True,
value_range=(-1, 1)):
# cache file
cache_file = osp.join('/tmp', rand_name(
suffix=suffix)) if save_file is None else save_file
# save to cache
try:
# preprocess
tensor = mx.clip(tensor, value_range[0], value_range[1])
# tensor shape: (batch, channels, frames, height, width)
# Process each frame
frames = []
for frame_idx in range(tensor.shape[2]):
frame = tensor[:, :, frame_idx, :, :] # (batch, channels, height, width)
grid = make_grid(frame, nrow=nrow, normalize=normalize, value_range=value_range)
frames.append(grid)
# Stack frames and convert to (frames, height, width, channels)
tensor = mx.stack(frames, axis=0) # (frames, channels, height, width)
tensor = mx.transpose(tensor, [0, 2, 3, 1]) # (frames, height, width, channels)
# Convert to uint8
tensor = (tensor * 255).astype(mx.uint8)
tensor_np = np.array(tensor)
# write video
writer = imageio.get_writer(
cache_file, fps=fps, codec='libx264', quality=8)
for frame in tensor_np:
writer.append_data(frame)
writer.close()
except Exception as e:
logging.info(f'save_video failed, error: {e}')
def save_image(tensor, save_file, nrow=8, normalize=True, value_range=(-1, 1)):
# cache file
suffix = osp.splitext(save_file)[1]
if suffix.lower() not in [
'.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp'
]:
suffix = '.png'
# save to cache
try:
# Clip values
tensor = mx.clip(tensor, value_range[0], value_range[1])
# Make grid
grid = make_grid(tensor, nrow=nrow, normalize=normalize, value_range=value_range)
# Convert to (height, width, channels) and uint8
grid = mx.transpose(grid, [1, 2, 0]) # (height, width, channels)
grid = (grid * 255).astype(mx.uint8)
# Save using imageio
imageio.imwrite(save_file, np.array(grid))
return save_file
except Exception as e:
logging.info(f'save_image failed, error: {e}')
def str2bool(v):
"""
Convert a string to a boolean.
Supported true values: 'yes', 'true', 't', 'y', '1'
Supported false values: 'no', 'false', 'f', 'n', '0'
Args:
v (str): String to convert.
Returns:
bool: Converted boolean value.
Raises:
argparse.ArgumentTypeError: If the value cannot be converted to boolean.
"""
if isinstance(v, bool):
return v
v_lower = v.lower()
if v_lower in ('yes', 'true', 't', 'y', '1'):
return True
elif v_lower in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected (True/False)')
def masks_like(tensor, zero=False, generator=None, p=0.2):
"""
Generate masks similar to input tensors.
Args:
tensor: List of MLX arrays
zero: Whether to apply zero masking
generator: Random generator (for MLX, we use mx.random.seed instead)
p: Probability for random masking
Returns:
Tuple of two lists of masks
"""
assert isinstance(tensor, list)
out1 = [mx.ones(u.shape, dtype=u.dtype) for u in tensor]
out2 = [mx.ones(u.shape, dtype=u.dtype) for u in tensor]
if zero:
if generator is not None:
# MLX doesn't have the same generator API as PyTorch
# We'll use random state instead
for u, v in zip(out1, out2):
random_num = mx.random.uniform(0, 1, shape=(1,)).item()
if random_num < p:
# Generate random values with normal distribution
normal_vals = mx.random.normal(shape=u[:, 0].shape, loc=-3.5, scale=0.5)
u[:, 0] = mx.exp(normal_vals)
v[:, 0] = mx.zeros_like(v[:, 0])
else:
# Keep original values
u[:, 0] = u[:, 0]
v[:, 0] = v[:, 0]
else:
for u, v in zip(out1, out2):
u[:, 0] = mx.zeros_like(u[:, 0])
v[:, 0] = mx.zeros_like(v[:, 0])
return out1, out2
def best_output_size(w, h, dw, dh, expected_area):
"""
Calculate the best output size given constraints.
Args:
w: Width
h: Height
dw: Width divisor
dh: Height divisor
expected_area: Target area
Returns:
Tuple of (output_width, output_height)
"""
# float output size
ratio = w / h
ow = (expected_area * ratio)**0.5
oh = expected_area / ow
# process width first
ow1 = int(ow // dw * dw)
oh1 = int(expected_area / ow1 // dh * dh)
assert ow1 % dw == 0 and oh1 % dh == 0 and ow1 * oh1 <= expected_area
ratio1 = ow1 / oh1
# process height first
oh2 = int(oh // dh * dh)
ow2 = int(expected_area / oh2 // dw * dw)
assert oh2 % dh == 0 and ow2 % dw == 0 and ow2 * oh2 <= expected_area
ratio2 = ow2 / oh2
# compare ratios
if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2,
ratio2 / ratio):
return ow1, oh1
else:
return ow2, oh2

View File

@ -0,0 +1,160 @@
import torch
import mlx.core as mx
import numpy as np
from typing import Dict, Tuple
from safetensors import safe_open
def convert_pytorch_to_mlx(pytorch_path: str, output_path: str, float16: bool = False):
"""
Convert PyTorch VAE weights to MLX format with correct mapping.
"""
print(f"Converting {pytorch_path} -> {output_path}")
dtype = mx.float16 if float16 else mx.float32
# Load PyTorch weights
if pytorch_path.endswith('.safetensors'):
weights = {}
with safe_open(pytorch_path, framework="pt", device="cpu") as f:
for key in f.keys():
tensor = f.get_tensor(key)
if tensor.dtype == torch.bfloat16:
tensor = tensor.float()
weights[key] = tensor.numpy()
else:
checkpoint = torch.load(pytorch_path, map_location='cpu')
weights = {}
state_dict = checkpoint if isinstance(checkpoint, dict) and 'state_dict' not in checkpoint else checkpoint.get('state_dict', checkpoint)
for key, tensor in state_dict.items():
if tensor.dtype == torch.bfloat16:
tensor = tensor.float()
weights[key] = tensor.numpy()
# Convert weights
mlx_weights = {}
for key, value in weights.items():
# Skip these
if any(skip in key for skip in ["num_batches_tracked", "running_mean", "running_var"]):
continue
# Convert weight formats
if value.ndim == 5 and "weight" in key: # Conv3d weights
# PyTorch: (out_channels, in_channels, D, H, W)
# MLX: (out_channels, D, H, W, in_channels)
value = np.transpose(value, (0, 2, 3, 4, 1))
elif value.ndim == 4 and "weight" in key: # Conv2d weights
# PyTorch: (out_channels, in_channels, H, W)
# MLX Conv2d expects: (out_channels, H, W, in_channels)
value = np.transpose(value, (0, 2, 3, 1))
elif value.ndim == 1 and "bias" in key: # Conv biases
# Keep as is - MLX uses same format
pass
# Map the key
new_key = key
# Map residual block internals within Sequential
# PyTorch: encoder.downsamples.0.residual.0.gamma
# MLX: encoder.downsamples.layers.0.residual.layers.0.gamma
import re
# Add .layers to Sequential modules
new_key = re.sub(r'\.downsamples\.(\d+)', r'.downsamples.layers.\1', new_key)
new_key = re.sub(r'\.upsamples\.(\d+)', r'.upsamples.layers.\1', new_key)
new_key = re.sub(r'\.middle\.(\d+)', r'.middle.layers.\1', new_key)
new_key = re.sub(r'\.head\.(\d+)', r'.head.layers.\1', new_key)
# Map residual Sequential internals
if ".residual." in new_key:
match = re.search(r'\.residual\.(\d+)\.', new_key)
if match:
idx = int(match.group(1))
if idx == 0: # First RMS_norm
new_key = re.sub(r'\.residual\.0\.', '.residual.layers.0.', new_key)
elif idx == 1: # SiLU - skip
continue
elif idx == 2: # First Conv3d
new_key = re.sub(r'\.residual\.2\.', '.residual.layers.2.', new_key)
elif idx == 3: # Second RMS_norm
new_key = re.sub(r'\.residual\.3\.', '.residual.layers.3.', new_key)
elif idx == 4: # Second SiLU - skip
continue
elif idx == 5: # Dropout - could be Identity in MLX
if "Dropout" in key:
continue
new_key = re.sub(r'\.residual\.5\.', '.residual.layers.5.', new_key)
elif idx == 6: # Second Conv3d
new_key = re.sub(r'\.residual\.6\.', '.residual.layers.6.', new_key)
# Map resample internals
if ".resample." in new_key:
# In both Encoder and Decoder Resample blocks, the Conv2d is at index 1
# in the nn.Sequential block, following either a padding or upsample layer.
# We just need to map PyTorch's .1 to MLX's .layers.1
if ".resample.1." in new_key:
new_key = new_key.replace(".resample.1.", ".resample.layers.1.")
# The layers at index 0 (ZeroPad2d, Upsample) have no weights, so we can
# safely skip any keys associated with them.
if ".resample.0." in key:
continue
# Map head internals (already using Sequential in MLX)
# Just need to handle the layers index
# Handle shortcut layers
if ".shortcut." in new_key and "Identity" not in key:
# Shortcut Conv3d layers - keep as is
pass
elif "Identity" in key:
# Skip Identity modules
continue
# Handle time_conv in Resample
if "time_conv" in new_key:
# Keep as is - already correctly named
pass
# Handle attention layers
if "to_qkv" in new_key or "proj" in new_key:
# Keep as is - already correctly named
pass
# In the conversion script
if "gamma" in new_key:
# Squeeze gamma from (C, 1, 1) or (C, 1, 1, 1) to just (C,)
value = np.squeeze(value) # This removes all dimensions of size 1
# Result will always be 1D array of shape (C,)
# Add to MLX weights
mlx_weights[new_key] = mx.array(value).astype(dtype)
# Verify critical layers are present
critical_prefixes = [
"encoder.conv1", "decoder.conv1", "conv1", "conv2",
"encoder.head.layers.2", "decoder.head.layers.2" # Updated for Sequential
]
for prefix in critical_prefixes:
found = any(k.startswith(prefix) for k in mlx_weights.keys())
if not found:
print(f"WARNING: No weights found for {prefix}")
print(f"Converted {len(mlx_weights)} parameters")
# Print a few example keys for verification
print("\nExample converted keys:")
for i, key in enumerate(sorted(mlx_weights.keys())[:10]):
print(f" {key}")
# Save
if output_path.endswith('.safetensors'):
mx.save_safetensors(output_path, mlx_weights)
else:
mx.savez(output_path, **mlx_weights)
print(f"\nSaved to {output_path}")
print("\nAll converted keys:")
for key in sorted(mlx_weights.keys()):
print(f" {key}: {mlx_weights[key].shape}")
return mlx_weights

View File

@ -0,0 +1,294 @@
# wan_model_io.py
from typing import List, Tuple, Set, Dict
import os
import mlx.core as mx
from mlx.utils import tree_unflatten, tree_flatten
from safetensors import safe_open
import torch
import numpy as np
import glob
def map_wan_2_2_weights(key: str, value: mx.array) -> List[Tuple[str, mx.array]]:
"""Map PyTorch WAN 2.2 weights to MLX format."""
# Only add .layers to Sequential WITHIN components, not to blocks themselves
# blocks.N stays as blocks.N (not blocks.layers.N)
# Handle Sequential layers - PyTorch uses .0, .1, .2, MLX uses .layers.0, .layers.1, .layers.2
# Only for components INSIDE blocks and top-level modules
if ".ffn." in key and not ".layers." in key:
# Replace .ffn.0 with .ffn.layers.0, etc.
key = key.replace(".ffn.0.", ".ffn.layers.0.")
key = key.replace(".ffn.1.", ".ffn.layers.1.")
key = key.replace(".ffn.2.", ".ffn.layers.2.")
if "text_embedding." in key and not ".layers." in key:
for i in range(10):
key = key.replace(f"text_embedding.{i}.", f"text_embedding.layers.{i}.")
if "time_embedding." in key and not ".layers." in key:
for i in range(10):
key = key.replace(f"time_embedding.{i}.", f"time_embedding.layers.{i}.")
if "time_projection." in key and not ".layers." in key:
for i in range(10):
key = key.replace(f"time_projection.{i}.", f"time_projection.layers.{i}.")
# Handle conv transpose for patch_embedding
if "patch_embedding.weight" in key:
# PyTorch Conv3d: (out_channels, in_channels, D, H, W)
# MLX Conv3d: (out_channels, D, H, W, in_channels)
value = mx.transpose(value, (0, 2, 3, 4, 1))
return [(key, value)]
def check_parameter_mismatch(model, weights: Dict[str, mx.array]) -> Tuple[Set[str], Set[str]]:
"""
Check for parameter mismatches between model and weights.
Returns:
(model_only, weights_only): Sets of parameter names that exist only in model or weights
"""
# Get all parameter names from model
model_params = dict(tree_flatten(model.parameters()))
model_keys = set(model_params.keys())
# Remove computed buffers that aren't loaded from weights
computed_buffers = {'freqs'} # Add any other computed buffers here
model_keys = model_keys - computed_buffers
# Get all parameter names from weights
weight_keys = set(weights.keys())
# Find differences
model_only = model_keys - weight_keys
weights_only = weight_keys - model_keys
return model_only, weights_only
def load_wan_2_2_from_safetensors(
safetensors_path: str,
model,
float16: bool = False,
check_mismatch: bool = True
):
"""
Load WAN 2.2 Model weights from safetensors file(s) into MLX model.
Args:
safetensors_path: Path to safetensors file or directory
model: MLX model instance
float16: Whether to use float16 precision
check_mismatch: Whether to check for parameter mismatches
"""
if os.path.isdir(safetensors_path):
# Multiple files (14B model) - only diffusion_mlx_model files
pattern = os.path.join(safetensors_path, "diffusion_mlx_model*.safetensors")
safetensor_files = sorted(glob.glob(pattern))
print(f"Found {len(safetensor_files)} diffusion_mlx_model safetensors files")
# Load all files and merge weights
all_weights = {}
for file_path in safetensor_files:
print(f"Loading: {file_path}")
weights = mx.load(file_path)
all_weights.update(weights)
if check_mismatch:
model_only, weights_only = check_parameter_mismatch(model, all_weights)
if model_only:
print(f"\n⚠️ WARNING: {len(model_only)} parameters in model but NOT in weights:")
for param in sorted(model_only)[:10]: # Show first 10
print(f" - {param}")
if len(model_only) > 10:
print(f" ... and {len(model_only) - 10} more")
if weights_only:
print(f"\n⚠️ WARNING: {len(weights_only)} parameters in weights but NOT in model:")
for param in sorted(weights_only)[:10]: # Show first 10
print(f" - {param}")
if len(weights_only) > 10:
print(f" ... and {len(weights_only) - 10} more")
if not model_only and not weights_only:
print("\n✅ Perfect match: All parameters align between model and weights!")
model.update(tree_unflatten(list(all_weights.items())))
else:
# Single file
print(f"Loading single file: {safetensors_path}")
weights = mx.load(safetensors_path)
if check_mismatch:
model_only, weights_only = check_parameter_mismatch(model, weights)
if model_only:
print(f"\n⚠️ WARNING: {len(model_only)} parameters in model but NOT in weights:")
for param in sorted(model_only)[:10]: # Show first 10
print(f" - {param}")
if len(model_only) > 10:
print(f" ... and {len(model_only) - 10} more")
if weights_only:
print(f"\n⚠️ WARNING: {len(weights_only)} parameters in weights but NOT in model:")
for param in sorted(weights_only)[:10]: # Show first 10
print(f" - {param}")
if len(weights_only) > 10:
print(f" ... and {len(weights_only) - 10} more")
if not model_only and not weights_only:
print("\n✅ Perfect match: All parameters align between model and weights!")
model.update(tree_unflatten(list(weights.items())))
print("\nWAN 2.2 Model weights loaded successfully!")
return model
def convert_wan_2_2_safetensors_to_mlx(
safetensors_path: str,
output_path: str,
float16: bool = False,
model=None # Optional: provide model instance to check parameter alignment
):
"""
Convert WAN 2.2 PyTorch safetensors file to MLX weights file.
Args:
safetensors_path: Input safetensors file
output_path: Output MLX weights file (.safetensors)
float16: Whether to use float16 precision
model: Optional MLX model instance to check parameter alignment
"""
dtype = mx.float16 if float16 else mx.float32
print(f"Converting WAN 2.2 safetensors to MLX format...")
print(f"Input: {safetensors_path}")
print(f"Output: {output_path}")
print(f"Target dtype: {dtype}")
# Load and convert weights
weights = {}
bfloat16_count = 0
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
keys = list(f.keys())
print(f"Processing {len(keys)} parameters...")
for key in keys:
tensor = f.get_tensor(key)
# Handle BFloat16
if tensor.dtype == torch.bfloat16:
bfloat16_count += 1
tensor = tensor.float() # Convert to float32 first
value = mx.array(tensor.numpy()).astype(dtype)
# Apply mapping
mapped = map_wan_2_2_weights(key, value)
for new_key, new_value in mapped:
weights[new_key] = new_value
if bfloat16_count > 0:
print(f"⚠️ Converted {bfloat16_count} BFloat16 tensors to {dtype}")
# Check parameter alignment if model provided
if model is not None:
print("\nChecking parameter alignment with model...")
model_only, weights_only = check_parameter_mismatch(model, weights)
if model_only:
print(f"\n⚠️ WARNING: {len(model_only)} parameters in model but NOT in converted weights:")
for param in sorted(model_only)[:10]:
print(f" - {param}")
if len(model_only) > 10:
print(f" ... and {len(model_only) - 10} more")
if weights_only:
print(f"\n⚠️ WARNING: {len(weights_only)} parameters in converted weights but NOT in model:")
for param in sorted(weights_only)[:10]:
print(f" - {param}")
if len(weights_only) > 10:
print(f" ... and {len(weights_only) - 10} more")
if not model_only and not weights_only:
print("\n✅ Perfect match: All parameters align between model and converted weights!")
# Save as MLX format
print(f"\nSaving {len(weights)} parameters to: {output_path}")
mx.save_safetensors(output_path, weights)
# Print a few example keys for verification
print("\nExample converted keys:")
for i, key in enumerate(sorted(weights.keys())[:10]):
print(f" {key}: {weights[key].shape}")
return weights
def convert_multiple_wan_2_2_safetensors_to_mlx(
checkpoint_dir: str,
float16: bool = False
):
"""Convert multiple WAN 2.2 PyTorch safetensors files to MLX format."""
# Find all PyTorch model files
pytorch_pattern = os.path.join(checkpoint_dir, "diffusion_pytorch_model-*.safetensors")
pytorch_files = sorted(glob.glob(pytorch_pattern))
if not pytorch_files:
raise FileNotFoundError(f"No PyTorch model files found matching: {pytorch_pattern}")
print(f"Converting {len(pytorch_files)} PyTorch files to MLX format...")
for i, pytorch_file in enumerate(pytorch_files, 1):
# Extract the suffix (e.g., "00001-of-00006")
basename = os.path.basename(pytorch_file)
suffix = basename.replace("diffusion_pytorch_model-", "").replace(".safetensors", "")
# Create MLX filename
mlx_file = os.path.join(checkpoint_dir, f"diffusion_mlx_model-{suffix}.safetensors")
print(f"\nConverting {i}/{len(pytorch_files)}: {basename}")
convert_wan_2_2_safetensors_to_mlx(pytorch_file, mlx_file, float16)
print("\nAll files converted successfully!")
def debug_wan_2_2_weight_mapping(safetensors_path: str, float16: bool = False):
"""
Debug function to see how WAN 2.2 weights are being mapped.
"""
dtype = mx.float16 if float16 else mx.float32
print("=== WAN 2.2 Weight Mapping Debug ===")
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
# Check first 30 keys to see the mapping
for i, key in enumerate(f.keys()):
if i >= 30:
break
tensor = f.get_tensor(key)
# Handle BFloat16
original_dtype = tensor.dtype
if tensor.dtype == torch.bfloat16:
tensor = tensor.float()
value = mx.array(tensor.numpy()).astype(dtype)
# Apply mapping
mapped = map_wan_2_2_weights(key, value)
new_key, new_value = mapped[0]
if new_key == key:
print(f"UNCHANGED: {key} [{tensor.shape}]")
else:
print(f"MAPPED: {key} -> {new_key} [{tensor.shape}]")