Implement Wan 2.1

This commit is contained in:
N 2025-07-28 15:51:11 -07:00
parent 4b2a0df237
commit a36538bf9f
40 changed files with 7431 additions and 0 deletions

38
video/Wan2.1/.gitignore vendored Normal file
View File

@ -0,0 +1,38 @@
.*
*.py[cod]
# *.jpg
*.jpeg
# *.png
*.gif
*.bmp
*.mp4
*.mov
*.mkv
*.log
*.zip
*.pt
*.pth
*.ckpt
*.safetensors
*.json
# *.txt
*.backup
*.pkl
*.html
*.pdf
*.whl
cache
__pycache__/
storage/
samples/
!.gitignore
!requirements.txt
.DS_Store
*DS_Store
google/
Wan2.1-T2V-14B/
Wan2.1-T2V-1.3B/
Wan2.1-I2V-14B-480P/
Wan2.1-I2V-14B-720P/
venv_wan/
venv_wan_py310/

201
video/Wan2.1/LICENSE.txt Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

83
video/Wan2.1/README.md Normal file
View File

@ -0,0 +1,83 @@
# Wan2.1
## Quickstart
#### Installation
Install dependencies:
```
pip install -r requirements.txt
```
#### Model Download
| Models | Download Link | Notes |
| --------------|-------------------------------------------------------------------------------|-------------------------------|
| T2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | Supports both 480P and 720P
| I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | Supports 720P
| I2V-14B-480P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | Supports 480P
| T2V-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | Supports 480P
> 💡Note: The 1.3B model is capable of generating videos at 720P resolution. However, due to limited training at this resolution, the results are generally less stable compared to 480P. For optimal performance, we recommend using 480P resolution. Also, note that the MLX port currently only supports T2V.
Download models using huggingface-cli:
```
pip install "huggingface_hub[cli]"
huggingface-cli download Wan-AI/Wan2.1-T2V-14B --local-dir ./Wan2.1-T2V-14B
```
Download models using modelscope-cli:
```
pip install modelscope
modelscope download Wan-AI/Wan2.1-T2V-14B --local_dir ./Wan2.1-T2V-14B
```
#### Run Text-to-Video Generation
This repository currently supports two Text-to-Video models (1.3B and 14B) and two resolutions (480P and 720P). The parameters and configurations for these models are as follows:
<table>
<thead>
<tr>
<th rowspan="2">Task</th>
<th colspan="2">Resolution</th>
<th rowspan="2">Model</th>
</tr>
<tr>
<th>480P</th>
<th>720P</th>
</tr>
</thead>
<tbody>
<tr>
<td>t2v-14B</td>
<td style="color: green;">✔️</td>
<td style="color: green;">✔️</td>
<td>Wan2.1-T2V-14B</td>
</tr>
<tr>
<td>t2v-1.3B</td>
<td style="color: green;">✔️</td>
<td style="color: red;"></td>
<td>Wan2.1-T2V-1.3B</td>
</tr>
</tbody>
</table>
##### (1) Example:
```
python generate_mlx.py --task t2v-1.3B --size "480*832" --frame_num 16 --sample_steps 25 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --prompt "Lion running under snow in Samarkand" --save_file output_video_mlx.mp4
```
## Citation
Credits to the Wan Team for the original PyTorch implementation.
```
@article{wan2.1,
title = {Wan: Open and Advanced Large-Scale Video Generative Models},
author = {Wan Team},
journal = {},
year = {2025}
}
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 516 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 871 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 55 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 294 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 628 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 208 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 245 KiB

View File

@ -0,0 +1,245 @@
import argparse
from datetime import datetime
import logging
import os
import sys
import random
import mlx.core as mx
from PIL import Image
import wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
from wan.utils.utils_mlx_own import cache_video, cache_image, str2bool
EXAMPLE_PROMPT = {
"t2v-1.3B": {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"t2v-14B": {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"t2i-14B": {
"prompt": "一个朴素端庄的美人",
},
"i2v-14B": {
"prompt":
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
"image":
"examples/i2v_input.JPG",
},
}
def _validate_args(args):
# Basic check
assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
# The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
if args.sample_steps is None:
args.sample_steps = 40 if "i2v" in args.task else 50
if args.sample_shift is None:
args.sample_shift = 5.0
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
args.sample_shift = 3.0
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
if args.frame_num is None:
args.frame_num = 1 if "t2i" in args.task else 1
# T2I frame_num check
if "t2i" in args.task:
assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}"
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
0, sys.maxsize)
# Size check
assert args.size in SUPPORTED_SIZES[
args.
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a image or video from a text prompt or image using Wan"
)
parser.add_argument(
"--task",
type=str,
default="t2v-14B",
choices=list(WAN_CONFIGS.keys()),
help="The task to run.")
parser.add_argument(
"--size",
type=str,
default="1280*720",
choices=list(SIZE_CONFIGS.keys()),
help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
)
parser.add_argument(
"--frame_num",
type=int,
default=None,
help="How many frames to sample from a image or video. The number should be 4n+1"
)
parser.add_argument(
"--ckpt_dir",
type=str,
default=None,
help="The path to the checkpoint directory.")
parser.add_argument(
"--offload_model",
type=str2bool,
default=None,
help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
)
parser.add_argument(
"--save_file",
type=str,
default=None,
help="The file to save the generated image or video to.")
parser.add_argument(
"--prompt",
type=str,
default=None,
help="The prompt to generate the image or video from.")
parser.add_argument(
"--base_seed",
type=int,
default=-1,
help="The seed to use for generating the image or video.")
parser.add_argument(
"--image",
type=str,
default=None,
help="The image to generate the video from.")
parser.add_argument(
"--sample_solver",
type=str,
default='unipc',
choices=['unipc', 'dpm++'],
help="The solver used to sample.")
parser.add_argument(
"--sample_steps", type=int, default=None, help="The sampling steps.")
parser.add_argument(
"--sample_shift",
type=float,
default=None,
help="Sampling shift factor for flow matching schedulers.")
parser.add_argument(
"--sample_guide_scale",
type=float,
default=5.0,
help="Classifier free guidance scale.")
args = parser.parse_args()
_validate_args(args)
return args
def _init_logging():
# logging
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s",
handlers=[logging.StreamHandler(stream=sys.stdout)])
def generate(args):
_init_logging()
# MLX uses default device automatically
if args.offload_model is None:
args.offload_model = True # Default to True to save memory
logging.info(
f"offload_model is not specified, set to {args.offload_model}.")
cfg = WAN_CONFIGS[args.task]
logging.info(f"Generation job args: {args}")
logging.info(f"Generation model config: {cfg}")
if "t2v" in args.task or "t2i" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
logging.info(f"Input prompt: {args.prompt}")
logging.info("Creating WanT2V pipeline.")
wan_t2v = wan.WanT2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
)
logging.info(
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
video = wan_t2v.generate(
args.prompt,
size=SIZE_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
else:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
if args.image is None:
args.image = EXAMPLE_PROMPT[args.task]["image"]
logging.info(f"Input prompt: {args.prompt}")
logging.info(f"Input image: {args.image}")
img = Image.open(args.image).convert("RGB")
logging.info("Creating WanI2V pipeline.")
wan_i2v = wan.WanI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
)
logging.info("Generating video ...")
video = wan_i2v.generate(
args.prompt,
img,
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
# Save output
if args.save_file is None:
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
formatted_prompt = args.prompt.replace(" ", "_").replace("/", "_")[:50]
suffix = '.png' if "t2i" in args.task else '.mp4'
args.save_file = f"{args.task}_{args.size}_{formatted_prompt}_{formatted_time}" + suffix
if "t2i" in args.task:
logging.info(f"Saving generated image to {args.save_file}")
# Note: cache_image might need to handle MLX arrays
cache_image(
tensor=video.squeeze(1)[None],
save_file=args.save_file,
nrow=1,
normalize=True,
value_range=(-1, 1))
else:
logging.info(f"Saving generated video to {args.save_file}")
cache_video(
tensor=video[None],
save_file=args.save_file,
fps=cfg.sample_fps,
nrow=1,
normalize=True,
value_range=(-1, 1))
logging.info("Finished.")
if __name__ == "__main__":
args = _parse_args()
generate(args)

View File

@ -0,0 +1,18 @@
torch>=2.4.0
torchvision>=0.19.0
opencv-python>=4.9.0.80
diffusers>=0.31.0
transformers>=4.49.0
tokenizers>=0.20.3
accelerate>=1.1.1
tqdm
imageio
easydict
ftfy
dashscope
imageio-ffmpeg
# flash_attn
gradio>=5.0.0
numpy>=1.23.5,<2
mlx
scikit-image

View File

@ -0,0 +1,6 @@
Put all your models (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P) in a folder and specify the max GPU number you want to use.
```bash
bash ./test.sh <local model dir> <gpu number>
```

113
video/Wan2.1/tests/test.sh Normal file
View File

@ -0,0 +1,113 @@
#!/bin/bash
if [ "$#" -eq 2 ]; then
MODEL_DIR=$(realpath "$1")
GPUS=$2
else
echo "Usage: $0 <local model dir> <gpu number>"
exit 1
fi
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
REPO_ROOT="$(dirname "$SCRIPT_DIR")"
cd "$REPO_ROOT" || exit 1
PY_FILE=./generate.py
function t2v_1_3B() {
T2V_1_3B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-1.3B"
# 1-GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B 1-GPU Test: "
python $PY_FILE --task t2v-1.3B --size 480*832 --ckpt_dir $T2V_1_3B_CKPT_DIR
# Multiple GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend local_qwen: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
if [ -n "${DASH_API_KEY+x}" ]; then
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend dashscope: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope"
else
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test."
fi
}
function t2v_14B() {
T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B"
# 1-GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B 1-GPU Test: "
python $PY_FILE --task t2v-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR
# Multiple GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU, prompt extend local_qwen: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
}
function t2i_14B() {
T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B"
# 1-GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B 1-GPU Test: "
python $PY_FILE --task t2i-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR
# Multiple GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU, prompt extend local_qwen: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
}
function i2v_14B_480p() {
I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-480P"
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: "
python $PY_FILE --task i2v-14B --size 832*480 --ckpt_dir $I2V_14B_CKPT_DIR
# Multiple GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend local_qwen: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-VL-3B-Instruct" --prompt_extend_target_lang "en"
if [ -n "${DASH_API_KEY+x}" ]; then
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend dashscope: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope"
else
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test."
fi
}
function i2v_14B_720p() {
I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-720P"
# 1-GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: "
python $PY_FILE --task i2v-14B --size 720*1280 --ckpt_dir $I2V_14B_CKPT_DIR
# Multiple GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
}
t2i_14B
t2v_1_3B
t2v_14B
i2v_14B_480p
i2v_14B_720p

View File

@ -0,0 +1,2 @@
from . import configs, modules
from .text2video_mlx import WanT2V

View File

@ -0,0 +1,42 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import copy
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
from .wan_i2v_14B import i2v_14B
from .wan_t2v_1_3B import t2v_1_3B
from .wan_t2v_14B import t2v_14B
# the config of t2i_14B is the same as t2v_14B
t2i_14B = copy.deepcopy(t2v_14B)
t2i_14B.__name__ = 'Config: Wan T2I 14B'
WAN_CONFIGS = {
't2v-14B': t2v_14B,
't2v-1.3B': t2v_1_3B,
'i2v-14B': i2v_14B,
't2i-14B': t2i_14B,
}
SIZE_CONFIGS = {
'720*1280': (720, 1280),
'1280*720': (1280, 720),
'480*832': (480, 832),
'832*480': (832, 480),
'1024*1024': (1024, 1024),
}
MAX_AREA_CONFIGS = {
'720*1280': 720 * 1280,
'1280*720': 1280 * 720,
'480*832': 480 * 832,
'832*480': 832 * 480,
}
SUPPORTED_SIZES = {
't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
't2v-1.3B': ('480*832', '832*480'),
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
't2i-14B': tuple(SIZE_CONFIGS.keys()),
}

View File

@ -0,0 +1,19 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
from easydict import EasyDict
#------------------------ Wan shared config ------------------------#
wan_shared_cfg = EasyDict()
# t5
wan_shared_cfg.t5_model = 'umt5_xxl'
wan_shared_cfg.t5_dtype = torch.float32
wan_shared_cfg.text_len = 512
# transformer
wan_shared_cfg.param_dtype = torch.float32
# inference
wan_shared_cfg.num_train_timesteps = 1000
wan_shared_cfg.sample_fps = 16
wan_shared_cfg.sample_neg_prompt = '色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走'

View File

@ -0,0 +1,35 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
from easydict import EasyDict
from .shared_config import wan_shared_cfg
#------------------------ Wan I2V 14B ------------------------#
i2v_14B = EasyDict(__name__='Config: Wan I2V 14B')
i2v_14B.update(wan_shared_cfg)
i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
i2v_14B.t5_tokenizer = 'google/umt5-xxl'
# clip
i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
i2v_14B.clip_dtype = torch.float32
i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
i2v_14B.clip_tokenizer = 'xlm-roberta-large'
# vae
i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
i2v_14B.vae_stride = (4, 8, 8)
# transformer
i2v_14B.patch_size = (1, 2, 2)
i2v_14B.dim = 5120
i2v_14B.ffn_dim = 13824
i2v_14B.freq_dim = 256
i2v_14B.num_heads = 40
i2v_14B.num_layers = 40
i2v_14B.window_size = (-1, -1)
i2v_14B.qk_norm = True
i2v_14B.cross_attn_norm = True
i2v_14B.eps = 1e-6

View File

@ -0,0 +1,29 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from easydict import EasyDict
from .shared_config import wan_shared_cfg
#------------------------ Wan T2V 14B ------------------------#
t2v_14B = EasyDict(__name__='Config: Wan T2V 14B')
t2v_14B.update(wan_shared_cfg)
# t5
t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.safetensors'
t2v_14B.t5_tokenizer = 'google/umt5-xxl'
# vae
t2v_14B.vae_checkpoint = 'vae_mlx.safetensors'
t2v_14B.vae_stride = (4, 8, 8)
# transformer
t2v_14B.patch_size = (1, 2, 2)
t2v_14B.dim = 5120
t2v_14B.ffn_dim = 13824
t2v_14B.freq_dim = 256
t2v_14B.num_heads = 40
t2v_14B.num_layers = 40
t2v_14B.window_size = (-1, -1)
t2v_14B.qk_norm = True
t2v_14B.cross_attn_norm = True
t2v_14B.eps = 1e-6

View File

@ -0,0 +1,29 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from easydict import EasyDict
from .shared_config import wan_shared_cfg
#------------------------ Wan T2V 1.3B ------------------------#
t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B')
t2v_1_3B.update(wan_shared_cfg)
# t5
t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.safetensors'
t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
# vae
t2v_1_3B.vae_checkpoint = 'vae_mlx.safetensors'
t2v_1_3B.vae_stride = (4, 8, 8)
# transformer
t2v_1_3B.patch_size = (1, 2, 2)
t2v_1_3B.dim = 1536
t2v_1_3B.ffn_dim = 8960
t2v_1_3B.freq_dim = 256
t2v_1_3B.num_heads = 12
t2v_1_3B.num_layers = 30
t2v_1_3B.window_size = (-1, -1)
t2v_1_3B.qk_norm = True
t2v_1_3B.cross_attn_norm = True
t2v_1_3B.eps = 1e-6

View File

@ -0,0 +1,14 @@
from .model_mlx import WanModel
from .t5_mlx import T5Decoder, T5Encoder, T5EncoderModel, T5Model
from .tokenizers import HuggingfaceTokenizer
from .vae_mlx import WanVAE
__all__ = [
'WanVAE',
'WanModel',
'T5Model',
'T5Encoder',
'T5Decoder',
'T5EncoderModel',
'HuggingfaceTokenizer',
]

View File

@ -0,0 +1,787 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
# MLX Implementation of WAN Model - True 1:1 Port from PyTorch
import math
from typing import List, Tuple, Optional
import mlx.core as mx
import mlx.nn as nn
import numpy as np
__all__ = ['WanModel']
def sinusoidal_embedding_1d(dim: int, position: mx.array) -> mx.array:
"""Generate sinusoidal position embeddings."""
assert dim % 2 == 0
half = dim // 2
position = position.astype(mx.float32)
# Calculate sinusoidal embeddings
div_term = mx.power(10000, mx.arange(half).astype(mx.float32) / half)
sinusoid = mx.expand_dims(position, 1) / mx.expand_dims(div_term, 0)
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=1)
def rope_params(max_seq_len: int, dim: int, theta: float = 10000) -> mx.array:
"""Generate RoPE (Rotary Position Embedding) parameters."""
assert dim % 2 == 0
positions = mx.arange(max_seq_len)
freqs = mx.arange(0, dim, 2).astype(mx.float32) / dim
freqs = 1.0 / mx.power(theta, freqs)
# Outer product
freqs = mx.expand_dims(positions, 1) * mx.expand_dims(freqs, 0)
# Convert to complex representation
return mx.stack([mx.cos(freqs), mx.sin(freqs)], axis=-1)
def rope_apply(x: mx.array, grid_sizes: mx.array, freqs: mx.array) -> mx.array:
"""Apply rotary position embeddings to input tensor."""
n, c_half = x.shape[2], x.shape[3] // 2
# Split frequencies for different dimensions
c_split = c_half - 2 * (c_half // 3)
freqs_splits = [
freqs[:, :c_split],
freqs[:, c_split:c_split + c_half // 3],
freqs[:, c_split + c_half // 3:]
]
output = []
for i in range(grid_sizes.shape[0]):
f, h, w = int(grid_sizes[i, 0]), int(grid_sizes[i, 1]), int(grid_sizes[i, 2])
seq_len = f * h * w
# Extract sequence for current sample
x_i = x[i, :seq_len].astype(mx.float32)
x_i = x_i.reshape(seq_len, n, -1, 2)
# Prepare frequency tensors
freqs_f = freqs_splits[0][:f].reshape(f, 1, 1, -1, 2)
freqs_f = mx.broadcast_to(freqs_f, (f, h, w, freqs_f.shape[-2], 2))
freqs_h = freqs_splits[1][:h].reshape(1, h, 1, -1, 2)
freqs_h = mx.broadcast_to(freqs_h, (f, h, w, freqs_h.shape[-2], 2))
freqs_w = freqs_splits[2][:w].reshape(1, 1, w, -1, 2)
freqs_w = mx.broadcast_to(freqs_w, (f, h, w, freqs_w.shape[-2], 2))
# Concatenate and reshape frequencies
freqs_i = mx.concatenate([freqs_f, freqs_h, freqs_w], axis=-2)
freqs_i = freqs_i.reshape(seq_len, 1, -1, 2)
# Apply rotary embedding
x_real = x_i[..., 0]
x_imag = x_i[..., 1]
freqs_cos = freqs_i[..., 0]
freqs_sin = freqs_i[..., 1]
x_rotated_real = x_real * freqs_cos - x_imag * freqs_sin
x_rotated_imag = x_real * freqs_sin + x_imag * freqs_cos
x_i = mx.stack([x_rotated_real, x_rotated_imag], axis=-1).reshape(seq_len, n, -1)
# Concatenate with remaining sequence if any
if seq_len < x.shape[1]:
x_i = mx.concatenate([x_i, x[i, seq_len:]], axis=0)
output.append(x_i)
return mx.stack(output).astype(x.dtype)
class WanRMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = mx.ones((dim,))
def __call__(self, x: mx.array) -> mx.array:
# RMS normalization
variance = mx.mean(mx.square(x.astype(mx.float32)), axis=-1, keepdims=True)
x_normed = x * mx.rsqrt(variance + self.eps)
return (x_normed * self.weight).astype(x.dtype)
class WanLayerNorm(nn.Module):
"""Layer normalization."""
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False):
super().__init__()
self.dim = dim
self.eps = eps
self.elementwise_affine = elementwise_affine
if elementwise_affine:
self.weight = mx.ones((dim,))
self.bias = mx.zeros((dim,))
def __call__(self, x: mx.array) -> mx.array:
# Standard layer normalization
x_float = x.astype(mx.float32)
mean = mx.mean(x_float, axis=-1, keepdims=True)
variance = mx.var(x_float, axis=-1, keepdims=True)
x_normed = (x_float - mean) * mx.rsqrt(variance + self.eps)
if self.elementwise_affine:
x_normed = x_normed * self.weight + self.bias
return x_normed.astype(x.dtype)
def mlx_attention(
q: mx.array,
k: mx.array,
v: mx.array,
q_lens: Optional[mx.array] = None,
k_lens: Optional[mx.array] = None,
dropout_p: float = 0.,
softmax_scale: Optional[float] = None,
q_scale: Optional[float] = None,
causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
deterministic: bool = False,
dtype: Optional[type] = None,
) -> mx.array:
"""
MLX implementation of scaled dot-product attention.
"""
# Get shapes
b, lq, n, d = q.shape
_, lk, _, _ = k.shape
# Scale queries if needed
if q_scale is not None:
q = q * q_scale
# Compute attention scores
q = q.transpose(0, 2, 1, 3) # [b, n, lq, d]
k = k.transpose(0, 2, 1, 3) # [b, n, lk, d]
v = v.transpose(0, 2, 1, 3) # [b, n, lk, d]
# Compute attention scores
scores = mx.matmul(q, k.transpose(0, 1, 3, 2)) # [b, n, lq, lk]
# Apply softmax scale if provided
if softmax_scale is not None:
scores = scores * softmax_scale
else:
# Default scaling by sqrt(d)
scores = scores / mx.sqrt(mx.array(d, dtype=scores.dtype))
# Create attention mask
attn_mask = None
# Apply window size masking if specified
if window_size != (-1, -1):
left_window, right_window = window_size
window_mask = mx.zeros((lq, lk))
for i in range(lq):
start = max(0, i - left_window)
end = min(lk, i + right_window + 1)
window_mask[i, start:end] = 1
attn_mask = window_mask
# Apply causal masking if needed
if causal:
causal_mask = mx.tril(mx.ones((lq, lk)), k=0)
if attn_mask is None:
attn_mask = causal_mask
else:
attn_mask = mx.logical_and(attn_mask, causal_mask)
# Apply attention mask if present
if attn_mask is not None:
attn_mask = attn_mask.astype(scores.dtype)
scores = scores * attn_mask + (1 - attn_mask) * -1e4
# Apply attention mask if lengths are provided
if q_lens is not None or k_lens is not None:
if q_lens is not None:
mask = mx.arange(lq)[None, :] < q_lens[:, None]
mask = mask.astype(scores.dtype)
scores = scores * mask[:, None, :, None] + (1 - mask[:, None, :, None]) * -1e4
if k_lens is not None:
mask = mx.arange(lk)[None, :] < k_lens[:, None]
mask = mask.astype(scores.dtype)
scores = scores * mask[:, None, None, :] + (1 - mask[:, None, None, :]) * -1e4
# Apply softmax
max_scores = mx.max(scores, axis=-1, keepdims=True)
scores = scores - max_scores
exp_scores = mx.exp(scores)
sum_exp = mx.sum(exp_scores, axis=-1, keepdims=True)
attn = exp_scores / (sum_exp + 1e-6)
# Apply dropout if needed
if dropout_p > 0 and not deterministic:
raise NotImplementedError("Dropout not implemented in MLX version")
# Compute output
out = mx.matmul(attn, v) # [b, n, lq, d]
out = out.transpose(0, 2, 1, 3) # [b, lq, n, d]
return out
class WanSelfAttention(nn.Module):
"""Self-attention module with RoPE and optional QK normalization."""
def __init__(self, dim: int, num_heads: int, window_size: Tuple[int, int] = (-1, -1),
qk_norm: bool = True, eps: float = 1e-6):
super().__init__()
assert dim % num_heads == 0
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
# Linear projections
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
# Normalization layers
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def __call__(self, x: mx.array, seq_lens: mx.array, grid_sizes: mx.array,
freqs: mx.array) -> mx.array:
b, s = x.shape[0], x.shape[1]
# Compute Q, K, V
q = self.q(x)
k = self.k(x)
v = self.v(x)
if self.qk_norm:
q = self.norm_q(q)
k = self.norm_k(k)
# Reshape for multi-head attention
q = q.reshape(b, s, self.num_heads, self.head_dim)
k = k.reshape(b, s, self.num_heads, self.head_dim)
v = v.reshape(b, s, self.num_heads, self.head_dim)
# Apply RoPE
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
# Apply attention
x = mlx_attention(q, k, v, k_lens=seq_lens, window_size=self.window_size)
# Reshape and project output
x = x.reshape(b, s, self.dim)
x = self.o(x)
return x
class WanT2VCrossAttention(WanSelfAttention):
"""Text-to-video cross attention."""
def __call__(self, x: mx.array, context: mx.array, context_lens: mx.array) -> mx.array:
b = x.shape[0]
# Compute queries from x
q = self.q(x)
if self.qk_norm:
q = self.norm_q(q)
q = q.reshape(b, -1, self.num_heads, self.head_dim)
# Compute keys and values from context
k = self.k(context)
v = self.v(context)
if self.qk_norm:
k = self.norm_k(k)
k = k.reshape(b, -1, self.num_heads, self.head_dim)
v = v.reshape(b, -1, self.num_heads, self.head_dim)
# Apply attention
x = mlx_attention(q, k, v, k_lens=context_lens)
# Reshape and project output
x = x.reshape(b, -1, self.dim)
x = self.o(x)
return x
class WanI2VCrossAttention(WanSelfAttention):
"""Image-to-video cross attention."""
def __init__(self, dim: int, num_heads: int, window_size: Tuple[int, int] = (-1, -1),
qk_norm: bool = True, eps: float = 1e-6):
super().__init__(dim, num_heads, window_size, qk_norm, eps)
self.k_img = nn.Linear(dim, dim)
self.v_img = nn.Linear(dim, dim)
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def __call__(self, x: mx.array, context: mx.array, context_lens: mx.array) -> mx.array:
# Split context into image and text parts
context_img = context[:, :257]
context = context[:, 257:]
b = x.shape[0]
# Compute queries
q = self.q(x)
if self.qk_norm:
q = self.norm_q(q)
q = q.reshape(b, -1, self.num_heads, self.head_dim)
# Compute keys and values for text
k = self.k(context)
v = self.v(context)
if self.qk_norm:
k = self.norm_k(k)
k = k.reshape(b, -1, self.num_heads, self.head_dim)
v = v.reshape(b, -1, self.num_heads, self.head_dim)
# Compute keys and values for image
k_img = self.k_img(context_img)
v_img = self.v_img(context_img)
if self.qk_norm:
k_img = self.norm_k_img(k_img)
k_img = k_img.reshape(b, -1, self.num_heads, self.head_dim)
v_img = v_img.reshape(b, -1, self.num_heads, self.head_dim)
# Apply attention
img_x = mlx_attention(q, k_img, v_img, k_lens=None)
x = mlx_attention(q, k, v, k_lens=context_lens)
# Combine and project
img_x = img_x.reshape(b, -1, self.dim)
x = x.reshape(b, -1, self.dim)
x = x + img_x
x = self.o(x)
return x
WAN_CROSSATTENTION_CLASSES = {
't2v_cross_attn': WanT2VCrossAttention,
'i2v_cross_attn': WanI2VCrossAttention,
}
class WanAttentionBlock(nn.Module):
"""Transformer block with self-attention, cross-attention, and FFN."""
def __init__(self, cross_attn_type: str, dim: int, ffn_dim: int, num_heads: int,
window_size: Tuple[int, int] = (-1, -1), qk_norm: bool = True,
cross_attn_norm: bool = False, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# Layers
self.norm1 = WanLayerNorm(dim, eps)
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](
dim, num_heads, (-1, -1), qk_norm, eps)
self.norm2 = WanLayerNorm(dim, eps)
# FFN - use a list instead of Sequential to match PyTorch exactly!
self.ffn = [
nn.Linear(dim, ffn_dim),
nn.GELU(),
nn.Linear(ffn_dim, dim)
]
# Modulation parameters
self.modulation = mx.random.normal((1, 6, dim)) / math.sqrt(dim)
def __call__(self, x: mx.array, e: mx.array, seq_lens: mx.array,
grid_sizes: mx.array, freqs: mx.array, context: mx.array,
context_lens: Optional[mx.array]) -> mx.array:
# Apply modulation
e = (self.modulation + e).astype(mx.float32)
e_chunks = [mx.squeeze(chunk, axis=1) for chunk in mx.split(e, 6, axis=1)]
# Self-attention with modulation
y = self.norm1(x).astype(mx.float32)
y = y * (1 + e_chunks[1]) + e_chunks[0]
y = self.self_attn(y, seq_lens, grid_sizes, freqs)
x = x + y * e_chunks[2]
# Cross-attention
if self.cross_attn_norm and isinstance(self.norm3, WanLayerNorm):
x = x + self.cross_attn(self.norm3(x), context, context_lens)
else:
x = x + self.cross_attn(x, context, context_lens)
# FFN with modulation
y = self.norm2(x).astype(mx.float32)
y = y * (1 + e_chunks[4]) + e_chunks[3]
# Apply FFN layers manually
y = self.ffn[0](y) # Linear
y = self.ffn[1](y) # GELU
y = self.ffn[2](y) # Linear
x = x + y * e_chunks[5]
return x
class Head(nn.Module):
"""Output head for final projection."""
def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int],
eps: float = 1e-6):
super().__init__()
self.dim = dim
self.out_dim = out_dim
self.patch_size = patch_size
self.eps = eps
# Output projection
out_features = int(np.prod(patch_size)) * out_dim
self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, out_features)
# Modulation
self.modulation = mx.random.normal((1, 2, dim)) / math.sqrt(dim)
def __call__(self, x: mx.array, e: mx.array) -> mx.array:
# Apply modulation
e = (self.modulation + mx.expand_dims(e, 1)).astype(mx.float32)
e_chunks = mx.split(e, 2, axis=1)
# Apply normalization and projection with modulation
x = self.norm(x) * (1 + e_chunks[1]) + e_chunks[0]
x = self.head(x)
return x
class MLPProj(nn.Module):
"""MLP projection for image embeddings."""
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
# Use a list to match PyTorch Sequential indexing
self.proj = [
nn.LayerNorm(in_dim),
nn.Linear(in_dim, in_dim),
nn.GELU(),
nn.Linear(in_dim, out_dim),
nn.LayerNorm(out_dim)
]
def __call__(self, image_embeds: mx.array) -> mx.array:
x = image_embeds
for layer in self.proj:
x = layer(x)
return x
class WanModel(nn.Module):
"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
MLX implementation - True 1:1 port from PyTorch.
"""
def __init__(
self,
model_type: str = 't2v',
patch_size: Tuple[int, int, int] = (1, 2, 2),
text_len: int = 512,
in_dim: int = 16,
dim: int = 2048,
ffn_dim: int = 8192,
freq_dim: int = 256,
text_dim: int = 4096,
out_dim: int = 16,
num_heads: int = 16,
num_layers: int = 32,
window_size: Tuple[int, int] = (-1, -1),
qk_norm: bool = True,
cross_attn_norm: bool = True,
eps: float = 1e-6
):
super().__init__()
assert model_type in ['t2v', 'i2v']
self.model_type = model_type
# Store configuration
self.patch_size = patch_size
self.text_len = text_len
self.in_dim = in_dim
self.dim = dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.text_dim = text_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# Embeddings
self.patch_embedding = nn.Conv3d(
in_dim, dim,
kernel_size=patch_size,
stride=patch_size
)
# Use lists instead of Sequential to match PyTorch!
self.text_embedding = [
nn.Linear(text_dim, dim),
nn.GELU(),
nn.Linear(dim, dim)
]
self.time_embedding = [
nn.Linear(freq_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim)
]
self.time_projection = [
nn.SiLU(),
nn.Linear(dim, dim * 6)
]
# Transformer blocks
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
self.blocks = [
WanAttentionBlock(
cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps
)
for _ in range(num_layers)
]
# Output head
self.head = Head(dim, out_dim, patch_size, eps)
# Precompute RoPE frequencies
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
self.freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
], axis=1)
# Image embedding for i2v
if model_type == 'i2v':
self.img_emb = MLPProj(1280, dim)
# Initialize weights
self.init_weights()
def __call__(
self,
x: List[mx.array],
t: mx.array,
context: List[mx.array],
seq_len: int,
clip_fea: Optional[mx.array] = None,
y: Optional[List[mx.array]] = None
) -> List[mx.array]:
"""
Forward pass through the diffusion model.
Args:
x: List of input video tensors [C_in, F, H, W]
t: Diffusion timesteps [B]
context: List of text embeddings [L, C]
seq_len: Maximum sequence length
clip_fea: CLIP image features for i2v mode
y: Conditional video inputs for i2v mode
Returns:
List of denoised video tensors [C_out, F, H/8, W/8]
"""
if self.model_type == 'i2v':
assert clip_fea is not None and y is not None
# Concatenate conditional inputs if provided
if y is not None:
x = [mx.concatenate([u, v], axis=0) for u, v in zip(x, y)]
# Patch embedding
x = [mx.transpose(mx.expand_dims(u, 0), (0, 2, 3, 4, 1)) for u in x]
x = [self.patch_embedding(u) for u in x]
# Transpose back from MLX format (N, D, H, W, C) to (N, C, D, H, W) for the rest of the model
x = [mx.transpose(u, (0, 4, 1, 2, 3)) for u in x]
grid_sizes = mx.array([[u.shape[2], u.shape[3], u.shape[4]] for u in x])
# Flatten spatial dimensions
x = [mx.transpose(u.reshape(u.shape[0], u.shape[1], -1), (0, 2, 1)) for u in x]
seq_lens = mx.array([u.shape[1] for u in x])
# Pad sequences to max length
x_padded = []
for u in x:
if u.shape[1] < seq_len:
padding = mx.zeros((1, seq_len - u.shape[1], u.shape[2]))
u = mx.concatenate([u, padding], axis=1)
x_padded.append(u)
x = mx.concatenate(x_padded, axis=0)
# Time embeddings - apply layers manually
e = sinusoidal_embedding_1d(self.freq_dim, t).astype(mx.float32)
e = self.time_embedding[0](e) # Linear
e = self.time_embedding[1](e) # SiLU
e = self.time_embedding[2](e) # Linear
# Time projection
e = self.time_projection[0](e) # SiLU
e0 = self.time_projection[1](e).reshape(-1, 6, self.dim) # Linear
# Process context
context_lens = None
context_padded = []
for u in context:
if u.shape[0] < self.text_len:
padding = mx.zeros((self.text_len - u.shape[0], u.shape[1]))
u = mx.concatenate([u, padding], axis=0)
context_padded.append(u)
context = mx.stack(context_padded)
# Apply text embedding layers manually
context = self.text_embedding[0](context) # Linear
context = self.text_embedding[1](context) # GELU
context = self.text_embedding[2](context) # Linear
# Add image embeddings for i2v
if clip_fea is not None:
context_clip = self.img_emb(clip_fea)
context = mx.concatenate([context_clip, context], axis=1)
# Apply transformer blocks
for block in self.blocks:
x = block(
x, e0, seq_lens, grid_sizes, self.freqs,
context, context_lens
)
# Apply output head
x = self.head(x, e)
# Unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.astype(mx.float32) for u in x]
def unpatchify(self, x: mx.array, grid_sizes: mx.array) -> List[mx.array]:
"""Reconstruct video tensors from patch embeddings."""
c = self.out_dim
out = []
for i in range(grid_sizes.shape[0]):
f, h, w = int(grid_sizes[i, 0]), int(grid_sizes[i, 1]), int(grid_sizes[i, 2])
seq_len = f * h * w
# Extract relevant sequence
u = x[i, :seq_len]
# Reshape to grid with patches
pf, ph, pw = self.patch_size
u = u.reshape(f, h, w, pf, ph, pw, c)
# Rearrange dimensions
u = mx.transpose(u, (6, 0, 3, 1, 4, 2, 5))
# Combine patches
u = u.reshape(c, f * pf, h * ph, w * pw)
out.append(u)
return out
def init_weights(self):
"""Initialize model parameters using Xavier/He initialization."""
# Note: MLX doesn't have nn.init like PyTorch, so we manually initialize
# Helper function for Xavier uniform initialization
def xavier_uniform(shape):
bound = mx.sqrt(6.0 / (shape[0] + shape[1]))
return mx.random.uniform(low=-bound, high=bound, shape=shape)
# Initialize linear layers in blocks
for block in self.blocks:
# Self attention
block.self_attn.q.weight = xavier_uniform(block.self_attn.q.weight.shape)
block.self_attn.k.weight = xavier_uniform(block.self_attn.k.weight.shape)
block.self_attn.v.weight = xavier_uniform(block.self_attn.v.weight.shape)
block.self_attn.o.weight = xavier_uniform(block.self_attn.o.weight.shape)
# Cross attention
block.cross_attn.q.weight = xavier_uniform(block.cross_attn.q.weight.shape)
block.cross_attn.k.weight = xavier_uniform(block.cross_attn.k.weight.shape)
block.cross_attn.v.weight = xavier_uniform(block.cross_attn.v.weight.shape)
block.cross_attn.o.weight = xavier_uniform(block.cross_attn.o.weight.shape)
# FFN layers - now it's a list!
block.ffn[0].weight = xavier_uniform(block.ffn[0].weight.shape)
block.ffn[2].weight = xavier_uniform(block.ffn[2].weight.shape)
# Modulation
block.modulation = mx.random.normal(
shape=(1, 6, self.dim),
scale=1.0 / math.sqrt(self.dim)
)
# Special initialization for embeddings
# Patch embedding - Xavier uniform
weight_shape = self.patch_embedding.weight.shape
fan_in = weight_shape[1] * np.prod(self.patch_size)
fan_out = weight_shape[0]
bound = mx.sqrt(6.0 / (fan_in + fan_out))
self.patch_embedding.weight = mx.random.uniform(
low=-bound,
high=bound,
shape=weight_shape
)
# Text embedding - normal distribution with std=0.02
self.text_embedding[0].weight = mx.random.normal(shape=self.text_embedding[0].weight.shape, scale=0.02)
self.text_embedding[2].weight = mx.random.normal(shape=self.text_embedding[2].weight.shape, scale=0.02)
# Time embedding - normal distribution with std=0.02
self.time_embedding[0].weight = mx.random.normal(shape=self.time_embedding[0].weight.shape, scale=0.02)
self.time_embedding[2].weight = mx.random.normal(shape=self.time_embedding[2].weight.shape, scale=0.02)
# Output head - initialize to zeros
self.head.head.weight = mx.zeros(self.head.head.weight.shape)
# Head modulation
self.head.modulation = mx.random.normal(
shape=(1, 2, self.dim),
scale=1.0 / math.sqrt(self.dim)
)
# Initialize i2v specific layers if present
if self.model_type == 'i2v':
for i in [1, 3]: # Linear layers in the proj list
if isinstance(self.img_emb.proj[i], nn.Linear):
self.img_emb.proj[i].weight = xavier_uniform(self.img_emb.proj[i].weight.shape)

View File

@ -0,0 +1,513 @@
# Modified from transformers.models.t5.modeling_t5
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .tokenizers import HuggingfaceTokenizer
__all__ = [
'T5Model',
'T5Encoder',
'T5Decoder',
'T5EncoderModel',
]
def fp16_clamp(x):
if x.dtype == torch.float16 and torch.isinf(x).any():
clamp = torch.finfo(x.dtype).max - 1000
x = torch.clamp(x, min=-clamp, max=clamp)
return x
def init_weights(m):
if isinstance(m, T5LayerNorm):
nn.init.ones_(m.weight)
elif isinstance(m, T5Model):
nn.init.normal_(m.token_embedding.weight, std=1.0)
elif isinstance(m, T5FeedForward):
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
elif isinstance(m, T5Attention):
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
elif isinstance(m, T5RelativeEmbedding):
nn.init.normal_(
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
class GELU(nn.Module):
def forward(self, x):
return 0.5 * x * (1.0 + torch.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
class T5LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super(T5LayerNorm, self).__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
self.eps)
if self.weight.dtype in [torch.float16, torch.float32]:
x = x.type_as(self.weight)
return self.weight * x
class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
assert dim_attn % num_heads == 0
super(T5Attention, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
self.num_heads = num_heads
self.head_dim = dim_attn // num_heads
# layers
self.q = nn.Linear(dim, dim_attn, bias=False)
self.k = nn.Linear(dim, dim_attn, bias=False)
self.v = nn.Linear(dim, dim_attn, bias=False)
self.o = nn.Linear(dim_attn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, context=None, mask=None, pos_bias=None):
"""
x: [B, L1, C].
context: [B, L2, C] or None.
mask: [B, L2] or [B, L1, L2] or None.
"""
# check inputs
context = x if context is None else context
b, n, c = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.q(x).view(b, -1, n, c)
k = self.k(context).view(b, -1, n, c)
v = self.v(context).view(b, -1, n, c)
# attention bias
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
if pos_bias is not None:
attn_bias += pos_bias
if mask is not None:
assert mask.ndim in [2, 3]
mask = mask.view(b, 1, 1,
-1) if mask.ndim == 2 else mask.unsqueeze(1)
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
# compute attention (T5 does not use scaling)
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
x = torch.einsum('bnij,bjnc->binc', attn, v)
# output
x = x.reshape(b, -1, n * c)
x = self.o(x)
x = self.dropout(x)
return x
class T5FeedForward(nn.Module):
def __init__(self, dim, dim_ffn, dropout=0.1):
super(T5FeedForward, self).__init__()
self.dim = dim
self.dim_ffn = dim_ffn
# layers
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x) * self.gate(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class T5SelfAttention(nn.Module):
def __init__(self,
dim,
dim_attn,
dim_ffn,
num_heads,
num_buckets,
shared_pos=True,
dropout=0.1):
super(T5SelfAttention, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.norm1 = T5LayerNorm(dim)
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=True)
def forward(self, x, mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(
x.size(1), x.size(1))
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.ffn(self.norm2(x)))
return x
class T5CrossAttention(nn.Module):
def __init__(self,
dim,
dim_attn,
dim_ffn,
num_heads,
num_buckets,
shared_pos=True,
dropout=0.1):
super(T5CrossAttention, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.norm1 = T5LayerNorm(dim)
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim)
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm3 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=False)
def forward(self,
x,
mask=None,
encoder_states=None,
encoder_mask=None,
pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(
x.size(1), x.size(1))
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.cross_attn(
self.norm2(x), context=encoder_states, mask=encoder_mask))
x = fp16_clamp(x + self.ffn(self.norm3(x)))
return x
class T5RelativeEmbedding(nn.Module):
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
super(T5RelativeEmbedding, self).__init__()
self.num_buckets = num_buckets
self.num_heads = num_heads
self.bidirectional = bidirectional
self.max_dist = max_dist
# layers
self.embedding = nn.Embedding(num_buckets, num_heads)
def forward(self, lq, lk):
device = self.embedding.weight.device
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
# torch.arange(lq).unsqueeze(1).to(device)
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
torch.arange(lq, device=device).unsqueeze(1)
rel_pos = self._relative_position_bucket(rel_pos)
rel_pos_embeds = self.embedding(rel_pos)
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
0) # [1, N, Lq, Lk]
return rel_pos_embeds.contiguous()
def _relative_position_bucket(self, rel_pos):
# preprocess
if self.bidirectional:
num_buckets = self.num_buckets // 2
rel_buckets = (rel_pos > 0).long() * num_buckets
rel_pos = torch.abs(rel_pos)
else:
num_buckets = self.num_buckets
rel_buckets = 0
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
# embeddings for small and large positions
max_exact = num_buckets // 2
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
math.log(self.max_dist / max_exact) *
(num_buckets - max_exact)).long()
rel_pos_large = torch.min(
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
return rel_buckets
class T5Encoder(nn.Module):
def __init__(self,
vocab,
dim,
dim_attn,
dim_ffn,
num_heads,
num_layers,
num_buckets,
shared_pos=True,
dropout=0.1):
super(T5Encoder, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_layers = num_layers
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
else nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=True) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
shared_pos, dropout) for _ in range(num_layers)
])
self.norm = T5LayerNorm(dim)
# initialize weights
self.apply(init_weights)
def forward(self, ids, mask=None):
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.size(1),
x.size(1)) if self.shared_pos else None
for block in self.blocks:
x = block(x, mask, pos_bias=e)
x = self.norm(x)
x = self.dropout(x)
return x
class T5Decoder(nn.Module):
def __init__(self,
vocab,
dim,
dim_attn,
dim_ffn,
num_heads,
num_layers,
num_buckets,
shared_pos=True,
dropout=0.1):
super(T5Decoder, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_layers = num_layers
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
else nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=False) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
shared_pos, dropout) for _ in range(num_layers)
])
self.norm = T5LayerNorm(dim)
# initialize weights
self.apply(init_weights)
def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
b, s = ids.size()
# causal mask
if mask is None:
mask = torch.tril(torch.ones(1, s, s).to(ids.device))
elif mask.ndim == 2:
mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
# layers
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.size(1),
x.size(1)) if self.shared_pos else None
for block in self.blocks:
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
x = self.norm(x)
x = self.dropout(x)
return x
class T5Model(nn.Module):
def __init__(self,
vocab_size,
dim,
dim_attn,
dim_ffn,
num_heads,
encoder_layers,
decoder_layers,
num_buckets,
shared_pos=True,
dropout=0.1):
super(T5Model, self).__init__()
self.vocab_size = vocab_size
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.encoder_layers = encoder_layers
self.decoder_layers = decoder_layers
self.num_buckets = num_buckets
# layers
self.token_embedding = nn.Embedding(vocab_size, dim)
self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
num_heads, encoder_layers, num_buckets,
shared_pos, dropout)
self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
num_heads, decoder_layers, num_buckets,
shared_pos, dropout)
self.head = nn.Linear(dim, vocab_size, bias=False)
# initialize weights
self.apply(init_weights)
def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
x = self.encoder(encoder_ids, encoder_mask)
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
x = self.head(x)
return x
def _t5(name,
encoder_only=False,
decoder_only=False,
return_tokenizer=False,
tokenizer_kwargs={},
dtype=torch.float32,
device='cpu',
**kwargs):
# sanity check
assert not (encoder_only and decoder_only)
# params
if encoder_only:
model_cls = T5Encoder
kwargs['vocab'] = kwargs.pop('vocab_size')
kwargs['num_layers'] = kwargs.pop('encoder_layers')
_ = kwargs.pop('decoder_layers')
elif decoder_only:
model_cls = T5Decoder
kwargs['vocab'] = kwargs.pop('vocab_size')
kwargs['num_layers'] = kwargs.pop('decoder_layers')
_ = kwargs.pop('encoder_layers')
else:
model_cls = T5Model
# init model
with torch.device(device):
model = model_cls(**kwargs)
# set device
model = model.to(dtype=dtype, device=device)
# init tokenizer
if return_tokenizer:
from .tokenizers import HuggingfaceTokenizer
tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
return model, tokenizer
else:
return model
def umt5_xxl(**kwargs):
cfg = dict(
vocab_size=256384,
dim=4096,
dim_attn=4096,
dim_ffn=10240,
num_heads=64,
encoder_layers=24,
decoder_layers=24,
num_buckets=32,
shared_pos=False,
dropout=0.1)
cfg.update(**kwargs)
return _t5('umt5-xxl', **cfg)
class T5EncoderModel:
def __init__(
self,
text_len,
dtype=torch.float32,
device='mps' if torch.backends.mps.is_available() else 'cpu',
checkpoint_path=None,
tokenizer_path=None,
shard_fn=None,
):
self.text_len = text_len
self.dtype = dtype
self.device = device
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
# init model
model = umt5_xxl(
encoder_only=True,
return_tokenizer=False,
dtype=dtype,
device=device).eval().requires_grad_(False)
logging.info(f'loading {checkpoint_path}')
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
self.model = model
if shard_fn is not None:
self.model = shard_fn(self.model, sync_module_states=False)
else:
self.model.to(self.device)
# init tokenizer
self.tokenizer = HuggingfaceTokenizer(
name=tokenizer_path, seq_len=text_len, clean='whitespace')
def __call__(self, texts, device):
ids, mask = self.tokenizer(
texts, return_mask=True, add_special_tokens=True)
ids = ids.to(device)
mask = mask.to(device)
seq_lens = mask.gt(0).sum(dim=1).long()
context = self.model(ids, mask)
return [u[:v] for u, v in zip(context, seq_lens)]

View File

@ -0,0 +1,617 @@
# Modified from transformers.models.t5.modeling_t5 for MLX
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import math
from typing import Optional, Tuple, List
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_unflatten
from .tokenizers import HuggingfaceTokenizer
__all__ = [
'T5Model',
'T5Encoder',
'T5Decoder',
'T5EncoderModel',
]
def fp16_clamp(x):
if x.dtype == mx.float16:
# Use same clamping as PyTorch for consistency
clamp = 65504.0 # max value for float16
return mx.clip(x, -clamp, clamp)
return x
class GELU(nn.Module):
def __call__(self, x):
return 0.5 * x * (1.0 + mx.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * mx.power(x, 3.0))))
class T5LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = mx.ones((dim,))
def __call__(self, x):
# Match PyTorch's approach: convert to float32 for stability
x_float = x.astype(mx.float32) if x.dtype == mx.float16 else x
variance = mx.mean(mx.square(x_float), axis=-1, keepdims=True)
x_norm = x_float * mx.rsqrt(variance + self.eps)
# Convert back to original dtype
if x.dtype == mx.float16:
x_norm = x_norm.astype(mx.float16)
return self.weight * x_norm
class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.0):
assert dim_attn % num_heads == 0
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.num_heads = num_heads
self.head_dim = dim_attn // num_heads
# layers
self.q = nn.Linear(dim, dim_attn, bias=False)
self.k = nn.Linear(dim, dim_attn, bias=False)
self.v = nn.Linear(dim, dim_attn, bias=False)
self.o = nn.Linear(dim_attn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def __call__(self, x, context=None, mask=None, pos_bias=None):
"""
x: [B, L1, C].
context: [B, L2, C] or None.
mask: [B, L2] or [B, L1, L2] or None.
"""
# check inputs
context = x if context is None else context
b, l1, _ = x.shape
_, l2, _ = context.shape
n, c = self.num_heads, self.head_dim
# compute query, key, value
q = self.q(x).reshape(b, l1, n, c)
k = self.k(context).reshape(b, l2, n, c)
v = self.v(context).reshape(b, l2, n, c)
# transpose for attention: [B, N, L, C]
q = mx.transpose(q, (0, 2, 1, 3))
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
# compute attention (T5 does not use scaling)
attn = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) # [B, N, L1, L2]
# add position bias if provided
if pos_bias is not None:
attn = attn + pos_bias
# apply mask
if mask is not None:
if mask.ndim == 2:
# [B, L2] -> [B, 1, 1, L2]
mask = mask[:, None, None, :]
elif mask.ndim == 3:
# [B, L1, L2] -> [B, 1, L1, L2]
mask = mask[:, None, :, :]
# Use very negative value that works well with float16
min_value = -65504.0 if attn.dtype == mx.float16 else -1e9
attn = mx.where(mask == 0, min_value, attn)
# softmax and apply attention
attn = mx.softmax(attn.astype(mx.float32), axis=-1).astype(attn.dtype)
attn = self.dropout(attn)
# apply attention to values
x = mx.matmul(attn, v) # [B, N, L1, C]
# transpose back and reshape
x = mx.transpose(x, (0, 2, 1, 3)) # [B, L1, N, C]
x = x.reshape(b, l1, -1)
# output projection
x = self.o(x)
x = self.dropout(x)
return x
class T5FeedForward(nn.Module):
def __init__(self, dim, dim_ffn, dropout=0.0):
super().__init__()
self.dim = dim
self.dim_ffn = dim_ffn
# layers
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
self.gate_act = GELU()
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def __call__(self, x):
gate = self.gate_act(self.gate_proj(x))
x = self.fc1(x) * gate
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class T5SelfAttention(nn.Module):
def __init__(self,
dim,
dim_attn,
dim_ffn,
num_heads,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.norm1 = T5LayerNorm(dim)
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=True)
def __call__(self, x, mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(
x.shape[1], x.shape[1])
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.ffn(self.norm2(x)))
return x
class T5CrossAttention(nn.Module):
def __init__(self,
dim,
dim_attn,
dim_ffn,
num_heads,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.norm1 = T5LayerNorm(dim)
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim)
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm3 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=False)
def __call__(self,
x,
mask=None,
encoder_states=None,
encoder_mask=None,
pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(
x.shape[1], x.shape[1])
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.cross_attn(
self.norm2(x), context=encoder_states, mask=encoder_mask))
x = fp16_clamp(x + self.ffn(self.norm3(x)))
return x
class T5RelativeEmbedding(nn.Module):
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
super().__init__()
self.num_buckets = num_buckets
self.num_heads = num_heads
self.bidirectional = bidirectional
self.max_dist = max_dist
# layers
self.embedding = nn.Embedding(num_buckets, num_heads)
def __call__(self, lq, lk):
# Create relative position matrix
positions_q = mx.arange(lq)[:, None]
positions_k = mx.arange(lk)[None, :]
rel_pos = positions_k - positions_q
# Apply bucketing
rel_pos = self._relative_position_bucket(rel_pos)
# Get embeddings
rel_pos_embeds = self.embedding(rel_pos)
# Reshape to [1, N, Lq, Lk]
rel_pos_embeds = mx.transpose(rel_pos_embeds, (2, 0, 1))
rel_pos_embeds = mx.expand_dims(rel_pos_embeds, 0)
return rel_pos_embeds
def _relative_position_bucket(self, rel_pos):
# preprocess
if self.bidirectional:
num_buckets = self.num_buckets // 2
rel_buckets = mx.array(rel_pos > 0, dtype=mx.int32) * num_buckets
rel_pos = mx.abs(rel_pos)
else:
num_buckets = self.num_buckets
rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32)
rel_pos = -mx.minimum(rel_pos, mx.zeros_like(rel_pos))
# embeddings for small and large positions
max_exact = num_buckets // 2
is_small = rel_pos < max_exact
# For large positions, use log scale
rel_pos_large = max_exact + (
mx.log(mx.array(rel_pos, dtype=mx.float32) / max_exact) /
math.log(self.max_dist / max_exact) *
(num_buckets - max_exact)
).astype(mx.int32)
rel_pos_large = mx.minimum(rel_pos_large, num_buckets - 1)
# Combine small and large position buckets
rel_buckets = rel_buckets + mx.where(is_small, rel_pos, rel_pos_large)
return rel_buckets
class T5Encoder(nn.Module):
def __init__(self,
vocab,
dim,
dim_attn,
dim_ffn,
num_heads,
num_layers,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_layers = num_layers
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
if isinstance(vocab, nn.Embedding):
self.token_embedding = vocab
else:
self.token_embedding = nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=True) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = [
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
shared_pos, dropout) for _ in range(num_layers)
]
self.norm = T5LayerNorm(dim)
def __call__(self, ids, mask=None):
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.shape[1],
x.shape[1]) if self.shared_pos else None
for block in self.blocks:
x = block(x, mask, pos_bias=e)
x = self.norm(x)
x = self.dropout(x)
return x
class T5Decoder(nn.Module):
def __init__(self,
vocab,
dim,
dim_attn,
dim_ffn,
num_heads,
num_layers,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_layers = num_layers
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
if isinstance(vocab, nn.Embedding):
self.token_embedding = vocab
else:
self.token_embedding = nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=False) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = [
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
shared_pos, dropout) for _ in range(num_layers)
]
self.norm = T5LayerNorm(dim)
def __call__(self, ids, mask=None, encoder_states=None, encoder_mask=None):
b, s = ids.shape
# causal mask
if mask is None:
mask = mx.tril(mx.ones((1, s, s)))
elif mask.ndim == 2:
# Expand mask properly
mask = mx.tril(mx.expand_dims(mask, 1).broadcast_to((b, s, s)))
# layers
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.shape[1],
x.shape[1]) if self.shared_pos else None
for block in self.blocks:
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
x = self.norm(x)
x = self.dropout(x)
return x
class T5Model(nn.Module):
def __init__(self,
vocab_size,
dim,
dim_attn,
dim_ffn,
num_heads,
encoder_layers,
decoder_layers,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.encoder_layers = encoder_layers
self.decoder_layers = decoder_layers
self.num_buckets = num_buckets
# layers
self.token_embedding = nn.Embedding(vocab_size, dim)
self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
num_heads, encoder_layers, num_buckets,
shared_pos, dropout)
self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
num_heads, decoder_layers, num_buckets,
shared_pos, dropout)
self.head = nn.Linear(dim, vocab_size, bias=False)
def __call__(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
x = self.encoder(encoder_ids, encoder_mask)
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
x = self.head(x)
return x
def init_mlx_weights(module, key):
"""Initialize weights for T5 model components to match PyTorch initialization"""
def normal(key, shape, std=1.0):
return mx.random.normal(key, shape) * std
if isinstance(module, T5LayerNorm):
module.weight = mx.ones_like(module.weight)
elif isinstance(module, nn.Embedding):
key = mx.random.split(key, 1)[0]
module.weight = normal(key, module.weight.shape, std=1.0)
elif isinstance(module, T5FeedForward):
# Match PyTorch initialization
key1, key2, key3 = mx.random.split(key, 3)
module.gate_proj.weight = normal(key1, module.gate_proj.weight.shape,
std=module.dim**-0.5)
module.fc1.weight = normal(key2, module.fc1.weight.shape,
std=module.dim**-0.5)
module.fc2.weight = normal(key3, module.fc2.weight.shape,
std=module.dim_ffn**-0.5)
elif isinstance(module, T5Attention):
# Match PyTorch initialization
key1, key2, key3, key4 = random.split(key, 4)
module.q.weight = normal(key1, module.q.weight.shape,
std=(module.dim * module.dim_attn)**-0.5)
module.k.weight = normal(key2, module.k.weight.shape,
std=module.dim**-0.5)
module.v.weight = normal(key3, module.v.weight.shape,
std=module.dim**-0.5)
module.o.weight = normal(key4, module.o.weight.shape,
std=(module.num_heads * module.dim_attn)**-0.5)
elif isinstance(module, T5RelativeEmbedding):
key = mx.random.split(key, 1)[0]
module.embedding.weight = normal(key, module.embedding.weight.shape,
std=(2 * module.num_buckets * module.num_heads)**-0.5)
elif isinstance(module, nn.Linear):
# Generic linear layer initialization
key = mx.random.split(key, 1)[0]
fan_in = module.weight.shape[1]
bound = 1.0 / math.sqrt(fan_in)
module.weight = mx.random.uniform(key, module.weight.shape, -bound, bound)
return module
def _t5(name,
encoder_only=False,
decoder_only=False,
return_tokenizer=False,
tokenizer_kwargs={},
**kwargs):
# sanity check
assert not (encoder_only and decoder_only)
# params
if encoder_only:
model_cls = T5Encoder
kwargs['vocab'] = kwargs.pop('vocab_size')
kwargs['num_layers'] = kwargs.pop('encoder_layers')
_ = kwargs.pop('decoder_layers')
elif decoder_only:
model_cls = T5Decoder
kwargs['vocab'] = kwargs.pop('vocab_size')
kwargs['num_layers'] = kwargs.pop('decoder_layers')
_ = kwargs.pop('encoder_layers')
else:
model_cls = T5Model
# init model
model = model_cls(**kwargs)
# Initialize weights properly
key = mx.random.key(0)
model = init_mlx_weights(model, key)
# init tokenizer
if return_tokenizer:
from .tokenizers import HuggingfaceTokenizer
tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
return model, tokenizer
else:
return model
def umt5_xxl(**kwargs):
cfg = dict(
vocab_size=256384,
dim=4096,
dim_attn=4096,
dim_ffn=10240,
num_heads=64,
encoder_layers=24,
decoder_layers=24,
num_buckets=32,
shared_pos=False,
dropout=0.0)
cfg.update(**kwargs)
return _t5('umt5-xxl', **cfg)
class T5EncoderModel:
def __init__(
self,
text_len,
checkpoint_path=None,
tokenizer_path=None,
):
self.text_len = text_len
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
# init model
model = umt5_xxl(
encoder_only=True,
return_tokenizer=False)
if checkpoint_path:
logging.info(f'loading {checkpoint_path}')
# Load weights - assuming MLX format checkpoint
weights = mx.load(checkpoint_path)
model.update(tree_unflatten(list(weights.items())))
self.model = model
# init tokenizer
from .tokenizers import HuggingfaceTokenizer
self.tokenizer = HuggingfaceTokenizer(
name=tokenizer_path if tokenizer_path else 'google/umt5-xxl',
seq_len=text_len,
clean='whitespace')
def __call__(self, texts):
# Handle single string input
if isinstance(texts, str):
texts = [texts]
# Tokenize texts
tokenizer_output = self.tokenizer(
texts, return_mask=True, add_special_tokens=True)
# Handle different tokenizer output formats
if isinstance(tokenizer_output, tuple):
ids, mask = tokenizer_output
else:
# Assuming dict output with 'input_ids' and 'attention_mask'
ids = tokenizer_output['input_ids']
mask = tokenizer_output['attention_mask']
# Convert to MLX arrays if not already
if not isinstance(ids, mx.array):
ids = mx.array(ids)
if not isinstance(mask, mx.array):
mask = mx.array(mask)
# Get sequence lengths
seq_lens = mx.sum(mask > 0, axis=1)
# Run encoder
context = self.model(ids, mask)
# Return variable length outputs
# Convert seq_lens to Python list for indexing
if seq_lens.ndim == 0: # Single value
seq_lens_list = [seq_lens.item()]
else:
seq_lens_list = seq_lens.tolist()
return [context[i, :int(seq_lens_list[i])] for i in range(len(texts))]
# Utility function to convert PyTorch checkpoint to MLX
def convert_pytorch_checkpoint(pytorch_path, mlx_path):
"""Convert PyTorch checkpoint to MLX format"""
import torch
# Load PyTorch checkpoint
pytorch_state = torch.load(pytorch_path, map_location='cpu')
# Convert to numpy then to MLX
mlx_state = {}
for key, value in pytorch_state.items():
if isinstance(value, torch.Tensor):
# Handle the key mapping if needed
mlx_key = key
# Convert tensor to MLX array
mlx_state[mlx_key] = mx.array(value.numpy())
# Save MLX checkpoint
mx.save(mlx_path, mlx_state)
return mlx_state

View File

@ -0,0 +1,82 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import html
import string
import ftfy
import regex as re
from transformers import AutoTokenizer
__all__ = ['HuggingfaceTokenizer']
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
def canonicalize(text, keep_punctuation_exact_string=None):
text = text.replace('_', ' ')
if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join(
part.translate(str.maketrans('', '', string.punctuation))
for part in text.split(keep_punctuation_exact_string))
else:
text = text.translate(str.maketrans('', '', string.punctuation))
text = text.lower()
text = re.sub(r'\s+', ' ', text)
return text.strip()
class HuggingfaceTokenizer:
def __init__(self, name, seq_len=None, clean=None, **kwargs):
assert clean in (None, 'whitespace', 'lower', 'canonicalize')
self.name = name
self.seq_len = seq_len
self.clean = clean
# init tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
self.vocab_size = self.tokenizer.vocab_size
def __call__(self, sequence, **kwargs):
return_mask = kwargs.pop('return_mask', False)
# arguments
_kwargs = {'return_tensors': 'pt'}
if self.seq_len is not None:
_kwargs.update({
'padding': 'max_length',
'truncation': True,
'max_length': self.seq_len
})
_kwargs.update(**kwargs)
# tokenization
if isinstance(sequence, str):
sequence = [sequence]
if self.clean:
sequence = [self._clean(u) for u in sequence]
ids = self.tokenizer(sequence, **_kwargs)
# output
if return_mask:
return ids.input_ids, ids.attention_mask
else:
return ids.input_ids
def _clean(self, text):
if self.clean == 'whitespace':
text = whitespace_clean(basic_clean(text))
elif self.clean == 'lower':
text = whitespace_clean(basic_clean(text)).lower()
elif self.clean == 'canonicalize':
text = canonicalize(basic_clean(text))
return text

View File

@ -0,0 +1,702 @@
# Original PyTorch implementation of Wan VAE
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import torch
import torch.cuda.amp as amp
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
__all__ = [
'WanVAE',
]
CACHE_T = 2
debug_line = 0
def debug(name, x):
global debug_line
print(f"LINE {debug_line}: {name}: shape = {tuple(x.shape)}, mean = {x.mean().item():.4f}, std = {x.std().item():.4f}")
debug_line += 1
return x
class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolusion.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
result = super().forward(x)
debug("TORCH x after conv3d", result)
return result
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
def forward(self, x):
return F.normalize(
x, dim=(1 if self.channel_first else
-1)) * self.scale * self.gamma + self.bias
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
result = super().forward(x.float()).type_as(x)
debug("TORCH x after upsample", result)
return result
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
'downsample3d')
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == 'upsample2d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
elif mode == 'upsample3d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == 'downsample2d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == 'downsample3d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == 'upsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = 'Rep'
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] != 'Rep':
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
if feat_cache[idx] == 'Rep':
x = self.time_conv(x)
debug("TORCH x after time_conv", x)
else:
x = self.time_conv(x, feat_cache[idx])
debug("TORCH x after time_conv with cache", x)
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.resample(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
debug("TORCH x after resample", x)
if self.mode == 'downsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
# # cache last frame of last two chunk
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
def init_weight(self, conv):
conv_weight = conv.weight
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
one_matrix = torch.eye(c1, c2)
init_matrix = one_matrix
nn.init.zeros_(conv_weight)
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def init_weight2(self, conv):
conv_weight = conv.weight.data
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False), nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1))
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
debug("TORCH x after shortcut", h)
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
debug("TORCH x after residual block", x)
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
debug("TORCH x after residual block", x)
return x + h
class AttentionBlock(nn.Module):
"""
Causal self-attention with a single head.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = RMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
# zero out the last layer params
nn.init.zeros_(self.proj.weight)
def forward(self, x):
identity = x
b, c, t, h, w = x.size()
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.norm(x)
debug("TORCH x after norm", x)
# compute query, key, value
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
-1).permute(0, 1, 3,
2).contiguous().chunk(
3, dim=-1)
# apply attention
x = F.scaled_dot_product_attention(
q,
k,
v,
)
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
# output
x = self.proj(x)
debug("TORCH x after proj", x)
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
return x + identity
class Encoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
downsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = 'downsample3d' if temperal_downsample[
i] else 'downsample2d'
downsamples.append(Resample(out_dim, mode=mode))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim, dropout))
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
debug("TORCH x after conv1 with cache", x)
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
debug("TORCH x after conv1", x)
## downsamples
for i, layer in enumerate(self.downsamples):
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
debug("TORCH x after downsample layer", x)
else:
x = layer(x)
debug("TORCH x after downsample layer", x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
debug("TORCH x after downsample layer", x)
else:
x = layer(x)
debug("TORCH x after downsample layer", x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
debug("TORCH x after downsample layer", x)
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
debug("TORCH x after downsample layer", x)
return x
class Decoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout))
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i == 1 or i == 2 or i == 3:
in_dim = in_dim // 2
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
upsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# upsample block
if i != len(dim_mult) - 1:
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
debug("TORCH x after conv1 with cache", x)
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
debug("TORCH x after conv1", x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
debug("TORCH x after middle layer", x)
else:
x = layer(x)
debug("TORCH x after middle layer", x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
debug("TORCH x after upsample layer", x)
else:
x = layer(x)
debug("TORCH x after upsample layer", x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
debug("TORCH x after head layer", x)
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
debug("TORCH x after head layer", x)
return x
def count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, CausalConv3d):
count += 1
return count
class WanVAE_(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z)
return x_recon, mu, log_var
def encode(self, x, scale):
self.clear_cache()
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
## 对encode输入的x按时间拆分为1、4、4、4....
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1, self.z_dim, 1, 1, 1)
else:
mu = (mu - scale[0]) * scale[1]
self.clear_cache()
return mu
def decode(self, z, scale):
self.clear_cache()
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
#cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
"""
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
"""
# params
cfg = dict(
dim=96,
z_dim=z_dim,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[False, True, True],
dropout=0.0)
cfg.update(**kwargs)
# init model
with torch.device('meta'):
model = WanVAE_(**cfg)
# load checkpoint
logging.info(f'loading {pretrained_path}')
model.load_state_dict(
torch.load(pretrained_path, map_location=device), assign=True)
return model
class WanVAE:
def __init__(self,
z_dim=16,
vae_pth='cache/vae_step_411000.pth',
dtype=torch.float,
device="cpu"):
self.dtype = dtype
self.device = device
mean = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]
std = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]
self.mean = torch.tensor(mean, dtype=dtype, device=device)
self.std = torch.tensor(std, dtype=dtype, device=device)
self.scale = [self.mean, 1.0 / self.std]
# init model
self.model = _video_vae(
pretrained_path=vae_pth,
z_dim=z_dim,
).eval().requires_grad_(False).to(device)
def encode(self, videos):
"""
videos: A list of videos each with shape [C, T, H, W].
"""
with amp.autocast(dtype=self.dtype):
return [
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
for u in videos
]
def decode(self, zs):
with amp.autocast(dtype=self.dtype):
return [
self.model.decode(u.unsqueeze(0),
self.scale).float().clamp_(-1, 1).squeeze(0)
for u in zs
]

View File

@ -0,0 +1,719 @@
# vae_mlx_final.py
import logging
from typing import Optional, List, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_unflatten
__all__ = [
'WanVAE',
]
CACHE_T = 2
debug_line = 0
class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolution for MLX.
Expects input in BTHWC format (batch, time, height, width, channels).
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Padding order: (W, W, H, H, T, 0)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def __call__(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
x = mx.concatenate([cache_x, x], axis=1) # Concat along time axis
padding[4] -= cache_x.shape[1]
# Pad in BTHWC format
pad_width = [(0, 0), (padding[4], padding[5]), (padding[2], padding[3]),
(padding[0], padding[1]), (0, 0)]
x = mx.pad(x, pad_width)
result = super().__call__(x)
return result
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=False, images=True, bias=False):
super().__init__()
self.channel_first = channel_first
self.images = images
self.scale = dim**0.5
# Just keep as 1D - let broadcasting do its magic
self.gamma = mx.ones((dim,))
self.bias = mx.zeros((dim,)) if bias else 0.
def __call__(self, x):
# F.normalize in PyTorch does L2 normalization, not RMS!
# For NHWC/BTHWC format, normalize along the last axis
# L2 norm: sqrt(sum(x^2))
norm = mx.sqrt(mx.sum(x * x, axis=-1, keepdims=True) + 1e-6)
x = x / norm
return x * self.scale * self.gamma + self.bias
class Upsample(nn.Module):
"""
Upsampling layer that matches PyTorch's behavior.
"""
def __init__(self, scale_factor, mode='nearest-exact'):
super().__init__()
self.scale_factor = scale_factor
self.mode = mode # mode is now unused, but kept for signature consistency
def __call__(self, x):
# For NHWC format (n, h, w, c)
# NOTE: For an integer scale_factor like 2.0, PyTorch's 'nearest-exact'
# is equivalent to a simple repeat operation. The previous coordinate-based
# sampling was not correct for this model and caused the divergence.
scale_h, scale_w = self.scale_factor
out = mx.repeat(x, int(scale_h), axis=1) # Repeat along H dimension
out = mx.repeat(out, int(scale_w), axis=2) # Repeat along W dimension
return out
class AsymmetricPad(nn.Module):
"""A module to apply asymmetric padding, compatible with nn.Sequential."""
def __init__(self, pad_width: tuple):
super().__init__()
self.pad_width = pad_width
def __call__(self, x):
return mx.pad(x, self.pad_width)
# Update your Resample class to use 'nearest-exact'
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
'downsample3d')
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == 'upsample2d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1)
)
elif mode == 'upsample3d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1)
)
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
# --- CORRECTED PADDING LOGIC ---
elif mode == 'downsample2d':
# Replicate PyTorch's ZeroPad2d((0, 1, 0, 1)) + Conv2d(stride=2)
# Use the new AsymmetricPad module.
# Pad width for NHWC format is ((N), (H), (W), (C))
# Pad H with (top, bottom) and W with (left, right)
pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0)))
conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0)
self.resample = nn.Sequential(pad_layer, conv_layer)
elif mode == 'downsample3d':
# The spatial downsampling part uses the same logic
pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0)))
conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0)
self.resample = nn.Sequential(pad_layer, conv_layer)
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def __call__(self, x, feat_cache=None, feat_idx=[0]):
# The __call__ method logic remains unchanged from your original code
b, t, h, w, c = x.shape
if self.mode == 'upsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = 'Rep'
feat_idx[0] += 1
else:
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != 'Rep':
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] == 'Rep':
cache_x = mx.concatenate([
mx.zeros_like(cache_x), cache_x
], axis=1)
if feat_cache[idx] == 'Rep':
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, t, h, w, 2, c)
x = mx.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=2)
x = x.reshape(b, t * 2, h, w, c)
t = x.shape[1]
x = x.reshape(b * t, h, w, c)
x = self.resample(x)
_, h_new, w_new, c_new = x.shape
x = x.reshape(b, t, h_new, w_new, c_new)
if self.mode == 'downsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x
feat_idx[0] += 1
else:
cache_x = x[:, -1:, :, :, :]
x = self.time_conv(
mx.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False),
nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False),
nn.SiLU(),
nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
CausalConv3d(out_dim, out_dim, 3, padding=1)
)
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
def __call__(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
for i, layer in enumerate(self.residual.layers):
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + h
class AttentionBlock(nn.Module):
"""
Causal self-attention with a single head.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = RMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
# zero out the last layer params
self.proj.weight = mx.zeros_like(self.proj.weight)
def __call__(self, x):
# x is in BTHWC format
identity = x
b, t, h, w, c = x.shape
x = x.reshape(b * t, h, w, c) # Combine batch and time
x = self.norm(x)
# compute query, key, value
qkv = self.to_qkv(x) # Output: (b*t, h, w, 3*c)
qkv = qkv.reshape(b * t, h * w, 3 * c)
q, k, v = mx.split(qkv, 3, axis=-1)
# Reshape for attention
q = q.reshape(b * t, h * w, c)
k = k.reshape(b * t, h * w, c)
v = v.reshape(b * t, h * w, c)
# Scaled dot product attention
scale = 1.0 / mx.sqrt(mx.array(c, dtype=q.dtype))
scores = (q @ k.transpose(0, 2, 1)) * scale
weights = mx.softmax(scores, axis=-1)
x = weights @ v
x = x.reshape(b * t, h, w, c)
# output
x = self.proj(x)
x = x.reshape(b, t, h, w, c)
return x + identity
class Encoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
downsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = 'downsample3d' if temperal_downsample[i] else 'downsample2d'
downsamples.append(Resample(out_dim, mode=mode))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[-1], dims[-1], dropout),
AttentionBlock(dims[-1]),
ResidualBlock(dims[-1], dims[-1], dropout)
)
# output blocks
self.head = nn.Sequential(
RMS_norm(dims[-1], images=False),
nn.SiLU(),
CausalConv3d(dims[-1], z_dim, 3, padding=1)
)
def __call__(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## downsamples
for i, layer in enumerate(self.downsamples.layers):
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle.layers:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for i, layer in enumerate(self.head.layers):
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
class Decoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout)
)
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i == 1 or i == 2 or i == 3:
in_dim = in_dim // 2
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
upsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# upsample block
if i != len(dim_mult) - 1:
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(dims[-1], images=False),
nn.SiLU(),
CausalConv3d(dims[-1], 3, 3, padding=1)
)
def __call__(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## middle
for layer in self.middle.layers:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples.layers:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for i, layer in enumerate(self.head.layers):
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_conv3d(model):
count = 0
for name, module in model.named_modules():
if isinstance(module, CausalConv3d):
count += 1
return count
class WanVAE_(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def encode(self, x, scale):
# x is in BTHWC format
self.clear_cache()
## cache
t = x.shape[1]
iter_ = 1 + (t - 1) // 4
## Split encode input x by time into 1, 4, 4, 4....
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :1, :, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(
x[:, 1 + 4 * (i - 1):1 + 4 * i, :, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = mx.concatenate([out, out_], axis=1)
z = self.conv1(out)
mu, log_var = mx.split(z, 2, axis=-1) # Split along channel dimension
if isinstance(scale[0], mx.array):
# Reshape scale for broadcasting in BTHWC format
scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim)
scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim)
mu = (mu - scale_mean) * scale_std
else:
mu = (mu - scale[0]) * scale[1]
self.clear_cache()
return mu, log_var
def decode(self, z, scale):
# z is in BTHWC format
self.clear_cache()
if isinstance(scale[0], mx.array):
scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim)
scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim)
z = z / scale_std + scale_mean
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[1]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, i:i + 1, :, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_ = self.decoder(
x[:, i:i + 1, :, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = mx.concatenate([out, out_], axis=1)
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
std = mx.exp(0.5 * log_var)
eps = mx.random.normal(std.shape)
return eps * std + mu
def __call__(self, x):
mu, log_var = self.encode(x, self.scale)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z, self.scale)
return x_recon, mu, log_var
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs, self.scale)
if deterministic:
return mu
std = mx.exp(0.5 * mx.clip(log_var, -30.0, 20.0))
return mu + std * mx.random.normal(std.shape)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
#cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=None, **kwargs):
"""
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
"""
# params
cfg = dict(
dim=96,
z_dim=z_dim,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[False, True, True],
dropout=0.0)
cfg.update(**kwargs)
# init model
model = WanVAE_(**cfg)
# load checkpoint
if pretrained_path:
logging.info(f'loading {pretrained_path}')
weights = mx.load(pretrained_path)
model.update(tree_unflatten(list(weights.items())))
return model
class WanVAE:
def __init__(self,
z_dim=16,
vae_pth='cache/vae_step_411000.pth',
dtype=mx.float32):
self.dtype = dtype
mean = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]
std = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]
self.mean = mx.array(mean, dtype=dtype)
self.std = mx.array(std, dtype=dtype)
self.scale = [self.mean, 1.0 / self.std]
# init model
self.model = _video_vae(
pretrained_path=vae_pth,
z_dim=z_dim,
)
def encode(self, videos):
"""
videos: A list of videos each with shape [C, T, H, W].
Returns: List of encoded videos in [C, T, H, W] format.
"""
encoded = []
for video in videos:
# Convert CTHW -> BTHWC
x = mx.expand_dims(video, axis=0) # Add batch dimension
x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC
# Encode
z = self.model.encode(x, self.scale)[0] # Get mu only
# Convert back BTHWC -> CTHW and remove batch dimension
z = z.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW
z = z.squeeze(0) # Remove batch dimension -> CTHW
encoded.append(z.astype(mx.float32))
return encoded
def decode(self, zs):
"""
zs: A list of latent codes each with shape [C, T, H, W].
Returns: List of decoded videos in [C, T, H, W] format.
"""
decoded = []
for z in zs:
# Convert CTHW -> BTHWC
x = mx.expand_dims(z, axis=0) # Add batch dimension
x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC
# Decode
x = self.model.decode(x, self.scale)
# Convert back BTHWC -> CTHW and remove batch dimension
x = x.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW
x = x.squeeze(0) # Remove batch dimension -> CTHW
# Clamp values
x = mx.clip(x, -1, 1)
decoded.append(x.astype(mx.float32))
return decoded

View File

@ -0,0 +1,170 @@
# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['XLMRoberta', 'xlm_roberta_large']
class SelfAttention(nn.Module):
def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.eps = eps
# layers
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
# compute attention
p = self.dropout.p if self.training else 0.0
x = F.scaled_dot_product_attention(q, k, v, mask, p)
x = x.permute(0, 2, 1, 3).reshape(b, s, c)
# output
x = self.o(x)
x = self.dropout(x)
return x
class AttentionBlock(nn.Module):
def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.post_norm = post_norm
self.eps = eps
# layers
self.attn = SelfAttention(dim, num_heads, dropout, eps)
self.norm1 = nn.LayerNorm(dim, eps=eps)
self.ffn = nn.Sequential(
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
nn.Dropout(dropout))
self.norm2 = nn.LayerNorm(dim, eps=eps)
def forward(self, x, mask):
if self.post_norm:
x = self.norm1(x + self.attn(x, mask))
x = self.norm2(x + self.ffn(x))
else:
x = x + self.attn(self.norm1(x), mask)
x = x + self.ffn(self.norm2(x))
return x
class XLMRoberta(nn.Module):
"""
XLMRobertaModel with no pooler and no LM head.
"""
def __init__(self,
vocab_size=250002,
max_seq_len=514,
type_size=1,
pad_id=1,
dim=1024,
num_heads=16,
num_layers=24,
post_norm=True,
dropout=0.1,
eps=1e-5):
super().__init__()
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
self.type_size = type_size
self.pad_id = pad_id
self.dim = dim
self.num_heads = num_heads
self.num_layers = num_layers
self.post_norm = post_norm
self.eps = eps
# embeddings
self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
self.type_embedding = nn.Embedding(type_size, dim)
self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
self.dropout = nn.Dropout(dropout)
# blocks
self.blocks = nn.ModuleList([
AttentionBlock(dim, num_heads, post_norm, dropout, eps)
for _ in range(num_layers)
])
# norm layer
self.norm = nn.LayerNorm(dim, eps=eps)
def forward(self, ids):
"""
ids: [B, L] of torch.LongTensor.
"""
b, s = ids.shape
mask = ids.ne(self.pad_id).long()
# embeddings
x = self.token_embedding(ids) + \
self.type_embedding(torch.zeros_like(ids)) + \
self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
if self.post_norm:
x = self.norm(x)
x = self.dropout(x)
# blocks
mask = torch.where(
mask.view(b, 1, 1, s).gt(0), 0.0,
torch.finfo(x.dtype).min)
for block in self.blocks:
x = block(x, mask)
# output
if not self.post_norm:
x = self.norm(x)
return x
def xlm_roberta_large(pretrained=False,
return_tokenizer=False,
device='cpu',
**kwargs):
"""
XLMRobertaLarge adapted from Huggingface.
"""
# params
cfg = dict(
vocab_size=250002,
max_seq_len=514,
type_size=1,
pad_id=1,
dim=1024,
num_heads=16,
num_layers=24,
post_norm=True,
dropout=0.1,
eps=1e-5)
cfg.update(**kwargs)
# init a model on device
with torch.device(device):
model = XLMRoberta(**cfg)
return model

View File

@ -0,0 +1,310 @@
import json
from typing import Optional, List, Tuple
import mlx.core as mx
from mlx.utils import tree_unflatten
from safetensors import safe_open
import torch
def check_safetensors_dtypes(safetensors_path: str):
"""
Check what dtypes are in the safetensors file.
Useful for debugging dtype issues.
"""
print(f"🔍 Checking dtypes in: {safetensors_path}")
dtype_counts = {}
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
for key in f.keys():
tensor = f.get_tensor(key)
dtype_str = str(tensor.dtype)
if dtype_str not in dtype_counts:
dtype_counts[dtype_str] = []
dtype_counts[dtype_str].append(key)
print("📊 Dtype summary:")
for dtype, keys in dtype_counts.items():
print(f" {dtype}: {len(keys)} parameters")
if dtype == "torch.bfloat16":
print(f" ⚠️ BFloat16 detected - will convert to float32")
print(f" Examples: {keys[:3]}")
return dtype_counts
def convert_tensor_dtype(tensor: torch.Tensor) -> torch.Tensor:
"""
Convert tensor to MLX-compatible dtype.
"""
if tensor.dtype == torch.bfloat16:
# Convert BFloat16 to float32
return tensor.float()
elif tensor.dtype == torch.float64:
# Convert float64 to float32 for efficiency
return tensor.float()
else:
# Keep other dtypes as-is
return tensor
def map_t5_encoder_weights(key: str, value: mx.array) -> List[Tuple[str, mx.array]]:
"""
Map T5 encoder weights from PyTorch format to MLX format.
Following the pattern used in MLX Stable Diffusion.
Args:
key: Parameter name from PyTorch model
value: Parameter tensor
Returns:
List of (key, value) tuples for MLX model
"""
# Handle the main structural difference: FFN gate layer
if ".ffn.gate.0.weight" in key:
# PyTorch has Sequential(Linear, GELU) but MLX has separate gate_proj + gate_act
key = key.replace(".ffn.gate.0.weight", ".ffn.gate_proj.weight")
return [(key, value)]
elif ".ffn.gate.0.bias" in key:
# Handle bias if it exists
key = key.replace(".ffn.gate.0.bias", ".ffn.gate_proj.bias")
return [(key, value)]
elif ".ffn.gate.1" in key:
# Skip GELU activation parameters - MLX handles this separately
print(f"Skipping GELU parameter: {key}")
return []
# Handle any other potential FFN mappings
elif ".ffn.fc1.weight" in key:
return [(key, value)]
elif ".ffn.fc2.weight" in key:
return [(key, value)]
# Handle attention layers (should be direct mapping)
elif ".attn.q.weight" in key:
return [(key, value)]
elif ".attn.k.weight" in key:
return [(key, value)]
elif ".attn.v.weight" in key:
return [(key, value)]
elif ".attn.o.weight" in key:
return [(key, value)]
# Handle embeddings and norms (direct mapping)
elif "token_embedding.weight" in key:
return [(key, value)]
elif "pos_embedding.embedding.weight" in key:
return [(key, value)]
elif "norm1.weight" in key or "norm2.weight" in key or "norm.weight" in key:
return [(key, value)]
# Default: direct mapping for any other parameters
else:
return [(key, value)]
def _flatten(params: List[List[Tuple[str, mx.array]]]) -> List[Tuple[str, mx.array]]:
"""Flatten nested list of parameter tuples"""
return [(k, v) for p in params for (k, v) in p]
def _load_safetensor_weights(
mapper_func,
model,
weight_file: str,
float16: bool = False
):
"""
Load safetensor weights using the mapping function.
Based on MLX SD pattern.
"""
dtype = mx.float16 if float16 else mx.float32
# Load weights from safetensors file
weights = {}
with safe_open(weight_file, framework="pt", device="cpu") as f:
for key in f.keys():
tensor = f.get_tensor(key)
# Handle BFloat16 - convert to float32 first
if tensor.dtype == torch.bfloat16:
print(f"Converting BFloat16 to float32 for: {key}")
tensor = tensor.float() # Convert to float32
weights[key] = mx.array(tensor.numpy()).astype(dtype)
# Apply mapping function
mapped_weights = _flatten([mapper_func(k, v) for k, v in weights.items()])
# Update model with mapped weights
model.update(tree_unflatten(mapped_weights))
return model
def load_t5_encoder_from_safetensors(
safetensors_path: str,
model, # Your MLX T5Encoder instance
float16: bool = False
):
"""
Load T5 encoder weights from safetensors file into MLX model.
Args:
safetensors_path: Path to the safetensors file
model: Your MLX T5Encoder model instance
float16: Whether to use float16 precision
Returns:
Model with loaded weights
"""
print(f"Loading T5 encoder weights from: {safetensors_path}")
# Load and map weights
model = _load_safetensor_weights(
map_t5_encoder_weights,
model,
safetensors_path,
float16
)
print("T5 encoder weights loaded successfully!")
return model
def debug_weight_mapping(safetensors_path: str, float16: bool = False):
"""
Debug function to see how weights are being mapped.
Useful for troubleshooting conversion issues.
"""
dtype = mx.float16 if float16 else mx.float32
print("=== T5 Weight Mapping Debug ===")
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
for key in f.keys():
tensor = f.get_tensor(key)
# Handle BFloat16
original_dtype = tensor.dtype
if tensor.dtype == torch.bfloat16:
print(f"Converting BFloat16 to float32 for: {key}")
tensor = tensor.float()
value = mx.array(tensor.numpy()).astype(dtype)
# Apply mapping
mapped = map_t5_encoder_weights(key, value)
if len(mapped) == 0:
print(f"SKIPPED: {key} ({original_dtype}) -> (no mapping)")
elif len(mapped) == 1:
new_key, new_value = mapped[0]
if new_key == key:
print(f"DIRECT: {key} ({original_dtype}) [{tensor.shape}]")
else:
print(f"MAPPED: {key} ({original_dtype}) -> {new_key} [{tensor.shape}]")
else:
print(f"SPLIT: {key} ({original_dtype}) -> {len(mapped)} parameters")
for new_key, new_value in mapped:
print(f" -> {new_key} [{new_value.shape}]")
def convert_safetensors_to_mlx_weights(
safetensors_path: str,
output_path: str,
float16: bool = False
):
"""
Convert safetensors file to MLX weights file (.npz format).
Args:
safetensors_path: Input safetensors file
output_path: Output MLX weights file (.npz)
float16: Whether to use float16 precision
"""
dtype = mx.float16 if float16 else mx.float32
print(f"Converting safetensors to MLX format...")
print(f"Input: {safetensors_path}")
print(f"Output: {output_path}")
print(f"Target dtype: {dtype}")
# Load and convert weights
weights = {}
bfloat16_count = 0
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
for key in f.keys():
tensor = f.get_tensor(key)
# Handle BFloat16
# if tensor.dtype == torch.bfloat16:
# bfloat16_count += 1
# tensor = tensor.float() # Convert to float32 first
value = mx.array(tensor.numpy())#.astype(dtype)
# Apply mapping
mapped = map_t5_encoder_weights(key, value)
for new_key, new_value in mapped:
weights[new_key] = new_value
if bfloat16_count > 0:
print(f"⚠️ Converted {bfloat16_count} BFloat16 tensors to float32")
# Save as MLX format
print(f"Saving {len(weights)} parameters to: {output_path}")
mx.save_safetensors(output_path, weights)
return weights
# Example usage functions
def example_usage():
"""Example of how to use the converter with BFloat16 handling"""
safetensors_file = "Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.safetensors"
# Step 1: Check dtypes first
print("=== Step 1: Check dtypes ===")
check_safetensors_dtypes(safetensors_file)
# Step 2: Debug the mapping
print("\n=== Step 2: Debug weight mapping ===")
debug_weight_mapping(safetensors_file)
# Step 3: Convert to MLX weights file
print("\n=== Step 3: Convert to MLX ===")
weights = convert_safetensors_to_mlx_weights(
safetensors_file,
"Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.safetensors",
float16=False # Use float32 to avoid precision loss from BFloat16
)
# Step 4: Load into MLX model (example)
print("\n=== Step 4: Load into MLX model ===")
# model = T5Encoder # Your MLX model
# model = load_t5_encoder_from_safetensors(
# safetensors_file,
# model,
# float16=False
# )
return weights
if __name__ == "__main__":
# Run debug to see mappings
# debug_weight_mapping("Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.safetensors")
example_usage()
# Or convert weights
# convert_safetensors_to_mlx_weights("your_model.safetensors", "your_model_mlx.npz")
print("T5 converter ready!")

View File

@ -0,0 +1,278 @@
import os
import torch
from safetensors.torch import save_file
from pathlib import Path
import json
from wan.modules.t5_mlx import T5Model
def convert_pickle_to_safetensors(
pickle_path: str,
safetensors_path: str,
model_class=None,
model_kwargs=None,
load_method: str = "weights_only" # Changed default to weights_only
):
"""Convert PyTorch pickle file to safetensors format."""
print(f"Loading PyTorch weights from: {pickle_path}")
# Try multiple loading methods in order of preference
methods_to_try = [load_method, "weights_only", "state_dict", "full_model"]
for method in methods_to_try:
try:
if method == "weights_only":
state_dict = torch.load(pickle_path, map_location='cpu', weights_only=True)
elif method == "state_dict":
checkpoint = torch.load(pickle_path, map_location='cpu')
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif isinstance(checkpoint, dict) and 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
elif method == "full_model":
model = torch.load(pickle_path, map_location='cpu')
if hasattr(model, 'state_dict'):
state_dict = model.state_dict()
else:
state_dict = model
print(f"✅ Successfully loaded with method: {method}")
break
except Exception as e:
print(f"❌ Method {method} failed: {e}")
continue
else:
raise RuntimeError(f"All loading methods failed for {pickle_path}")
# Clean up the state dict
state_dict = clean_state_dict(state_dict)
print(f"Found {len(state_dict)} parameters")
# Convert BF16 to FP32 if needed
for key, tensor in state_dict.items():
if tensor.dtype == torch.bfloat16:
state_dict[key] = tensor.to(torch.float32)
print(f"Converted {key} from bfloat16 to float32")
# Save as safetensors
print(f"Saving to safetensors: {safetensors_path}")
os.makedirs(os.path.dirname(safetensors_path), exist_ok=True)
save_file(state_dict, safetensors_path)
print("✅ T5 conversion complete!")
return state_dict
def clean_state_dict(state_dict):
"""
Clean up state dict by removing unwanted prefixes or keys.
"""
cleaned = {}
for key, value in state_dict.items():
# Remove common prefixes that might interfere
clean_key = key
if clean_key.startswith('module.'):
clean_key = clean_key[7:]
if clean_key.startswith('model.'):
clean_key = clean_key[6:]
cleaned[clean_key] = value
if clean_key != key:
print(f"Cleaned key: {key} -> {clean_key}")
return cleaned
def load_with_your_torch_model(pickle_path: str, model_class, **model_kwargs):
"""
Load pickle weights into your specific PyTorch T5 model implementation.
Args:
pickle_path: Path to pickle file
model_class: Your T5Encoder class
**model_kwargs: Arguments for your model constructor
"""
print("Method 1: Loading into your PyTorch T5 model")
# Initialize your model
model = model_class(**model_kwargs)
# Load checkpoint
checkpoint = torch.load(pickle_path, map_location='cpu')
# Handle different checkpoint formats
if isinstance(checkpoint, dict):
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
# Assume the dict IS the state dict
state_dict = checkpoint
else:
# Assume it's a model object
state_dict = checkpoint.state_dict()
# Clean and load
state_dict = clean_state_dict(state_dict)
model.load_state_dict(state_dict, strict=False) # Use strict=False to ignore missing keys
return model, state_dict
def explore_pickle_file(pickle_path: str):
"""
Explore the contents of a pickle file to understand its structure.
"""
print(f"🔍 Exploring pickle file: {pickle_path}")
try:
# Try loading with weights_only first (safer)
print("\n--- Trying weights_only=True ---")
try:
data = torch.load(pickle_path, map_location='cpu', weights_only=True)
print(f"✅ Loaded with weights_only=True")
print(f"Type: {type(data)}")
if isinstance(data, dict):
print(f"Dictionary with {len(data)} keys:")
for i, key in enumerate(data.keys()):
print(f" {key}: {type(data[key])}")
if hasattr(data[key], 'shape'):
print(f" Shape: {data[key].shape}")
if i >= 9: # Show first 10 keys
break
except Exception as e:
print(f"❌ weights_only=True failed: {e}")
# Try regular loading
print("\n--- Trying regular loading ---")
data = torch.load(pickle_path, map_location='cpu')
print(f"✅ Loaded successfully")
print(f"Type: {type(data)}")
if hasattr(data, 'state_dict'):
print("📋 Found state_dict method")
state_dict = data.state_dict()
print(f"State dict has {len(state_dict)} parameters")
elif isinstance(data, dict):
print(f"📋 Dictionary with keys: {list(data.keys())}")
# Check for common checkpoint keys
if 'state_dict' in data:
print("Found 'state_dict' key")
print(f"state_dict has {len(data['state_dict'])} parameters")
elif 'model' in data:
print("Found 'model' key")
print(f"model has {len(data['model'])} parameters")
except Exception as e:
print(f"❌ Failed to load: {e}")
def full_conversion_pipeline(
pickle_path: str,
safetensors_path: str,
torch_model_class=None,
model_kwargs=None
):
"""
Complete pipeline: pickle -> safetensors -> ready for MLX conversion
"""
print("🚀 Starting full conversion pipeline")
print("="*50)
# Step 1: Explore the pickle file
print("Step 1: Exploring pickle file structure")
explore_pickle_file(pickle_path)
# Step 2: Convert to safetensors
print(f"\nStep 2: Converting to safetensors")
# Try different loading methods
for method in ["weights_only", "state_dict", "full_model"]:
try:
print(f"\nTrying load method: {method}")
state_dict = convert_pickle_to_safetensors(
pickle_path,
safetensors_path,
model_class=torch_model_class,
model_kwargs=model_kwargs,
load_method=method
)
print(f"✅ Success with method: {method}")
break
except Exception as e:
print(f"❌ Method {method} failed: {e}")
continue
else:
print("❌ All methods failed!")
return None
print(f"\n🎉 Conversion complete!")
print(f"Safetensors file saved to: {safetensors_path}")
print(f"Ready for MLX conversion using the previous script!")
return state_dict
# Example usage
def example_usage():
"""
Example of how to use the conversion functions
"""
# Your model class and parameters
# class YourT5Encoder(nn.Module):
# def __init__(self, vocab_size, d_model, ...):
# ...
pickle_file = "Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth"
safetensors_file = "Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.safetensors"
# Method 1: Quick exploration
print("=== Exploring pickle file ===")
explore_pickle_file(pickle_file)
# Method 2: Full pipeline
print("\n=== Full conversion pipeline ===")
state_dict = full_conversion_pipeline(
pickle_file,
safetensors_file,
torch_model_class=T5Model, # Your model class
model_kwargs={
'vocab_size': 256384,
'd_model': 4096,
'num_layers': 24,
# ... other parameters
}
)
# Method 3: Direct conversion (if you know the format)
print("\n=== Direct conversion ===")
# state_dict = convert_pickle_to_safetensors(
# pickle_file,
# safetensors_file,
# load_method="state_dict" # or "weights_only" or "full_model"
# )
if __name__ == "__main__":
example_usage()

View File

@ -0,0 +1,312 @@
import glob
import gc
import logging
import math
import os
import random
import sys
from tqdm import tqdm
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .modules.model_mlx import WanModel
from .modules.t5_mlx import T5EncoderModel
from .modules.vae_mlx import WanVAE
from .utils.fm_solvers_mlx import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
from .utils.fm_solvers_unipc_mlx import FlowUniPCMultistepScheduler
from .wan_model_io import load_wan_from_safetensors
class WanT2V:
def __init__(
self,
config,
checkpoint_dir,
):
r"""
Initializes the Wan text-to-video generation model components.
Args:
config (EasyDict):
Object containing model parameters initialized from config.py
checkpoint_dir (`str`):
Path to directory containing model checkpoints
"""
self.config = config
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = mx.float16 if config.param_dtype == 'float16' else mx.float32
# Initialize T5 text encoder - with automatic conversion
t5_checkpoint_path = os.path.join(checkpoint_dir, config.t5_checkpoint)
mlx_t5_path = t5_checkpoint_path.replace('.safetensors', '_mlx.safetensors')
if not os.path.exists(mlx_t5_path):
# Check if it's a .pth file that needs conversion
pth_path = t5_checkpoint_path.replace('.safetensors', '.pth')
if os.path.exists(pth_path):
logging.info(f"Converting T5 PyTorch model to safetensors: {pth_path}")
from .t5_torch_to_sf import convert_pickle_to_safetensors
convert_pickle_to_safetensors(pth_path, t5_checkpoint_path, load_method="weights_only")
# Convert torch safetensors to MLX safetensors
from .t5_model_io import convert_safetensors_to_mlx_weights
convert_safetensors_to_mlx_weights(t5_checkpoint_path, mlx_t5_path, float16=(self.param_dtype == mx.float16))
else:
raise FileNotFoundError(f"T5 checkpoint not found: {t5_checkpoint_path} or {pth_path}")
t5_checkpoint_path = mlx_t5_path # Use the MLX version
logging.info(f"Loading T5 text encoder... from {t5_checkpoint_path}")
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
checkpoint_path=t5_checkpoint_path,
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer))
# Initialize VAE
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
# Initialize VAE - with automatic conversion
vae_path = os.path.join(checkpoint_dir, config.vae_checkpoint)
if not os.path.exists(vae_path):
# Check for PyTorch VAE file to convert
pth_vae_path = vae_path.replace('_mlx.safetensors', '.pth')
if not os.path.exists(pth_vae_path):
# Try alternative naming
pth_vae_path = os.path.join(checkpoint_dir, 'Wan2.1_VAE.pth')
if os.path.exists(pth_vae_path):
logging.info(f"Converting VAE PyTorch model to MLX: {pth_vae_path}")
from .vae_model_io import convert_pytorch_to_mlx
convert_pytorch_to_mlx(pth_vae_path, vae_path, float16=(self.param_dtype == mx.float16))
else:
raise FileNotFoundError(f"VAE checkpoint not found: {vae_path} or {pth_vae_path}")
logging.info("Loading VAE...")
self.vae = WanVAE(vae_pth=vae_path)
# Initialize WanModel
logging.info(f"Creating WanModel from {checkpoint_dir}")
# Create model with config parameters
self.model = WanModel(
model_type='t2v',
patch_size=config.patch_size,
text_len=config.text_len,
in_dim=16,
dim=config.dim,
ffn_dim=config.ffn_dim,
freq_dim=config.freq_dim,
text_dim=4096,
out_dim=16,
num_heads=config.num_heads,
num_layers=config.num_layers,
window_size=getattr(config, 'window_size', (-1, -1)),
qk_norm=getattr(config, 'qk_norm', True),
cross_attn_norm=getattr(config, 'cross_attn_norm', True),
eps=getattr(config, 'eps', 1e-6)
)
# In WanT2V.__init__, replace the model loading section with:
# Load pretrained weights - with automatic conversion
model_path = os.path.join(checkpoint_dir, "diffusion_pytorch_model_mlx.safetensors")
if not os.path.exists(model_path):
# Check for directory with multiple files (14B model)
pattern = os.path.join(checkpoint_dir, "diffusion_mlx_model*.safetensors")
mlx_files = glob.glob(pattern)
if not mlx_files:
# No MLX files found, look for PyTorch files to convert
pytorch_path = os.path.join(checkpoint_dir, "diffusion_pytorch_model.safetensors")
pytorch_pattern = os.path.join(checkpoint_dir, "diffusion_pytorch_model-*.safetensors")
pytorch_files = glob.glob(pytorch_pattern)
if os.path.exists(pytorch_path):
logging.info(f"Converting single PyTorch model to MLX: {pytorch_path}")
from .wan_model_io import convert_safetensors_to_mlx_weights
convert_safetensors_to_mlx_weights(
pytorch_path,
model_path,
float16=(self.param_dtype == mx.float16)
)
elif pytorch_files:
logging.info(f"Converting {len(pytorch_files)} PyTorch model files to MLX")
from .wan_model_io import convert_multiple_safetensors_to_mlx
convert_multiple_safetensors_to_mlx(
checkpoint_dir,
float16=(self.param_dtype == mx.float16)
)
else:
raise FileNotFoundError(f"No PyTorch model files found in {checkpoint_dir}")
# Load the model (now MLX format exists)
if os.path.exists(model_path):
# Single file (1.3B)
logging.info(f"Loading model weights from {model_path}")
self.model = load_wan_from_safetensors(model_path, self.model, float16=(self.param_dtype == mx.float16))
else:
# Multiple files (14B)
logging.info(f"Loading model weights from directory {checkpoint_dir}")
self.model = load_wan_from_safetensors(checkpoint_dir, self.model, float16=(self.param_dtype == mx.float16))
# Set model to eval mode
self.model.eval()
self.sp_size = 1 # No sequence parallelism in MLX version
self.sample_neg_prompt = config.sample_neg_prompt
def generate(self,
input_prompt,
size=(1280, 720),
frame_num=81,
shift=5.0,
sample_solver='unipc',
sampling_steps=50,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True):
r"""
Generates video frames from text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation
size (tuple[`int`], *optional*, defaults to (1280,720)):
Controls video resolution, (width,height).
frame_num (`int`, *optional*, defaults to 81):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 50):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float`, *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed.
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save memory
Returns:
mx.array:
Generated video frames tensor. Dimensions: (C, N, H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames
- H: Frame height (from size)
- W: Frame width (from size)
"""
# Preprocess
F = frame_num
target_shape = (
self.vae.model.z_dim,
(F - 1) // self.vae_stride[0] + 1,
size[1] // self.vae_stride[1],
size[0] // self.vae_stride[2]
)
seq_len = math.ceil(
(target_shape[2] * target_shape[3]) /
(self.patch_size[1] * self.patch_size[2]) *
target_shape[1] / self.sp_size
) * self.sp_size
if n_prompt == "":
n_prompt = self.sample_neg_prompt
# Set random seed
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
mx.random.seed(seed)
# Encode text prompts
logging.info("Encoding text prompts...")
context = self.text_encoder([input_prompt])
context_null = self.text_encoder([n_prompt])
# Generate initial noise
noise = [
mx.random.normal(
shape=target_shape,
dtype=mx.float32
)
]
# Initialize scheduler
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False
)
sample_scheduler.set_timesteps(
sampling_steps, shift=shift
)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False
)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
sigmas=sampling_sigmas
)
else:
raise NotImplementedError(f"Unsupported solver: {sample_solver}")
# Sample videos
latents = noise
arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
logging.info(f"Generating video with {len(timesteps)} steps...")
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = latents
timestep = mx.array([t])
# Model predictions
noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c
)[0]
noise_pred_uncond = self.model(
latent_model_input, t=timestep, **arg_null
)[0]
# Classifier-free guidance
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond
)
# Scheduler step
temp_x0 = sample_scheduler.step(
mx.expand_dims(noise_pred, 0),
t,
mx.expand_dims(latents[0], 0),
return_dict=False
)[0]
latents = [mx.squeeze(temp_x0, 0)]
mx.eval(latents)
x0 = latents
# Decode latents to video
logging.info("Decoding latents to video...")
videos = self.vae.decode(x0)
# Memory cleanup
del noise, latents, sample_scheduler
if offload_model:
mx.eval(videos) # Ensure computation is complete
gc.collect()
return videos[0]

View File

@ -0,0 +1,8 @@
from .fm_solvers_mlx import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,
retrieve_timesteps)
from .fm_solvers_unipc_mlx import FlowUniPCMultistepScheduler
__all__ = [
'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler'
]

View File

@ -0,0 +1,562 @@
import math
from typing import List, Optional, Tuple, Union
import mlx.core as mx
import numpy as np
def get_sampling_sigmas(sampling_steps, shift):
sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
sigma = (shift * sigma / (1 + (shift - 1) * sigma))
return sigma
def retrieve_timesteps(
scheduler,
num_inference_steps=None,
device=None,
timesteps=None,
sigmas=None,
**kwargs,
):
if timesteps is not None and sigmas is not None:
raise ValueError(
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
)
if timesteps is not None:
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class SchedulerOutput:
"""Output class for scheduler step results."""
def __init__(self, prev_sample: mx.array):
self.prev_sample = prev_sample
class FlowDPMSolverMultistepScheduler:
"""
MLX implementation of FlowDPMSolverMultistepScheduler.
A fast dedicated high-order solver for diffusion ODEs.
"""
order = 1
def __init__(
self,
num_train_timesteps: int = 1000,
solver_order: int = 2,
prediction_type: str = "flow_prediction",
shift: Optional[float] = 1.0,
use_dynamic_shifting: bool = False,
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
final_sigmas_type: Optional[str] = "zero",
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
invert_sigmas: bool = False,
):
# Store configuration
self.config = {
'num_train_timesteps': num_train_timesteps,
'solver_order': solver_order,
'prediction_type': prediction_type,
'shift': shift,
'use_dynamic_shifting': use_dynamic_shifting,
'thresholding': thresholding,
'dynamic_thresholding_ratio': dynamic_thresholding_ratio,
'sample_max_value': sample_max_value,
'algorithm_type': algorithm_type,
'solver_type': solver_type,
'lower_order_final': lower_order_final,
'euler_at_final': euler_at_final,
'final_sigmas_type': final_sigmas_type,
'lambda_min_clipped': lambda_min_clipped,
'variance_type': variance_type,
'invert_sigmas': invert_sigmas,
}
# Validate algorithm type
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
if algorithm_type == "deis":
self.config['algorithm_type'] = "dpmsolver++"
else:
raise NotImplementedError(f"{algorithm_type} is not implemented")
# Validate solver type
if solver_type not in ["midpoint", "heun"]:
if solver_type in ["logrho", "bh1", "bh2"]:
self.config['solver_type'] = "midpoint"
else:
raise NotImplementedError(f"{solver_type} is not implemented")
# Initialize scheduling
self.num_inference_steps = None
alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy()
sigmas = 1.0 - alphas
sigmas = mx.array(sigmas, dtype=mx.float32)
if not use_dynamic_shifting:
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.sigmas = sigmas
self.timesteps = sigmas * num_train_timesteps
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self._begin_index = None
self.sigma_min = float(self.sigmas[-1])
self.sigma_max = float(self.sigmas[0])
@property
def step_index(self):
return self._step_index
@property
def begin_index(self):
return self._begin_index
def set_begin_index(self, begin_index: int = 0):
self._begin_index = begin_index
def set_timesteps(
self,
num_inference_steps: Union[int, None] = None,
device: Union[str, None] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[Union[float, None]] = None,
shift: Optional[Union[float, None]] = None,
):
"""Sets the discrete timesteps used for the diffusion chain."""
if self.config['use_dynamic_shifting'] and mu is None:
raise ValueError(
"you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
)
if sigmas is None:
sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1]
if self.config['use_dynamic_shifting']:
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
if shift is None:
shift = self.config['shift']
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
if self.config['final_sigmas_type'] == "sigma_min":
sigma_last = self.sigma_min
elif self.config['final_sigmas_type'] == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config['final_sigmas_type']}"
)
timesteps = sigmas * self.config['num_train_timesteps']
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = mx.array(sigmas)
self.timesteps = mx.array(timesteps, dtype=mx.int64)
self.num_inference_steps = len(timesteps)
self.model_outputs = [None] * self.config['solver_order']
self.lower_order_nums = 0
self._step_index = None
self._begin_index = None
def _threshold_sample(self, sample: mx.array) -> mx.array:
"""Dynamic thresholding method."""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
# Flatten sample for quantile calculation
sample_flat = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = mx.abs(sample_flat)
# Compute quantile
s = mx.quantile(
abs_sample,
self.config['dynamic_thresholding_ratio'],
axis=1,
keepdims=True
)
s = mx.clip(s, 1, self.config['sample_max_value'])
# Threshold and normalize
sample_flat = mx.clip(sample_flat, -s, s) / s
sample = sample_flat.reshape(batch_size, channels, *remaining_dims)
return sample.astype(dtype)
def _sigma_to_t(self, sigma):
return sigma * self.config['num_train_timesteps']
def _sigma_to_alpha_sigma_t(self, sigma):
return 1 - sigma, sigma
def time_shift(self, mu: float, sigma: float, t: mx.array):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
def convert_model_output(
self,
model_output: mx.array,
sample: mx.array,
**kwargs,
) -> mx.array:
"""Convert model output to the corresponding type the algorithm needs."""
# DPM-Solver++ needs to solve an integral of the data prediction model
if self.config['algorithm_type'] in ["dpmsolver++", "sde-dpmsolver++"]:
if self.config['prediction_type'] == "flow_prediction":
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
else:
raise ValueError(
f"prediction_type given as {self.config['prediction_type']} must be "
f"'flow_prediction' for the FlowDPMSolverMultistepScheduler."
)
if self.config['thresholding']:
x0_pred = self._threshold_sample(x0_pred)
return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model
elif self.config['algorithm_type'] in ["dpmsolver", "sde-dpmsolver"]:
if self.config['prediction_type'] == "flow_prediction":
sigma_t = self.sigmas[self.step_index]
epsilon = sample - (1 - sigma_t) * model_output
else:
raise ValueError(
f"prediction_type given as {self.config['prediction_type']} must be "
f"'flow_prediction' for the FlowDPMSolverMultistepScheduler."
)
if self.config['thresholding']:
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
x0_pred = self._threshold_sample(x0_pred)
epsilon = model_output + x0_pred
return epsilon
def dpm_solver_first_order_update(
self,
model_output: mx.array,
sample: mx.array,
noise: Optional[mx.array] = None,
**kwargs,
) -> mx.array:
"""One step for the first-order DPMSolver (equivalent to DDIM)."""
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
lambda_s = mx.log(alpha_s) - mx.log(sigma_s)
h = lambda_t - lambda_s
if self.config['algorithm_type'] == "dpmsolver++":
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (mx.exp(-h) - 1.0)) * model_output
elif self.config['algorithm_type'] == "dpmsolver":
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (mx.exp(h) - 1.0)) * model_output
elif self.config['algorithm_type'] == "sde-dpmsolver++":
assert noise is not None
x_t = (
(sigma_t / sigma_s * mx.exp(-h)) * sample +
(alpha_t * (1 - mx.exp(-2.0 * h))) * model_output +
sigma_t * mx.sqrt(1.0 - mx.exp(-2 * h)) * noise
)
elif self.config['algorithm_type'] == "sde-dpmsolver":
assert noise is not None
x_t = (
(alpha_t / alpha_s) * sample -
2.0 * (sigma_t * (mx.exp(h) - 1.0)) * model_output +
sigma_t * mx.sqrt(mx.exp(2 * h) - 1.0) * noise
)
return x_t
def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[mx.array],
sample: mx.array,
noise: Optional[mx.array] = None,
**kwargs,
) -> mx.array:
"""One step for the second-order multistep DPMSolver."""
sigma_t, sigma_s0, sigma_s1 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0)
lambda_s1 = mx.log(alpha_s1) - mx.log(sigma_s1)
m0, m1 = model_output_list[-1], model_output_list[-2]
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
r0 = h_0 / h
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
if self.config['algorithm_type'] == "dpmsolver++":
if self.config['solver_type'] == "midpoint":
x_t = (
(sigma_t / sigma_s0) * sample -
(alpha_t * (mx.exp(-h) - 1.0)) * D0 -
0.5 * (alpha_t * (mx.exp(-h) - 1.0)) * D1
)
elif self.config['solver_type'] == "heun":
x_t = (
(sigma_t / sigma_s0) * sample -
(alpha_t * (mx.exp(-h) - 1.0)) * D0 +
(alpha_t * ((mx.exp(-h) - 1.0) / h + 1.0)) * D1
)
elif self.config['algorithm_type'] == "dpmsolver":
if self.config['solver_type'] == "midpoint":
x_t = (
(alpha_t / alpha_s0) * sample -
(sigma_t * (mx.exp(h) - 1.0)) * D0 -
0.5 * (sigma_t * (mx.exp(h) - 1.0)) * D1
)
elif self.config['solver_type'] == "heun":
x_t = (
(alpha_t / alpha_s0) * sample -
(sigma_t * (mx.exp(h) - 1.0)) * D0 -
(sigma_t * ((mx.exp(h) - 1.0) / h - 1.0)) * D1
)
elif self.config['algorithm_type'] == "sde-dpmsolver++":
assert noise is not None
if self.config['solver_type'] == "midpoint":
x_t = (
(sigma_t / sigma_s0 * mx.exp(-h)) * sample +
(alpha_t * (1 - mx.exp(-2.0 * h))) * D0 +
0.5 * (alpha_t * (1 - mx.exp(-2.0 * h))) * D1 +
sigma_t * mx.sqrt(1.0 - mx.exp(-2 * h)) * noise
)
elif self.config['solver_type'] == "heun":
x_t = (
(sigma_t / sigma_s0 * mx.exp(-h)) * sample +
(alpha_t * (1 - mx.exp(-2.0 * h))) * D0 +
(alpha_t * ((1.0 - mx.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 +
sigma_t * mx.sqrt(1.0 - mx.exp(-2 * h)) * noise
)
elif self.config['algorithm_type'] == "sde-dpmsolver":
assert noise is not None
if self.config['solver_type'] == "midpoint":
x_t = (
(alpha_t / alpha_s0) * sample -
2.0 * (sigma_t * (mx.exp(h) - 1.0)) * D0 -
(sigma_t * (mx.exp(h) - 1.0)) * D1 +
sigma_t * mx.sqrt(mx.exp(2 * h) - 1.0) * noise
)
elif self.config['solver_type'] == "heun":
x_t = (
(alpha_t / alpha_s0) * sample -
2.0 * (sigma_t * (mx.exp(h) - 1.0)) * D0 -
2.0 * (sigma_t * ((mx.exp(h) - 1.0) / h - 1.0)) * D1 +
sigma_t * mx.sqrt(mx.exp(2 * h) - 1.0) * noise
)
return x_t
def multistep_dpm_solver_third_order_update(
self,
model_output_list: List[mx.array],
sample: mx.array,
**kwargs,
) -> mx.array:
"""One step for the third-order multistep DPMSolver."""
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
self.sigmas[self.step_index - 2],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0)
lambda_s1 = mx.log(alpha_s1) - mx.log(sigma_s1)
lambda_s2 = mx.log(alpha_s2) - mx.log(sigma_s2)
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
r0, r1 = h_0 / h, h_1 / h
D0 = m0
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
if self.config['algorithm_type'] == "dpmsolver++":
x_t = (
(sigma_t / sigma_s0) * sample -
(alpha_t * (mx.exp(-h) - 1.0)) * D0 +
(alpha_t * ((mx.exp(-h) - 1.0) / h + 1.0)) * D1 -
(alpha_t * ((mx.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
)
elif self.config['algorithm_type'] == "dpmsolver":
x_t = (
(alpha_t / alpha_s0) * sample -
(sigma_t * (mx.exp(h) - 1.0)) * D0 -
(sigma_t * ((mx.exp(h) - 1.0) / h - 1.0)) * D1 -
(sigma_t * ((mx.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
)
return x_t
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = mx.where(schedule_timesteps == timestep)[0]
pos = 1 if len(indices) > 1 else 0
return int(indices[pos])
def _init_step_index(self, timestep):
"""Initialize the step_index counter for the scheduler."""
if self.begin_index is None:
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: mx.array,
timestep: Union[int, mx.array],
sample: mx.array,
generator=None,
variance_noise: Optional[mx.array] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""Predict the sample from the previous timestep."""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
# Improve numerical stability for small number of steps
lower_order_final = (
(self.step_index == len(self.timesteps) - 1) and
(self.config['euler_at_final'] or
(self.config['lower_order_final'] and len(self.timesteps) < 15) or
self.config['final_sigmas_type'] == "zero")
)
lower_order_second = (
(self.step_index == len(self.timesteps) - 2) and
self.config['lower_order_final'] and
len(self.timesteps) < 15
)
model_output = self.convert_model_output(model_output, sample=sample)
for i in range(self.config['solver_order'] - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output
# Upcast to avoid precision issues
sample = sample.astype(mx.float32)
# Generate noise if needed for SDE variants
if self.config['algorithm_type'] in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
noise = mx.random.normal(model_output.shape, dtype=mx.float32)
elif self.config['algorithm_type'] in ["sde-dpmsolver", "sde-dpmsolver++"]:
noise = variance_noise.astype(mx.float32)
else:
noise = None
if self.config['solver_order'] == 1 or self.lower_order_nums < 1 or lower_order_final:
prev_sample = self.dpm_solver_first_order_update(
model_output, sample=sample, noise=noise
)
elif self.config['solver_order'] == 2 or self.lower_order_nums < 2 or lower_order_second:
prev_sample = self.multistep_dpm_solver_second_order_update(
self.model_outputs, sample=sample, noise=noise
)
else:
prev_sample = self.multistep_dpm_solver_third_order_update(
self.model_outputs, sample=sample
)
if self.lower_order_nums < self.config['solver_order']:
self.lower_order_nums += 1
# Cast sample back to expected dtype
prev_sample = prev_sample.astype(model_output.dtype)
# Increase step index
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: mx.array, *args, **kwargs) -> mx.array:
"""Scale model input - no scaling needed for this scheduler."""
return sample
def add_noise(
self,
original_samples: mx.array,
noise: mx.array,
timesteps: mx.array,
) -> mx.array:
"""Add noise to original samples."""
sigmas = self.sigmas.astype(original_samples.dtype)
schedule_timesteps = self.timesteps
# Get step indices
if self.begin_index is None:
step_indices = [
self.index_for_timestep(t, schedule_timesteps)
for t in timesteps
]
elif self.step_index is not None:
step_indices = [self.step_index] * timesteps.shape[0]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices]
while len(sigma.shape) < len(original_samples.shape):
sigma = mx.expand_dims(sigma, -1)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
def __len__(self):
return self.config['num_train_timesteps']

View File

@ -0,0 +1,546 @@
import math
from typing import List, Optional, Tuple, Union
import mlx.core as mx
import numpy as np
class SchedulerOutput:
"""Output class for scheduler step results."""
def __init__(self, prev_sample: mx.array):
self.prev_sample = prev_sample
class FlowUniPCMultistepScheduler:
"""
MLX implementation of UniPCMultistepScheduler.
A training-free framework designed for the fast sampling of diffusion models.
"""
order = 1
def __init__(
self,
num_train_timesteps: int = 1000,
solver_order: int = 2,
prediction_type: str = "flow_prediction",
shift: Optional[float] = 1.0,
use_dynamic_shifting: bool = False,
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
predict_x0: bool = True,
solver_type: str = "bh2",
lower_order_final: bool = True,
disable_corrector: List[int] = [],
solver_p = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
final_sigmas_type: Optional[str] = "zero",
):
# Store configuration
self.config = {
'num_train_timesteps': num_train_timesteps,
'solver_order': solver_order,
'prediction_type': prediction_type,
'shift': shift,
'use_dynamic_shifting': use_dynamic_shifting,
'thresholding': thresholding,
'dynamic_thresholding_ratio': dynamic_thresholding_ratio,
'sample_max_value': sample_max_value,
'predict_x0': predict_x0,
'solver_type': solver_type,
'lower_order_final': lower_order_final,
'disable_corrector': disable_corrector,
'solver_p': solver_p,
'timestep_spacing': timestep_spacing,
'steps_offset': steps_offset,
'final_sigmas_type': final_sigmas_type,
}
# Validate solver type
if solver_type not in ["bh1", "bh2"]:
if solver_type in ["midpoint", "heun", "logrho"]:
self.config['solver_type'] = "bh2"
else:
raise NotImplementedError(
f"{solver_type} is not implemented for {self.__class__}"
)
self.predict_x0 = predict_x0
# setable values
self.num_inference_steps = None
alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy()
sigmas = 1.0 - alphas
sigmas = mx.array(sigmas, dtype=mx.float32)
if not use_dynamic_shifting:
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.sigmas = sigmas
self.timesteps = sigmas * num_train_timesteps
self.model_outputs = [None] * solver_order
self.timestep_list = [None] * solver_order
self.lower_order_nums = 0
self.disable_corrector = disable_corrector
self.solver_p = solver_p
self.last_sample = None
self._step_index = None
self._begin_index = None
self.sigma_min = float(self.sigmas[-1])
self.sigma_max = float(self.sigmas[0])
@property
def step_index(self):
"""The index counter for current timestep."""
return self._step_index
@property
def begin_index(self):
"""The index for the first timestep."""
return self._begin_index
def set_begin_index(self, begin_index: int = 0):
"""Sets the begin index for the scheduler."""
self._begin_index = begin_index
def set_timesteps(
self,
num_inference_steps: Union[int, None] = None,
device: Union[str, None] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[Union[float, None]] = None,
shift: Optional[Union[float, None]] = None,
):
"""Sets the discrete timesteps used for the diffusion chain."""
if self.config['use_dynamic_shifting'] and mu is None:
raise ValueError(
"you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
)
if sigmas is None:
sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1]
if self.config['use_dynamic_shifting']:
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
if shift is None:
shift = self.config['shift']
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
if self.config['final_sigmas_type'] == "sigma_min":
sigma_last = self.sigma_min
elif self.config['final_sigmas_type'] == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config['final_sigmas_type']}"
)
timesteps = sigmas * self.config['num_train_timesteps']
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = mx.array(sigmas)
self.timesteps = mx.array(timesteps, dtype=mx.int64)
self.num_inference_steps = len(timesteps)
self.model_outputs = [None] * self.config['solver_order']
self.lower_order_nums = 0
self.last_sample = None
if self.solver_p:
self.solver_p.set_timesteps(self.num_inference_steps, device=device)
# add an index counter for schedulers
self._step_index = None
self._begin_index = None
def _threshold_sample(self, sample: mx.array) -> mx.array:
"""Dynamic thresholding method."""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
# Flatten sample for quantile calculation
sample_flat = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = mx.abs(sample_flat)
# Compute quantile
s = mx.quantile(
abs_sample,
self.config['dynamic_thresholding_ratio'],
axis=1,
keepdims=True
)
s = mx.clip(s, 1, self.config['sample_max_value'])
# Threshold and normalize
sample_flat = mx.clip(sample_flat, -s, s) / s
sample = sample_flat.reshape(batch_size, channels, *remaining_dims)
return sample.astype(dtype)
def _sigma_to_t(self, sigma):
return sigma * self.config['num_train_timesteps']
def _sigma_to_alpha_sigma_t(self, sigma):
return 1 - sigma, sigma
def time_shift(self, mu: float, sigma: float, t: mx.array):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
def convert_model_output(
self,
model_output: mx.array,
sample: mx.array = None,
**kwargs,
) -> mx.array:
"""Convert the model output to the corresponding type the UniPC algorithm needs."""
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
if self.predict_x0:
if self.config['prediction_type'] == "flow_prediction":
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
else:
raise ValueError(
f"prediction_type given as {self.config['prediction_type']} must be 'flow_prediction' "
f"for the UniPCMultistepScheduler."
)
if self.config['thresholding']:
x0_pred = self._threshold_sample(x0_pred)
return x0_pred
else:
if self.config['prediction_type'] == "flow_prediction":
sigma_t = self.sigmas[self.step_index]
epsilon = sample - (1 - sigma_t) * model_output
else:
raise ValueError(
f"prediction_type given as {self.config['prediction_type']} must be 'flow_prediction' "
f"for the UniPCMultistepScheduler."
)
if self.config['thresholding']:
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
x0_pred = self._threshold_sample(x0_pred)
epsilon = model_output + x0_pred
return epsilon
def multistep_uni_p_bh_update(
self,
model_output: mx.array,
sample: mx.array = None,
order: int = None,
**kwargs,
) -> mx.array:
"""One step for the UniP (B(h) version)."""
model_output_list = self.model_outputs
s0 = self.timestep_list[-1]
m0 = model_output_list[-1]
x = sample
if self.solver_p:
x_t = self.solver_p.step(model_output, s0, x).prev_sample
return x_t
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0)
h = lambda_t - lambda_s0
rks = []
D1s = []
for i in range(1, order):
si = self.step_index - i
mi = model_output_list[-(i + 1)]
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
lambda_si = mx.log(alpha_si) - mx.log(sigma_si)
rk = (lambda_si - lambda_s0) / h
rks.append(rk)
D1s.append((mi - m0) / rk)
rks.append(1.0)
rks = mx.array(rks)
R = []
b = []
hh = -h if self.predict_x0 else h
h_phi_1 = mx.exp(hh) - 1 # h\phi_1(h) = e^h - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.config['solver_type'] == "bh1":
B_h = hh
elif self.config['solver_type'] == "bh2":
B_h = mx.exp(hh) - 1
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(mx.power(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= i + 1
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = mx.stack(R)
b = mx.array(b)
if len(D1s) > 0:
D1s = mx.stack(D1s, axis=1) # (B, K)
# for order 2, we use a simplified version
if order == 2:
rhos_p = mx.array([0.5], dtype=x.dtype)
else:
rhos_p = mx.linalg.solve(R[:-1, :-1], b[:-1], stream=mx.cpu).astype(x.dtype)
else:
D1s = None
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
pred_res = mx.sum(rhos_p[:, None, None, None] * D1s, axis=0)
else:
pred_res = 0
x_t = x_t_ - alpha_t * B_h * pred_res
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
pred_res = mx.sum(rhos_p[:, None, None, None] * D1s, axis=0)
else:
pred_res = 0
x_t = x_t_ - sigma_t * B_h * pred_res
x_t = x_t.astype(x.dtype)
return x_t
def multistep_uni_c_bh_update(
self,
this_model_output: mx.array,
last_sample: mx.array = None,
this_sample: mx.array = None,
order: int = None,
**kwargs,
) -> mx.array:
"""One step for the UniC (B(h) version)."""
model_output_list = self.model_outputs
m0 = model_output_list[-1]
x = last_sample
x_t = this_sample
model_t = this_model_output
sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0)
h = lambda_t - lambda_s0
rks = []
D1s = []
for i in range(1, order):
si = self.step_index - (i + 1)
mi = model_output_list[-(i + 1)]
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
lambda_si = mx.log(alpha_si) - mx.log(sigma_si)
rk = (lambda_si - lambda_s0) / h
rks.append(rk)
D1s.append((mi - m0) / rk)
rks.append(1.0)
rks = mx.array(rks)
R = []
b = []
hh = -h if self.predict_x0 else h
h_phi_1 = mx.exp(hh) - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.config['solver_type'] == "bh1":
B_h = hh
elif self.config['solver_type'] == "bh2":
B_h = mx.exp(hh) - 1
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(mx.power(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= i + 1
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = mx.stack(R)
b = mx.array(b)
if len(D1s) > 0:
D1s = mx.stack(D1s, axis=1)
else:
D1s = None
# for order 1, we use a simplified version
if order == 1:
rhos_c = mx.array([0.5], dtype=x.dtype)
else:
rhos_c = mx.linalg.solve(R, b, stream=mx.cpu).astype(x.dtype)
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
corr_res = mx.sum(rhos_c[:-1, None, None, None] * D1s, axis=0)
else:
corr_res = 0
D1_t = model_t - m0
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
corr_res = mx.sum(rhos_c[:-1, None, None, None] * D1s, axis=0)
else:
corr_res = 0
D1_t = model_t - m0
x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
x_t = x_t.astype(x.dtype)
return x_t
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
condition = schedule_timesteps == timestep
indices = mx.argmax(condition.astype(mx.int32))
# Convert scalar to int and return
return int(indices)
def _init_step_index(self, timestep):
"""Initialize the step_index counter for the scheduler."""
if self.begin_index is None:
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: mx.array,
timestep: Union[int, mx.array],
sample: mx.array,
return_dict: bool = True,
generator=None
) -> Union[SchedulerOutput, Tuple]:
"""Predict the sample from the previous timestep."""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
use_corrector = (
self.step_index > 0 and
self.step_index - 1 not in self.disable_corrector and
self.last_sample is not None
)
model_output_convert = self.convert_model_output(
model_output, sample=sample
)
if use_corrector:
sample = self.multistep_uni_c_bh_update(
this_model_output=model_output_convert,
last_sample=self.last_sample,
this_sample=sample,
order=self.this_order,
)
for i in range(self.config['solver_order'] - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.timestep_list[i] = self.timestep_list[i + 1]
self.model_outputs[-1] = model_output_convert
self.timestep_list[-1] = timestep
if self.config['lower_order_final']:
this_order = min(
self.config['solver_order'],
len(self.timesteps) - self.step_index
)
else:
this_order = self.config['solver_order']
self.this_order = min(this_order, self.lower_order_nums + 1)
assert self.this_order > 0
self.last_sample = sample
prev_sample = self.multistep_uni_p_bh_update(
model_output=model_output,
sample=sample,
order=self.this_order,
)
if self.lower_order_nums < self.config['solver_order']:
self.lower_order_nums += 1
# Increase step index
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: mx.array, *args, **kwargs) -> mx.array:
"""Scale model input - no scaling needed for this scheduler."""
return sample
def add_noise(
self,
original_samples: mx.array,
noise: mx.array,
timesteps: mx.array,
) -> mx.array:
"""Add noise to original samples."""
sigmas = self.sigmas.astype(original_samples.dtype)
schedule_timesteps = self.timesteps
# Get step indices
if self.begin_index is None:
step_indices = [
self.index_for_timestep(t, schedule_timesteps)
for t in timesteps
]
elif self.step_index is not None:
step_indices = [self.step_index] * timesteps.shape[0]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices]
while len(sigma.shape) < len(original_samples.shape):
sigma = mx.expand_dims(sigma, -1)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
def __len__(self):
return self.config['num_train_timesteps']

View File

@ -0,0 +1,373 @@
# Copied from https://github.com/kq-chen/qwen-vl-utils
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from __future__ import annotations
import base64
import logging
import math
import os
import sys
import time
import warnings
from functools import lru_cache
from io import BytesIO
import requests
import torch
import torchvision
from packaging import version
from PIL import Image
from torchvision import io, transforms
from torchvision.transforms import InterpolationMode
logger = logging.getLogger(__name__)
IMAGE_FACTOR = 28
MIN_PIXELS = 4 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
VIDEO_MIN_PIXELS = 128 * 28 * 28
VIDEO_MAX_PIXELS = 768 * 28 * 28
VIDEO_TOTAL_PIXELS = 24576 * 28 * 28
FRAME_FACTOR = 2
FPS = 2.0
FPS_MIN_FRAMES = 4
FPS_MAX_FRAMES = 768
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def smart_resize(height: int,
width: int,
factor: int = IMAGE_FACTOR,
min_pixels: int = MIN_PIXELS,
max_pixels: int = MAX_PIXELS) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def fetch_image(ele: dict[str, str | Image.Image],
size_factor: int = IMAGE_FACTOR) -> Image.Image:
if "image" in ele:
image = ele["image"]
else:
image = ele["image_url"]
image_obj = None
if isinstance(image, Image.Image):
image_obj = image
elif image.startswith("http://") or image.startswith("https://"):
image_obj = Image.open(requests.get(image, stream=True).raw)
elif image.startswith("file://"):
image_obj = Image.open(image[7:])
elif image.startswith("data:image"):
if "base64," in image:
_, base64_data = image.split("base64,", 1)
data = base64.b64decode(base64_data)
image_obj = Image.open(BytesIO(data))
else:
image_obj = Image.open(image)
if image_obj is None:
raise ValueError(
f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
)
image = image_obj.convert("RGB")
## resize
if "resized_height" in ele and "resized_width" in ele:
resized_height, resized_width = smart_resize(
ele["resized_height"],
ele["resized_width"],
factor=size_factor,
)
else:
width, height = image.size
min_pixels = ele.get("min_pixels", MIN_PIXELS)
max_pixels = ele.get("max_pixels", MAX_PIXELS)
resized_height, resized_width = smart_resize(
height,
width,
factor=size_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
image = image.resize((resized_width, resized_height))
return image
def smart_nframes(
ele: dict,
total_frames: int,
video_fps: int | float,
) -> int:
"""calculate the number of frames for video used for model inputs.
Args:
ele (dict): a dict contains the configuration of video.
support either `fps` or `nframes`:
- nframes: the number of frames to extract for model inputs.
- fps: the fps to extract frames for model inputs.
- min_frames: the minimum number of frames of the video, only used when fps is provided.
- max_frames: the maximum number of frames of the video, only used when fps is provided.
total_frames (int): the original total number of frames of the video.
video_fps (int | float): the original fps of the video.
Raises:
ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
Returns:
int: the number of frames for video used for model inputs.
"""
assert not ("fps" in ele and
"nframes" in ele), "Only accept either `fps` or `nframes`"
if "nframes" in ele:
nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
else:
fps = ele.get("fps", FPS)
min_frames = ceil_by_factor(
ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
max_frames = floor_by_factor(
ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)),
FRAME_FACTOR)
nframes = total_frames / video_fps * fps
nframes = min(max(nframes, min_frames), max_frames)
nframes = round_by_factor(nframes, FRAME_FACTOR)
if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
raise ValueError(
f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}."
)
return nframes
def _read_video_torchvision(ele: dict,) -> torch.Tensor:
"""read video using torchvision.io.read_video
Args:
ele (dict): a dict contains the configuration of video.
support keys:
- video: the path of video. support "file://", "http://", "https://" and local path.
- video_start: the start time of video.
- video_end: the end time of video.
Returns:
torch.Tensor: the video tensor with shape (T, C, H, W).
"""
video_path = ele["video"]
if version.parse(torchvision.__version__) < version.parse("0.19.0"):
if "http://" in video_path or "https://" in video_path:
warnings.warn(
"torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0."
)
if "file://" in video_path:
video_path = video_path[7:]
st = time.time()
video, audio, info = io.read_video(
video_path,
start_pts=ele.get("video_start", 0.0),
end_pts=ele.get("video_end", None),
pts_unit="sec",
output_format="TCHW",
)
total_frames, video_fps = video.size(0), info["video_fps"]
logger.info(
f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s"
)
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(0, total_frames - 1, nframes).round().long()
video = video[idx]
return video
def is_decord_available() -> bool:
import importlib.util
return importlib.util.find_spec("decord") is not None
def _read_video_decord(ele: dict,) -> torch.Tensor:
"""read video using decord.VideoReader
Args:
ele (dict): a dict contains the configuration of video.
support keys:
- video: the path of video. support "file://", "http://", "https://" and local path.
- video_start: the start time of video.
- video_end: the end time of video.
Returns:
torch.Tensor: the video tensor with shape (T, C, H, W).
"""
import decord
video_path = ele["video"]
st = time.time()
vr = decord.VideoReader(video_path)
# TODO: support start_pts and end_pts
if 'video_start' in ele or 'video_end' in ele:
raise NotImplementedError(
"not support start_pts and end_pts in decord for now.")
total_frames, video_fps = len(vr), vr.get_avg_fps()
logger.info(
f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s"
)
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
video = vr.get_batch(idx).asnumpy()
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
return video
VIDEO_READER_BACKENDS = {
"decord": _read_video_decord,
"torchvision": _read_video_torchvision,
}
FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
@lru_cache(maxsize=1)
def get_video_reader_backend() -> str:
if FORCE_QWENVL_VIDEO_READER is not None:
video_reader_backend = FORCE_QWENVL_VIDEO_READER
elif is_decord_available():
video_reader_backend = "decord"
else:
video_reader_backend = "torchvision"
print(
f"qwen-vl-utils using {video_reader_backend} to read video.",
file=sys.stderr)
return video_reader_backend
def fetch_video(
ele: dict,
image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]:
# Handle MPS device compatibility
original_device = None
if isinstance(ele.get("video"), torch.Tensor) and ele["video"].device.type == "cpu":
original_device = ele["video"].device
ele["video"] = ele["video"].cpu()
if isinstance(ele["video"], str):
video_reader_backend = get_video_reader_backend()
video = VIDEO_READER_BACKENDS[video_reader_backend](ele)
nframes, _, height, width = video.shape
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
max_pixels = max(
min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR),
int(min_pixels * 1.05))
max_pixels = ele.get("max_pixels", max_pixels)
if "resized_height" in ele and "resized_width" in ele:
resized_height, resized_width = smart_resize(
ele["resized_height"],
ele["resized_width"],
factor=image_factor,
)
else:
resized_height, resized_width = smart_resize(
height,
width,
factor=image_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
video = transforms.functional.resize(
video,
[resized_height, resized_width],
interpolation=InterpolationMode.BICUBIC,
antialias=True,
).float()
return video
else:
assert isinstance(ele["video"], (list, tuple))
process_info = ele.copy()
process_info.pop("type", None)
process_info.pop("video", None)
images = [
fetch_image({
"image": video_element,
**process_info
},
size_factor=image_factor)
for video_element in ele["video"]
]
nframes = ceil_by_factor(len(images), FRAME_FACTOR)
if len(images) < nframes:
images.extend([images[-1]] * (nframes - len(images)))
return images
# Return to original device if needed
if original_device is not None and isinstance(video, torch.Tensor):
video = video.to(original_device)
def extract_vision_info(
conversations: list[dict] | list[list[dict]]) -> list[dict]:
vision_infos = []
if isinstance(conversations[0], dict):
conversations = [conversations]
for conversation in conversations:
for message in conversation:
if isinstance(message["content"], list):
for ele in message["content"]:
if ("image" in ele or "image_url" in ele or
"video" in ele or
ele["type"] in ("image", "image_url", "video")):
vision_infos.append(ele)
return vision_infos
def process_vision_info(
conversations: list[dict] | list[list[dict]],
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] |
None]:
vision_infos = extract_vision_info(conversations)
## Read images or videos
image_inputs = []
video_inputs = []
for vision_info in vision_infos:
if "image" in vision_info or "image_url" in vision_info:
image_inputs.append(fetch_image(vision_info))
elif "video" in vision_info:
video_inputs.append(fetch_video(vision_info))
else:
raise ValueError("image, image_url or video should in content.")
if len(image_inputs) == 0:
image_inputs = None
if len(video_inputs) == 0:
video_inputs = None
return image_inputs, video_inputs

View File

@ -0,0 +1,175 @@
import argparse
import binascii
import os
import os.path as osp
import imageio
import mlx.core as mx
import numpy as np
__all__ = ['cache_video', 'cache_image', 'str2bool']
def rand_name(length=8, suffix=''):
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
if suffix:
if not suffix.startswith('.'):
suffix = '.' + suffix
name += suffix
return name
def make_grid(tensor, nrow=8, normalize=True, value_range=(-1, 1)):
"""MLX equivalent of torchvision.utils.make_grid"""
# tensor shape: (batch, channels, height, width)
batch_size, channels, height, width = tensor.shape
# Calculate grid dimensions
ncol = nrow
nrow_actual = (batch_size + ncol - 1) // ncol
# Create grid
grid_height = height * nrow_actual + (nrow_actual - 1) * 2 # 2 pixel padding
grid_width = width * ncol + (ncol - 1) * 2
# Initialize grid with zeros
grid = mx.zeros((channels, grid_height, grid_width))
# Fill grid
for idx in range(batch_size):
row = idx // ncol
col = idx % ncol
y_start = row * (height + 2)
y_end = y_start + height
x_start = col * (width + 2)
x_end = x_start + width
img = tensor[idx]
if normalize:
# Normalize to [0, 1]
img = (img - value_range[0]) / (value_range[1] - value_range[0])
grid[:, y_start:y_end, x_start:x_end] = img
return grid
def cache_video(tensor,
save_file=None,
fps=30,
suffix='.mp4',
nrow=8,
normalize=True,
value_range=(-1, 1),
retry=5):
# cache file
cache_file = osp.join('/tmp', rand_name(
suffix=suffix)) if save_file is None else save_file
# save to cache
error = None
for _ in range(retry):
try:
# preprocess
tensor = mx.clip(tensor, value_range[0], value_range[1])
# tensor shape: (batch, channels, frames, height, width)
# Process each frame
frames = []
for frame_idx in range(tensor.shape[2]):
frame = tensor[:, :, frame_idx, :, :] # (batch, channels, height, width)
grid = make_grid(frame, nrow=nrow, normalize=normalize, value_range=value_range)
frames.append(grid)
# Stack frames and convert to (frames, height, width, channels)
tensor = mx.stack(frames, axis=0) # (frames, channels, height, width)
tensor = mx.transpose(tensor, [0, 2, 3, 1]) # (frames, height, width, channels)
# Convert to uint8
tensor = (tensor * 255).astype(mx.uint8)
tensor_np = np.array(tensor)
# write video
writer = imageio.get_writer(
cache_file, fps=fps, codec='libx264', quality=8)
for frame in tensor_np:
writer.append_data(frame)
writer.close()
return cache_file
except Exception as e:
error = e
continue
else:
print(f'cache_video failed, error: {error}', flush=True)
return None
def save_image(tensor, save_file, nrow=8, normalize=True, value_range=(-1, 1)):
"""MLX equivalent of torchvision.utils.save_image"""
# Make grid
grid = make_grid(tensor, nrow=nrow, normalize=normalize, value_range=value_range)
# Convert to (height, width, channels) and uint8
grid = mx.transpose(grid, [1, 2, 0]) # (height, width, channels)
grid = (grid * 255).astype(mx.uint8)
# Save using imageio
imageio.imwrite(save_file, np.array(grid))
def cache_image(tensor,
save_file,
nrow=8,
normalize=True,
value_range=(-1, 1),
retry=5):
# cache file
suffix = osp.splitext(save_file)[1]
if suffix.lower() not in [
'.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp'
]:
suffix = '.png'
# save to cache
error = None
for _ in range(retry):
try:
tensor = mx.clip(tensor, value_range[0], value_range[1])
save_image(
tensor,
save_file,
nrow=nrow,
normalize=normalize,
value_range=value_range)
return save_file
except Exception as e:
error = e
continue
def str2bool(v):
"""
Convert a string to a boolean.
Supported true values: 'yes', 'true', 't', 'y', '1'
Supported false values: 'no', 'false', 'f', 'n', '0'
Args:
v (str): String to convert.
Returns:
bool: Converted boolean value.
Raises:
argparse.ArgumentTypeError: If the value cannot be converted to boolean.
"""
if isinstance(v, bool):
return v
v_lower = v.lower()
if v_lower in ('yes', 'true', 't', 'y', '1'):
return True
elif v_lower in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected (True/False)')

View File

@ -0,0 +1,175 @@
import torch
import mlx.core as mx
import numpy as np
from typing import Dict, Tuple
from safetensors import safe_open
def convert_pytorch_to_mlx(pytorch_path: str, output_path: str, float16: bool = False):
"""
Convert PyTorch VAE weights to MLX format with correct mapping.
"""
print(f"Converting {pytorch_path} -> {output_path}")
dtype = mx.float16 if float16 else mx.float32
# Load PyTorch weights
if pytorch_path.endswith('.safetensors'):
weights = {}
with safe_open(pytorch_path, framework="pt", device="cpu") as f:
for key in f.keys():
tensor = f.get_tensor(key)
if tensor.dtype == torch.bfloat16:
tensor = tensor.float()
weights[key] = tensor.numpy()
else:
checkpoint = torch.load(pytorch_path, map_location='cpu')
weights = {}
state_dict = checkpoint if isinstance(checkpoint, dict) and 'state_dict' not in checkpoint else checkpoint.get('state_dict', checkpoint)
for key, tensor in state_dict.items():
if tensor.dtype == torch.bfloat16:
tensor = tensor.float()
weights[key] = tensor.numpy()
# Convert weights
mlx_weights = {}
for key, value in weights.items():
# Skip these
if any(skip in key for skip in ["num_batches_tracked", "running_mean", "running_var"]):
continue
# Convert weight formats
if value.ndim == 5 and "weight" in key: # Conv3d weights
# PyTorch: (out_channels, in_channels, D, H, W)
# MLX: (out_channels, D, H, W, in_channels)
value = np.transpose(value, (0, 2, 3, 4, 1))
elif value.ndim == 4 and "weight" in key: # Conv2d weights
# PyTorch: (out_channels, in_channels, H, W)
# MLX Conv2d expects: (out_channels, H, W, in_channels)
value = np.transpose(value, (0, 2, 3, 1))
elif value.ndim == 1 and "bias" in key: # Conv biases
# Keep as is - MLX uses same format
pass
# Map the key
new_key = key
# Map residual block internals within Sequential
# PyTorch: encoder.downsamples.0.residual.0.gamma
# MLX: encoder.downsamples.layers.0.residual.layers.0.gamma
import re
# Add .layers to Sequential modules
new_key = re.sub(r'\.downsamples\.(\d+)', r'.downsamples.layers.\1', new_key)
new_key = re.sub(r'\.upsamples\.(\d+)', r'.upsamples.layers.\1', new_key)
new_key = re.sub(r'\.middle\.(\d+)', r'.middle.layers.\1', new_key)
new_key = re.sub(r'\.head\.(\d+)', r'.head.layers.\1', new_key)
# Map residual Sequential internals
if ".residual." in new_key:
match = re.search(r'\.residual\.(\d+)\.', new_key)
if match:
idx = int(match.group(1))
if idx == 0: # First RMS_norm
new_key = re.sub(r'\.residual\.0\.', '.residual.layers.0.', new_key)
elif idx == 1: # SiLU - skip
continue
elif idx == 2: # First Conv3d
new_key = re.sub(r'\.residual\.2\.', '.residual.layers.2.', new_key)
elif idx == 3: # Second RMS_norm
new_key = re.sub(r'\.residual\.3\.', '.residual.layers.3.', new_key)
elif idx == 4: # Second SiLU - skip
continue
elif idx == 5: # Dropout - could be Identity in MLX
if "Dropout" in key:
continue
new_key = re.sub(r'\.residual\.5\.', '.residual.layers.5.', new_key)
elif idx == 6: # Second Conv3d
new_key = re.sub(r'\.residual\.6\.', '.residual.layers.6.', new_key)
# ------ START: REPLACEMENT BLOCK ------
# Map resample internals
if ".resample." in new_key:
# In both Encoder and Decoder Resample blocks, the Conv2d is at index 1
# in the nn.Sequential block, following either a padding or upsample layer.
# We just need to map PyTorch's .1 to MLX's .layers.1
if ".resample.1." in new_key:
new_key = new_key.replace(".resample.1.", ".resample.layers.1.")
# The layers at index 0 (ZeroPad2d, Upsample) have no weights, so we can
# safely skip any keys associated with them.
if ".resample.0." in key:
continue
# ------ END: REPLACEMENT BLOCK ------
# Map head internals (already using Sequential in MLX)
# Just need to handle the layers index
# Handle shortcut layers
if ".shortcut." in new_key and "Identity" not in key:
# Shortcut Conv3d layers - keep as is
pass
elif "Identity" in key:
# Skip Identity modules
continue
# Handle time_conv in Resample
if "time_conv" in new_key:
# Keep as is - already correctly named
pass
# Handle attention layers
if "to_qkv" in new_key or "proj" in new_key:
# Keep as is - already correctly named
pass
# In the conversion script
if "gamma" in new_key:
# Squeeze gamma from (C, 1, 1) or (C, 1, 1, 1) to just (C,)
value = np.squeeze(value) # This removes all dimensions of size 1
# Result will always be 1D array of shape (C,)
# Add to MLX weights
mlx_weights[new_key] = mx.array(value).astype(dtype)
# Verify critical layers are present
critical_prefixes = [
"encoder.conv1", "decoder.conv1", "conv1", "conv2",
"encoder.head.layers.2", "decoder.head.layers.2" # Updated for Sequential
]
for prefix in critical_prefixes:
found = any(k.startswith(prefix) for k in mlx_weights.keys())
if not found:
print(f"WARNING: No weights found for {prefix}")
print(f"Converted {len(mlx_weights)} parameters")
# Print a few example keys for verification
print("\nExample converted keys:")
for i, key in enumerate(sorted(mlx_weights.keys())[:10]):
print(f" {key}")
# Save
if output_path.endswith('.safetensors'):
mx.save_safetensors(output_path, mlx_weights)
else:
mx.savez(output_path, **mlx_weights)
print(f"\nSaved to {output_path}")
print("\nAll converted keys:")
for key in sorted(mlx_weights.keys()):
print(f" {key}: {mlx_weights[key].shape}")
return mlx_weights
if __name__ == "__main__":
import sys
if len(sys.argv) < 3:
print("Usage: python convert_vae_final.py <input.pth> <output.safetensors> [--fp16]")
else:
convert_pytorch_to_mlx(
sys.argv[1],
sys.argv[2],
"--fp16" in sys.argv
)

View File

@ -0,0 +1,228 @@
from typing import List, Tuple
import os
import mlx.core as mx
from mlx.utils import tree_unflatten
from safetensors import safe_open
import torch
import numpy as np
def map_wan_weights(key: str, value: mx.array) -> List[Tuple[str, mx.array]]:
# Remove .layers. from PyTorch Sequential to match MLX Python lists
key = key.replace(".layers.", ".")
# Handle conv transpose if needed
if "patch_embedding.weight" in key:
value = mx.transpose(value, (0, 2, 3, 4, 1))
return [(key, value)]
def _flatten(params: List[List[Tuple[str, mx.array]]]) -> List[Tuple[str, mx.array]]:
"""Flatten nested list of parameter tuples"""
return [(k, v) for p in params for (k, v) in p]
def load_wan_from_safetensors(
safetensors_path: str,
model,
float16: bool = False
):
"""
Load WanModel weights from safetensors file(s) into MLX model.
"""
import os
import glob
if os.path.isdir(safetensors_path):
# Multiple files (14B model) - only diffusion_mlx_model files
pattern = os.path.join(safetensors_path, "diffusion_mlx_model*.safetensors")
safetensor_files = sorted(glob.glob(pattern))
print(f"Found {len(safetensor_files)} diffusion_mlx_model safetensors files")
# Load all files and merge weights
all_weights = {}
for file_path in safetensor_files:
print(f"Loading: {file_path}")
weights = mx.load(file_path)
all_weights.update(weights)
model.update(tree_unflatten(list(all_weights.items())))
else:
# Single file (1.3B model)
print(f"Loading single file: {safetensors_path}")
weights = mx.load(safetensors_path)
model.update(tree_unflatten(list(weights.items())))
print("WanModel weights loaded successfully!")
return model
def convert_safetensors_to_mlx_weights(
safetensors_path: str,
output_path: str,
float16: bool = False
):
"""
Convert safetensors file to MLX weights file.
Args:
safetensors_path: Input safetensors file
output_path: Output MLX weights file (.safetensors)
float16: Whether to use float16 precision
"""
dtype = mx.float16 if float16 else mx.float32
print(f"Converting safetensors to MLX format...")
print(f"Input: {safetensors_path}")
print(f"Output: {output_path}")
print(f"Target dtype: {dtype}")
# Load and convert weights
weights = {}
bfloat16_count = 0
original_keys = []
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
original_keys = list(f.keys()) # Store keys before closing
print(f"Processing {len(original_keys)} parameters...")
for key in original_keys:
tensor = f.get_tensor(key)
# Handle BFloat16
if tensor.dtype == torch.bfloat16:
bfloat16_count += 1
tensor = tensor.float() # Convert to float32 first
value = mx.array(tensor.numpy()).astype(dtype)
# Apply mapping
mapped = map_wan_weights(key, value)
for new_key, new_value in mapped:
weights[new_key] = new_value
if bfloat16_count > 0:
print(f"⚠️ Converted {bfloat16_count} BFloat16 tensors to {dtype}")
# Print mapping summary
skipped = len(original_keys) - len(weights)
if skipped > 0:
print(f" Skipped {skipped} activation layer parameters")
# Save as MLX format
print(f"Saving {len(weights)} parameters to: {output_path}")
mx.save_safetensors(output_path, weights)
# Print a few example keys for verification
print("\nExample converted keys:")
for i, key in enumerate(sorted(weights.keys())[:10]):
print(f" {key}: {weights[key].shape}")
return weights
def convert_multiple_safetensors_to_mlx(
checkpoint_dir: str,
float16: bool = False
):
"""Convert multiple PyTorch safetensors files to MLX format."""
import glob
# Find all PyTorch model files
pytorch_pattern = os.path.join(checkpoint_dir, "diffusion_pytorch_model-*.safetensors")
pytorch_files = sorted(glob.glob(pytorch_pattern))
if not pytorch_files:
raise FileNotFoundError(f"No PyTorch model files found matching: {pytorch_pattern}")
print(f"Converting {len(pytorch_files)} PyTorch files to MLX format...")
for i, pytorch_file in enumerate(pytorch_files, 1):
# Extract the suffix (e.g., "00001-of-00006")
basename = os.path.basename(pytorch_file)
suffix = basename.replace("diffusion_pytorch_model-", "").replace(".safetensors", "")
# Create MLX filename
mlx_file = os.path.join(checkpoint_dir, f"diffusion_mlx_model-{suffix}.safetensors")
print(f"Converting {i}/{len(pytorch_files)}: {basename}")
convert_safetensors_to_mlx_weights(pytorch_file, mlx_file, float16)
print("All files converted successfully!")
def debug_weight_mapping(safetensors_path: str, float16: bool = False):
"""
Debug function to see how weights are being mapped.
"""
dtype = mx.float16 if float16 else mx.float32
print("=== WAN Weight Mapping Debug ===")
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
# Check first 20 keys to see the mapping
for i, key in enumerate(f.keys()):
if i >= 20:
break
tensor = f.get_tensor(key)
# Handle BFloat16
original_dtype = tensor.dtype
if tensor.dtype == torch.bfloat16:
tensor = tensor.float()
value = mx.array(tensor.numpy()).astype(dtype)
# Apply mapping
mapped = map_wan_weights(key, value)
if len(mapped) == 0:
print(f"SKIPPED: {key} ({original_dtype})")
elif len(mapped) == 1:
new_key, new_value = mapped[0]
if new_key == key:
print(f"DIRECT: {key} ({original_dtype}) [{tensor.shape}]")
else:
print(f"MAPPED: {key} -> {new_key} [{tensor.shape}]")
def check_model_structure(model):
"""
Print the structure of an MLX model to debug loading issues.
"""
from mlx.utils import tree_flatten
print("=== Model Structure ===")
params = dict(tree_flatten(model))
print(f"Model has {len(params)} parameters")
print("\nFirst 20 parameter names:")
for i, (key, value) in enumerate(params.items()):
if i >= 20:
break
print(f" {key}: {value.shape}")
return params
# Example usage
if __name__ == "__main__":
import sys
if len(sys.argv) < 3:
print("Usage: python wan_model_io.py <input.safetensors> <output.safetensors> [--fp16]")
sys.exit(1)
input_file = sys.argv[1]
output_file = sys.argv[2]
use_fp16 = "--fp16" in sys.argv
# Debug the mapping first (optional)
debug_weight_mapping(input_file, use_fp16)
# Convert weights
convert_safetensors_to_mlx_weights(input_file, output_file, float16=use_fp16)
print("Conversion complete!")