Implement Wan 2.1
38
video/Wan2.1/.gitignore
vendored
Normal file
@ -0,0 +1,38 @@
|
||||
.*
|
||||
*.py[cod]
|
||||
# *.jpg
|
||||
*.jpeg
|
||||
# *.png
|
||||
*.gif
|
||||
*.bmp
|
||||
*.mp4
|
||||
*.mov
|
||||
*.mkv
|
||||
*.log
|
||||
*.zip
|
||||
*.pt
|
||||
*.pth
|
||||
*.ckpt
|
||||
*.safetensors
|
||||
*.json
|
||||
# *.txt
|
||||
*.backup
|
||||
*.pkl
|
||||
*.html
|
||||
*.pdf
|
||||
*.whl
|
||||
cache
|
||||
__pycache__/
|
||||
storage/
|
||||
samples/
|
||||
!.gitignore
|
||||
!requirements.txt
|
||||
.DS_Store
|
||||
*DS_Store
|
||||
google/
|
||||
Wan2.1-T2V-14B/
|
||||
Wan2.1-T2V-1.3B/
|
||||
Wan2.1-I2V-14B-480P/
|
||||
Wan2.1-I2V-14B-720P/
|
||||
venv_wan/
|
||||
venv_wan_py310/
|
201
video/Wan2.1/LICENSE.txt
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
83
video/Wan2.1/README.md
Normal file
@ -0,0 +1,83 @@
|
||||
# Wan2.1
|
||||
|
||||
## Quickstart
|
||||
|
||||
#### Installation
|
||||
Install dependencies:
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
|
||||
#### Model Download
|
||||
|
||||
| Models | Download Link | Notes |
|
||||
| --------------|-------------------------------------------------------------------------------|-------------------------------|
|
||||
| T2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | Supports both 480P and 720P
|
||||
| I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | Supports 720P
|
||||
| I2V-14B-480P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | Supports 480P
|
||||
| T2V-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | Supports 480P
|
||||
|
||||
> 💡Note: The 1.3B model is capable of generating videos at 720P resolution. However, due to limited training at this resolution, the results are generally less stable compared to 480P. For optimal performance, we recommend using 480P resolution. Also, note that the MLX port currently only supports T2V.
|
||||
|
||||
Download models using huggingface-cli:
|
||||
```
|
||||
pip install "huggingface_hub[cli]"
|
||||
huggingface-cli download Wan-AI/Wan2.1-T2V-14B --local-dir ./Wan2.1-T2V-14B
|
||||
```
|
||||
|
||||
Download models using modelscope-cli:
|
||||
```
|
||||
pip install modelscope
|
||||
modelscope download Wan-AI/Wan2.1-T2V-14B --local_dir ./Wan2.1-T2V-14B
|
||||
```
|
||||
#### Run Text-to-Video Generation
|
||||
|
||||
This repository currently supports two Text-to-Video models (1.3B and 14B) and two resolutions (480P and 720P). The parameters and configurations for these models are as follows:
|
||||
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th rowspan="2">Task</th>
|
||||
<th colspan="2">Resolution</th>
|
||||
<th rowspan="2">Model</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<th>480P</th>
|
||||
<th>720P</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td>t2v-14B</td>
|
||||
<td style="color: green;">✔️</td>
|
||||
<td style="color: green;">✔️</td>
|
||||
<td>Wan2.1-T2V-14B</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>t2v-1.3B</td>
|
||||
<td style="color: green;">✔️</td>
|
||||
<td style="color: red;">❌</td>
|
||||
<td>Wan2.1-T2V-1.3B</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
|
||||
##### (1) Example:
|
||||
```
|
||||
python generate_mlx.py --task t2v-1.3B --size "480*832" --frame_num 16 --sample_steps 25 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --prompt "Lion running under snow in Samarkand" --save_file output_video_mlx.mp4
|
||||
```
|
||||
|
||||
|
||||
## Citation
|
||||
Credits to the Wan Team for the original PyTorch implementation.
|
||||
|
||||
```
|
||||
@article{wan2.1,
|
||||
title = {Wan: Open and Advanced Large-Scale Video Generative Models},
|
||||
author = {Wan Team},
|
||||
journal = {},
|
||||
year = {2025}
|
||||
}
|
||||
```
|
BIN
video/Wan2.1/assets/comp_effic.png
Normal file
After Width: | Height: | Size: 1.7 MiB |
BIN
video/Wan2.1/assets/data_for_diff_stage.jpg
Normal file
After Width: | Height: | Size: 516 KiB |
BIN
video/Wan2.1/assets/i2v_res.png
Normal file
After Width: | Height: | Size: 871 KiB |
BIN
video/Wan2.1/assets/logo.png
Normal file
After Width: | Height: | Size: 55 KiB |
BIN
video/Wan2.1/assets/t2v_res.jpg
Normal file
After Width: | Height: | Size: 294 KiB |
BIN
video/Wan2.1/assets/vben_vs_sota.png
Normal file
After Width: | Height: | Size: 1.5 MiB |
BIN
video/Wan2.1/assets/video_dit_arch.jpg
Normal file
After Width: | Height: | Size: 628 KiB |
BIN
video/Wan2.1/assets/video_vae_res.jpg
Normal file
After Width: | Height: | Size: 208 KiB |
BIN
video/Wan2.1/examples/i2v_input.JPG
Normal file
After Width: | Height: | Size: 245 KiB |
245
video/Wan2.1/generate_mlx.py
Normal file
@ -0,0 +1,245 @@
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
|
||||
import mlx.core as mx
|
||||
from PIL import Image
|
||||
|
||||
import wan
|
||||
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
|
||||
from wan.utils.utils_mlx_own import cache_video, cache_image, str2bool
|
||||
|
||||
EXAMPLE_PROMPT = {
|
||||
"t2v-1.3B": {
|
||||
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||||
},
|
||||
"t2v-14B": {
|
||||
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||||
},
|
||||
"t2i-14B": {
|
||||
"prompt": "一个朴素端庄的美人",
|
||||
},
|
||||
"i2v-14B": {
|
||||
"prompt":
|
||||
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
|
||||
"image":
|
||||
"examples/i2v_input.JPG",
|
||||
},
|
||||
}
|
||||
|
||||
def _validate_args(args):
|
||||
# Basic check
|
||||
assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
|
||||
assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
|
||||
assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
|
||||
|
||||
# The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
|
||||
if args.sample_steps is None:
|
||||
args.sample_steps = 40 if "i2v" in args.task else 50
|
||||
|
||||
if args.sample_shift is None:
|
||||
args.sample_shift = 5.0
|
||||
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
|
||||
args.sample_shift = 3.0
|
||||
|
||||
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
|
||||
if args.frame_num is None:
|
||||
args.frame_num = 1 if "t2i" in args.task else 1
|
||||
|
||||
# T2I frame_num check
|
||||
if "t2i" in args.task:
|
||||
assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}"
|
||||
|
||||
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
|
||||
0, sys.maxsize)
|
||||
# Size check
|
||||
assert args.size in SUPPORTED_SIZES[
|
||||
args.
|
||||
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
|
||||
|
||||
def _parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate a image or video from a text prompt or image using Wan"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default="t2v-14B",
|
||||
choices=list(WAN_CONFIGS.keys()),
|
||||
help="The task to run.")
|
||||
parser.add_argument(
|
||||
"--size",
|
||||
type=str,
|
||||
default="1280*720",
|
||||
choices=list(SIZE_CONFIGS.keys()),
|
||||
help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--frame_num",
|
||||
type=int,
|
||||
default=None,
|
||||
help="How many frames to sample from a image or video. The number should be 4n+1"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ckpt_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the checkpoint directory.")
|
||||
parser.add_argument(
|
||||
"--offload_model",
|
||||
type=str2bool,
|
||||
default=None,
|
||||
help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file to save the generated image or video to.")
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The prompt to generate the image or video from.")
|
||||
parser.add_argument(
|
||||
"--base_seed",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="The seed to use for generating the image or video.")
|
||||
parser.add_argument(
|
||||
"--image",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The image to generate the video from.")
|
||||
parser.add_argument(
|
||||
"--sample_solver",
|
||||
type=str,
|
||||
default='unipc',
|
||||
choices=['unipc', 'dpm++'],
|
||||
help="The solver used to sample.")
|
||||
parser.add_argument(
|
||||
"--sample_steps", type=int, default=None, help="The sampling steps.")
|
||||
parser.add_argument(
|
||||
"--sample_shift",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Sampling shift factor for flow matching schedulers.")
|
||||
parser.add_argument(
|
||||
"--sample_guide_scale",
|
||||
type=float,
|
||||
default=5.0,
|
||||
help="Classifier free guidance scale.")
|
||||
|
||||
args = parser.parse_args()
|
||||
_validate_args(args)
|
||||
return args
|
||||
|
||||
|
||||
def _init_logging():
|
||||
# logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="[%(asctime)s] %(levelname)s: %(message)s",
|
||||
handlers=[logging.StreamHandler(stream=sys.stdout)])
|
||||
|
||||
|
||||
def generate(args):
|
||||
_init_logging()
|
||||
|
||||
# MLX uses default device automatically
|
||||
|
||||
if args.offload_model is None:
|
||||
args.offload_model = True # Default to True to save memory
|
||||
logging.info(
|
||||
f"offload_model is not specified, set to {args.offload_model}.")
|
||||
|
||||
cfg = WAN_CONFIGS[args.task]
|
||||
logging.info(f"Generation job args: {args}")
|
||||
logging.info(f"Generation model config: {cfg}")
|
||||
|
||||
if "t2v" in args.task or "t2i" in args.task:
|
||||
if args.prompt is None:
|
||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||
logging.info(f"Input prompt: {args.prompt}")
|
||||
|
||||
logging.info("Creating WanT2V pipeline.")
|
||||
wan_t2v = wan.WanT2V(
|
||||
config=cfg,
|
||||
checkpoint_dir=args.ckpt_dir,
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
|
||||
video = wan_t2v.generate(
|
||||
args.prompt,
|
||||
size=SIZE_CONFIGS[args.size],
|
||||
frame_num=args.frame_num,
|
||||
shift=args.sample_shift,
|
||||
sample_solver=args.sample_solver,
|
||||
sampling_steps=args.sample_steps,
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
|
||||
else:
|
||||
if args.prompt is None:
|
||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||
if args.image is None:
|
||||
args.image = EXAMPLE_PROMPT[args.task]["image"]
|
||||
logging.info(f"Input prompt: {args.prompt}")
|
||||
logging.info(f"Input image: {args.image}")
|
||||
|
||||
img = Image.open(args.image).convert("RGB")
|
||||
|
||||
logging.info("Creating WanI2V pipeline.")
|
||||
wan_i2v = wan.WanI2V(
|
||||
config=cfg,
|
||||
checkpoint_dir=args.ckpt_dir,
|
||||
)
|
||||
|
||||
logging.info("Generating video ...")
|
||||
video = wan_i2v.generate(
|
||||
args.prompt,
|
||||
img,
|
||||
max_area=MAX_AREA_CONFIGS[args.size],
|
||||
frame_num=args.frame_num,
|
||||
shift=args.sample_shift,
|
||||
sample_solver=args.sample_solver,
|
||||
sampling_steps=args.sample_steps,
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
|
||||
# Save output
|
||||
if args.save_file is None:
|
||||
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
formatted_prompt = args.prompt.replace(" ", "_").replace("/", "_")[:50]
|
||||
suffix = '.png' if "t2i" in args.task else '.mp4'
|
||||
args.save_file = f"{args.task}_{args.size}_{formatted_prompt}_{formatted_time}" + suffix
|
||||
|
||||
if "t2i" in args.task:
|
||||
logging.info(f"Saving generated image to {args.save_file}")
|
||||
# Note: cache_image might need to handle MLX arrays
|
||||
cache_image(
|
||||
tensor=video.squeeze(1)[None],
|
||||
save_file=args.save_file,
|
||||
nrow=1,
|
||||
normalize=True,
|
||||
value_range=(-1, 1))
|
||||
else:
|
||||
logging.info(f"Saving generated video to {args.save_file}")
|
||||
cache_video(
|
||||
tensor=video[None],
|
||||
save_file=args.save_file,
|
||||
fps=cfg.sample_fps,
|
||||
nrow=1,
|
||||
normalize=True,
|
||||
value_range=(-1, 1))
|
||||
logging.info("Finished.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = _parse_args()
|
||||
generate(args)
|
18
video/Wan2.1/requirements.txt
Normal file
@ -0,0 +1,18 @@
|
||||
torch>=2.4.0
|
||||
torchvision>=0.19.0
|
||||
opencv-python>=4.9.0.80
|
||||
diffusers>=0.31.0
|
||||
transformers>=4.49.0
|
||||
tokenizers>=0.20.3
|
||||
accelerate>=1.1.1
|
||||
tqdm
|
||||
imageio
|
||||
easydict
|
||||
ftfy
|
||||
dashscope
|
||||
imageio-ffmpeg
|
||||
# flash_attn
|
||||
gradio>=5.0.0
|
||||
numpy>=1.23.5,<2
|
||||
mlx
|
||||
scikit-image
|
6
video/Wan2.1/tests/README.md
Normal file
@ -0,0 +1,6 @@
|
||||
|
||||
Put all your models (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P) in a folder and specify the max GPU number you want to use.
|
||||
|
||||
```bash
|
||||
bash ./test.sh <local model dir> <gpu number>
|
||||
```
|
113
video/Wan2.1/tests/test.sh
Normal file
@ -0,0 +1,113 @@
|
||||
#!/bin/bash
|
||||
|
||||
|
||||
if [ "$#" -eq 2 ]; then
|
||||
MODEL_DIR=$(realpath "$1")
|
||||
GPUS=$2
|
||||
else
|
||||
echo "Usage: $0 <local model dir> <gpu number>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
REPO_ROOT="$(dirname "$SCRIPT_DIR")"
|
||||
cd "$REPO_ROOT" || exit 1
|
||||
|
||||
PY_FILE=./generate.py
|
||||
|
||||
|
||||
function t2v_1_3B() {
|
||||
T2V_1_3B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-1.3B"
|
||||
|
||||
# 1-GPU Test
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B 1-GPU Test: "
|
||||
python $PY_FILE --task t2v-1.3B --size 480*832 --ckpt_dir $T2V_1_3B_CKPT_DIR
|
||||
|
||||
# Multiple GPU Test
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU Test: "
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
|
||||
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend local_qwen: "
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
|
||||
|
||||
if [ -n "${DASH_API_KEY+x}" ]; then
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend dashscope: "
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope"
|
||||
else
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test."
|
||||
fi
|
||||
}
|
||||
|
||||
function t2v_14B() {
|
||||
T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B"
|
||||
|
||||
# 1-GPU Test
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B 1-GPU Test: "
|
||||
python $PY_FILE --task t2v-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR
|
||||
|
||||
# Multiple GPU Test
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU Test: "
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
|
||||
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU, prompt extend local_qwen: "
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
|
||||
}
|
||||
|
||||
|
||||
|
||||
function t2i_14B() {
|
||||
T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B"
|
||||
|
||||
# 1-GPU Test
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B 1-GPU Test: "
|
||||
python $PY_FILE --task t2i-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR
|
||||
|
||||
# Multiple GPU Test
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU Test: "
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
|
||||
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU, prompt extend local_qwen: "
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
|
||||
}
|
||||
|
||||
|
||||
function i2v_14B_480p() {
|
||||
I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-480P"
|
||||
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: "
|
||||
python $PY_FILE --task i2v-14B --size 832*480 --ckpt_dir $I2V_14B_CKPT_DIR
|
||||
|
||||
# Multiple GPU Test
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: "
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
|
||||
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend local_qwen: "
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-VL-3B-Instruct" --prompt_extend_target_lang "en"
|
||||
|
||||
if [ -n "${DASH_API_KEY+x}" ]; then
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend dashscope: "
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope"
|
||||
else
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test."
|
||||
fi
|
||||
}
|
||||
|
||||
|
||||
function i2v_14B_720p() {
|
||||
I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-720P"
|
||||
|
||||
# 1-GPU Test
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: "
|
||||
python $PY_FILE --task i2v-14B --size 720*1280 --ckpt_dir $I2V_14B_CKPT_DIR
|
||||
|
||||
# Multiple GPU Test
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: "
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
|
||||
}
|
||||
|
||||
|
||||
t2i_14B
|
||||
t2v_1_3B
|
||||
t2v_14B
|
||||
i2v_14B_480p
|
||||
i2v_14B_720p
|
2
video/Wan2.1/wan/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from . import configs, modules
|
||||
from .text2video_mlx import WanT2V
|
42
video/Wan2.1/wan/configs/__init__.py
Normal file
@ -0,0 +1,42 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import copy
|
||||
import os
|
||||
|
||||
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
||||
|
||||
from .wan_i2v_14B import i2v_14B
|
||||
from .wan_t2v_1_3B import t2v_1_3B
|
||||
from .wan_t2v_14B import t2v_14B
|
||||
|
||||
# the config of t2i_14B is the same as t2v_14B
|
||||
t2i_14B = copy.deepcopy(t2v_14B)
|
||||
t2i_14B.__name__ = 'Config: Wan T2I 14B'
|
||||
|
||||
WAN_CONFIGS = {
|
||||
't2v-14B': t2v_14B,
|
||||
't2v-1.3B': t2v_1_3B,
|
||||
'i2v-14B': i2v_14B,
|
||||
't2i-14B': t2i_14B,
|
||||
}
|
||||
|
||||
SIZE_CONFIGS = {
|
||||
'720*1280': (720, 1280),
|
||||
'1280*720': (1280, 720),
|
||||
'480*832': (480, 832),
|
||||
'832*480': (832, 480),
|
||||
'1024*1024': (1024, 1024),
|
||||
}
|
||||
|
||||
MAX_AREA_CONFIGS = {
|
||||
'720*1280': 720 * 1280,
|
||||
'1280*720': 1280 * 720,
|
||||
'480*832': 480 * 832,
|
||||
'832*480': 832 * 480,
|
||||
}
|
||||
|
||||
SUPPORTED_SIZES = {
|
||||
't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
||||
't2v-1.3B': ('480*832', '832*480'),
|
||||
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
||||
't2i-14B': tuple(SIZE_CONFIGS.keys()),
|
||||
}
|
19
video/Wan2.1/wan/configs/shared_config.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import torch
|
||||
from easydict import EasyDict
|
||||
|
||||
#------------------------ Wan shared config ------------------------#
|
||||
wan_shared_cfg = EasyDict()
|
||||
|
||||
# t5
|
||||
wan_shared_cfg.t5_model = 'umt5_xxl'
|
||||
wan_shared_cfg.t5_dtype = torch.float32
|
||||
wan_shared_cfg.text_len = 512
|
||||
|
||||
# transformer
|
||||
wan_shared_cfg.param_dtype = torch.float32
|
||||
|
||||
# inference
|
||||
wan_shared_cfg.num_train_timesteps = 1000
|
||||
wan_shared_cfg.sample_fps = 16
|
||||
wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
|
35
video/Wan2.1/wan/configs/wan_i2v_14B.py
Normal file
@ -0,0 +1,35 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import torch
|
||||
from easydict import EasyDict
|
||||
|
||||
from .shared_config import wan_shared_cfg
|
||||
|
||||
#------------------------ Wan I2V 14B ------------------------#
|
||||
|
||||
i2v_14B = EasyDict(__name__='Config: Wan I2V 14B')
|
||||
i2v_14B.update(wan_shared_cfg)
|
||||
|
||||
i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
||||
i2v_14B.t5_tokenizer = 'google/umt5-xxl'
|
||||
|
||||
# clip
|
||||
i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
|
||||
i2v_14B.clip_dtype = torch.float32
|
||||
i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
|
||||
i2v_14B.clip_tokenizer = 'xlm-roberta-large'
|
||||
|
||||
# vae
|
||||
i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
|
||||
i2v_14B.vae_stride = (4, 8, 8)
|
||||
|
||||
# transformer
|
||||
i2v_14B.patch_size = (1, 2, 2)
|
||||
i2v_14B.dim = 5120
|
||||
i2v_14B.ffn_dim = 13824
|
||||
i2v_14B.freq_dim = 256
|
||||
i2v_14B.num_heads = 40
|
||||
i2v_14B.num_layers = 40
|
||||
i2v_14B.window_size = (-1, -1)
|
||||
i2v_14B.qk_norm = True
|
||||
i2v_14B.cross_attn_norm = True
|
||||
i2v_14B.eps = 1e-6
|
29
video/Wan2.1/wan/configs/wan_t2v_14B.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
from easydict import EasyDict
|
||||
|
||||
from .shared_config import wan_shared_cfg
|
||||
|
||||
#------------------------ Wan T2V 14B ------------------------#
|
||||
|
||||
t2v_14B = EasyDict(__name__='Config: Wan T2V 14B')
|
||||
t2v_14B.update(wan_shared_cfg)
|
||||
|
||||
# t5
|
||||
t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.safetensors'
|
||||
t2v_14B.t5_tokenizer = 'google/umt5-xxl'
|
||||
|
||||
# vae
|
||||
t2v_14B.vae_checkpoint = 'vae_mlx.safetensors'
|
||||
t2v_14B.vae_stride = (4, 8, 8)
|
||||
|
||||
# transformer
|
||||
t2v_14B.patch_size = (1, 2, 2)
|
||||
t2v_14B.dim = 5120
|
||||
t2v_14B.ffn_dim = 13824
|
||||
t2v_14B.freq_dim = 256
|
||||
t2v_14B.num_heads = 40
|
||||
t2v_14B.num_layers = 40
|
||||
t2v_14B.window_size = (-1, -1)
|
||||
t2v_14B.qk_norm = True
|
||||
t2v_14B.cross_attn_norm = True
|
||||
t2v_14B.eps = 1e-6
|
29
video/Wan2.1/wan/configs/wan_t2v_1_3B.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
from easydict import EasyDict
|
||||
|
||||
from .shared_config import wan_shared_cfg
|
||||
|
||||
#------------------------ Wan T2V 1.3B ------------------------#
|
||||
|
||||
t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B')
|
||||
t2v_1_3B.update(wan_shared_cfg)
|
||||
|
||||
# t5
|
||||
t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.safetensors'
|
||||
t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
|
||||
|
||||
# vae
|
||||
t2v_1_3B.vae_checkpoint = 'vae_mlx.safetensors'
|
||||
t2v_1_3B.vae_stride = (4, 8, 8)
|
||||
|
||||
# transformer
|
||||
t2v_1_3B.patch_size = (1, 2, 2)
|
||||
t2v_1_3B.dim = 1536
|
||||
t2v_1_3B.ffn_dim = 8960
|
||||
t2v_1_3B.freq_dim = 256
|
||||
t2v_1_3B.num_heads = 12
|
||||
t2v_1_3B.num_layers = 30
|
||||
t2v_1_3B.window_size = (-1, -1)
|
||||
t2v_1_3B.qk_norm = True
|
||||
t2v_1_3B.cross_attn_norm = True
|
||||
t2v_1_3B.eps = 1e-6
|
14
video/Wan2.1/wan/modules/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
from .model_mlx import WanModel
|
||||
from .t5_mlx import T5Decoder, T5Encoder, T5EncoderModel, T5Model
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
from .vae_mlx import WanVAE
|
||||
|
||||
__all__ = [
|
||||
'WanVAE',
|
||||
'WanModel',
|
||||
'T5Model',
|
||||
'T5Encoder',
|
||||
'T5Decoder',
|
||||
'T5EncoderModel',
|
||||
'HuggingfaceTokenizer',
|
||||
]
|
787
video/Wan2.1/wan/modules/model_mlx.py
Normal file
@ -0,0 +1,787 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
# MLX Implementation of WAN Model - True 1:1 Port from PyTorch
|
||||
|
||||
import math
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['WanModel']
|
||||
|
||||
|
||||
def sinusoidal_embedding_1d(dim: int, position: mx.array) -> mx.array:
|
||||
"""Generate sinusoidal position embeddings."""
|
||||
assert dim % 2 == 0
|
||||
half = dim // 2
|
||||
position = position.astype(mx.float32)
|
||||
|
||||
# Calculate sinusoidal embeddings
|
||||
div_term = mx.power(10000, mx.arange(half).astype(mx.float32) / half)
|
||||
sinusoid = mx.expand_dims(position, 1) / mx.expand_dims(div_term, 0)
|
||||
|
||||
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=1)
|
||||
|
||||
|
||||
def rope_params(max_seq_len: int, dim: int, theta: float = 10000) -> mx.array:
|
||||
"""Generate RoPE (Rotary Position Embedding) parameters."""
|
||||
assert dim % 2 == 0
|
||||
positions = mx.arange(max_seq_len)
|
||||
freqs = mx.arange(0, dim, 2).astype(mx.float32) / dim
|
||||
freqs = 1.0 / mx.power(theta, freqs)
|
||||
|
||||
# Outer product
|
||||
freqs = mx.expand_dims(positions, 1) * mx.expand_dims(freqs, 0)
|
||||
|
||||
# Convert to complex representation
|
||||
return mx.stack([mx.cos(freqs), mx.sin(freqs)], axis=-1)
|
||||
|
||||
|
||||
def rope_apply(x: mx.array, grid_sizes: mx.array, freqs: mx.array) -> mx.array:
|
||||
"""Apply rotary position embeddings to input tensor."""
|
||||
n, c_half = x.shape[2], x.shape[3] // 2
|
||||
|
||||
# Split frequencies for different dimensions
|
||||
c_split = c_half - 2 * (c_half // 3)
|
||||
freqs_splits = [
|
||||
freqs[:, :c_split],
|
||||
freqs[:, c_split:c_split + c_half // 3],
|
||||
freqs[:, c_split + c_half // 3:]
|
||||
]
|
||||
|
||||
output = []
|
||||
for i in range(grid_sizes.shape[0]):
|
||||
f, h, w = int(grid_sizes[i, 0]), int(grid_sizes[i, 1]), int(grid_sizes[i, 2])
|
||||
seq_len = f * h * w
|
||||
|
||||
# Extract sequence for current sample
|
||||
x_i = x[i, :seq_len].astype(mx.float32)
|
||||
x_i = x_i.reshape(seq_len, n, -1, 2)
|
||||
|
||||
# Prepare frequency tensors
|
||||
freqs_f = freqs_splits[0][:f].reshape(f, 1, 1, -1, 2)
|
||||
freqs_f = mx.broadcast_to(freqs_f, (f, h, w, freqs_f.shape[-2], 2))
|
||||
|
||||
freqs_h = freqs_splits[1][:h].reshape(1, h, 1, -1, 2)
|
||||
freqs_h = mx.broadcast_to(freqs_h, (f, h, w, freqs_h.shape[-2], 2))
|
||||
|
||||
freqs_w = freqs_splits[2][:w].reshape(1, 1, w, -1, 2)
|
||||
freqs_w = mx.broadcast_to(freqs_w, (f, h, w, freqs_w.shape[-2], 2))
|
||||
|
||||
# Concatenate and reshape frequencies
|
||||
freqs_i = mx.concatenate([freqs_f, freqs_h, freqs_w], axis=-2)
|
||||
freqs_i = freqs_i.reshape(seq_len, 1, -1, 2)
|
||||
|
||||
# Apply rotary embedding
|
||||
x_real = x_i[..., 0]
|
||||
x_imag = x_i[..., 1]
|
||||
freqs_cos = freqs_i[..., 0]
|
||||
freqs_sin = freqs_i[..., 1]
|
||||
|
||||
x_rotated_real = x_real * freqs_cos - x_imag * freqs_sin
|
||||
x_rotated_imag = x_real * freqs_sin + x_imag * freqs_cos
|
||||
|
||||
x_i = mx.stack([x_rotated_real, x_rotated_imag], axis=-1).reshape(seq_len, n, -1)
|
||||
|
||||
# Concatenate with remaining sequence if any
|
||||
if seq_len < x.shape[1]:
|
||||
x_i = mx.concatenate([x_i, x[i, seq_len:]], axis=0)
|
||||
|
||||
output.append(x_i)
|
||||
|
||||
return mx.stack(output).astype(x.dtype)
|
||||
|
||||
|
||||
class WanRMSNorm(nn.Module):
|
||||
"""Root Mean Square Layer Normalization."""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = mx.ones((dim,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# RMS normalization
|
||||
variance = mx.mean(mx.square(x.astype(mx.float32)), axis=-1, keepdims=True)
|
||||
x_normed = x * mx.rsqrt(variance + self.eps)
|
||||
return (x_normed * self.weight).astype(x.dtype)
|
||||
|
||||
|
||||
class WanLayerNorm(nn.Module):
|
||||
"""Layer normalization."""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
|
||||
if elementwise_affine:
|
||||
self.weight = mx.ones((dim,))
|
||||
self.bias = mx.zeros((dim,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# Standard layer normalization
|
||||
x_float = x.astype(mx.float32)
|
||||
mean = mx.mean(x_float, axis=-1, keepdims=True)
|
||||
variance = mx.var(x_float, axis=-1, keepdims=True)
|
||||
x_normed = (x_float - mean) * mx.rsqrt(variance + self.eps)
|
||||
|
||||
if self.elementwise_affine:
|
||||
x_normed = x_normed * self.weight + self.bias
|
||||
|
||||
return x_normed.astype(x.dtype)
|
||||
|
||||
|
||||
def mlx_attention(
|
||||
q: mx.array,
|
||||
k: mx.array,
|
||||
v: mx.array,
|
||||
q_lens: Optional[mx.array] = None,
|
||||
k_lens: Optional[mx.array] = None,
|
||||
dropout_p: float = 0.,
|
||||
softmax_scale: Optional[float] = None,
|
||||
q_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
window_size: Tuple[int, int] = (-1, -1),
|
||||
deterministic: bool = False,
|
||||
dtype: Optional[type] = None,
|
||||
) -> mx.array:
|
||||
"""
|
||||
MLX implementation of scaled dot-product attention.
|
||||
"""
|
||||
# Get shapes
|
||||
b, lq, n, d = q.shape
|
||||
_, lk, _, _ = k.shape
|
||||
|
||||
# Scale queries if needed
|
||||
if q_scale is not None:
|
||||
q = q * q_scale
|
||||
|
||||
# Compute attention scores
|
||||
q = q.transpose(0, 2, 1, 3) # [b, n, lq, d]
|
||||
k = k.transpose(0, 2, 1, 3) # [b, n, lk, d]
|
||||
v = v.transpose(0, 2, 1, 3) # [b, n, lk, d]
|
||||
|
||||
# Compute attention scores
|
||||
scores = mx.matmul(q, k.transpose(0, 1, 3, 2)) # [b, n, lq, lk]
|
||||
|
||||
# Apply softmax scale if provided
|
||||
if softmax_scale is not None:
|
||||
scores = scores * softmax_scale
|
||||
else:
|
||||
# Default scaling by sqrt(d)
|
||||
scores = scores / mx.sqrt(mx.array(d, dtype=scores.dtype))
|
||||
|
||||
# Create attention mask
|
||||
attn_mask = None
|
||||
|
||||
# Apply window size masking if specified
|
||||
if window_size != (-1, -1):
|
||||
left_window, right_window = window_size
|
||||
window_mask = mx.zeros((lq, lk))
|
||||
for i in range(lq):
|
||||
start = max(0, i - left_window)
|
||||
end = min(lk, i + right_window + 1)
|
||||
window_mask[i, start:end] = 1
|
||||
attn_mask = window_mask
|
||||
|
||||
# Apply causal masking if needed
|
||||
if causal:
|
||||
causal_mask = mx.tril(mx.ones((lq, lk)), k=0)
|
||||
if attn_mask is None:
|
||||
attn_mask = causal_mask
|
||||
else:
|
||||
attn_mask = mx.logical_and(attn_mask, causal_mask)
|
||||
|
||||
# Apply attention mask if present
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.astype(scores.dtype)
|
||||
scores = scores * attn_mask + (1 - attn_mask) * -1e4
|
||||
|
||||
# Apply attention mask if lengths are provided
|
||||
if q_lens is not None or k_lens is not None:
|
||||
if q_lens is not None:
|
||||
mask = mx.arange(lq)[None, :] < q_lens[:, None]
|
||||
mask = mask.astype(scores.dtype)
|
||||
scores = scores * mask[:, None, :, None] + (1 - mask[:, None, :, None]) * -1e4
|
||||
if k_lens is not None:
|
||||
mask = mx.arange(lk)[None, :] < k_lens[:, None]
|
||||
mask = mask.astype(scores.dtype)
|
||||
scores = scores * mask[:, None, None, :] + (1 - mask[:, None, None, :]) * -1e4
|
||||
|
||||
# Apply softmax
|
||||
max_scores = mx.max(scores, axis=-1, keepdims=True)
|
||||
scores = scores - max_scores
|
||||
exp_scores = mx.exp(scores)
|
||||
sum_exp = mx.sum(exp_scores, axis=-1, keepdims=True)
|
||||
attn = exp_scores / (sum_exp + 1e-6)
|
||||
|
||||
# Apply dropout if needed
|
||||
if dropout_p > 0 and not deterministic:
|
||||
raise NotImplementedError("Dropout not implemented in MLX version")
|
||||
|
||||
# Compute output
|
||||
out = mx.matmul(attn, v) # [b, n, lq, d]
|
||||
out = out.transpose(0, 2, 1, 3) # [b, lq, n, d]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class WanSelfAttention(nn.Module):
|
||||
"""Self-attention module with RoPE and optional QK normalization."""
|
||||
|
||||
def __init__(self, dim: int, num_heads: int, window_size: Tuple[int, int] = (-1, -1),
|
||||
qk_norm: bool = True, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.eps = eps
|
||||
|
||||
# Linear projections
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(dim, dim)
|
||||
self.v = nn.Linear(dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
|
||||
# Normalization layers
|
||||
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||
|
||||
def __call__(self, x: mx.array, seq_lens: mx.array, grid_sizes: mx.array,
|
||||
freqs: mx.array) -> mx.array:
|
||||
b, s = x.shape[0], x.shape[1]
|
||||
|
||||
# Compute Q, K, V
|
||||
q = self.q(x)
|
||||
k = self.k(x)
|
||||
v = self.v(x)
|
||||
|
||||
if self.qk_norm:
|
||||
q = self.norm_q(q)
|
||||
k = self.norm_k(k)
|
||||
|
||||
# Reshape for multi-head attention
|
||||
q = q.reshape(b, s, self.num_heads, self.head_dim)
|
||||
k = k.reshape(b, s, self.num_heads, self.head_dim)
|
||||
v = v.reshape(b, s, self.num_heads, self.head_dim)
|
||||
|
||||
# Apply RoPE
|
||||
q = rope_apply(q, grid_sizes, freqs)
|
||||
k = rope_apply(k, grid_sizes, freqs)
|
||||
|
||||
# Apply attention
|
||||
x = mlx_attention(q, k, v, k_lens=seq_lens, window_size=self.window_size)
|
||||
|
||||
# Reshape and project output
|
||||
x = x.reshape(b, s, self.dim)
|
||||
x = self.o(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class WanT2VCrossAttention(WanSelfAttention):
|
||||
"""Text-to-video cross attention."""
|
||||
|
||||
def __call__(self, x: mx.array, context: mx.array, context_lens: mx.array) -> mx.array:
|
||||
b = x.shape[0]
|
||||
|
||||
# Compute queries from x
|
||||
q = self.q(x)
|
||||
if self.qk_norm:
|
||||
q = self.norm_q(q)
|
||||
q = q.reshape(b, -1, self.num_heads, self.head_dim)
|
||||
|
||||
# Compute keys and values from context
|
||||
k = self.k(context)
|
||||
v = self.v(context)
|
||||
if self.qk_norm:
|
||||
k = self.norm_k(k)
|
||||
k = k.reshape(b, -1, self.num_heads, self.head_dim)
|
||||
v = v.reshape(b, -1, self.num_heads, self.head_dim)
|
||||
|
||||
# Apply attention
|
||||
x = mlx_attention(q, k, v, k_lens=context_lens)
|
||||
|
||||
# Reshape and project output
|
||||
x = x.reshape(b, -1, self.dim)
|
||||
x = self.o(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class WanI2VCrossAttention(WanSelfAttention):
|
||||
"""Image-to-video cross attention."""
|
||||
|
||||
def __init__(self, dim: int, num_heads: int, window_size: Tuple[int, int] = (-1, -1),
|
||||
qk_norm: bool = True, eps: float = 1e-6):
|
||||
super().__init__(dim, num_heads, window_size, qk_norm, eps)
|
||||
|
||||
self.k_img = nn.Linear(dim, dim)
|
||||
self.v_img = nn.Linear(dim, dim)
|
||||
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||
|
||||
def __call__(self, x: mx.array, context: mx.array, context_lens: mx.array) -> mx.array:
|
||||
# Split context into image and text parts
|
||||
context_img = context[:, :257]
|
||||
context = context[:, 257:]
|
||||
|
||||
b = x.shape[0]
|
||||
|
||||
# Compute queries
|
||||
q = self.q(x)
|
||||
if self.qk_norm:
|
||||
q = self.norm_q(q)
|
||||
q = q.reshape(b, -1, self.num_heads, self.head_dim)
|
||||
|
||||
# Compute keys and values for text
|
||||
k = self.k(context)
|
||||
v = self.v(context)
|
||||
if self.qk_norm:
|
||||
k = self.norm_k(k)
|
||||
k = k.reshape(b, -1, self.num_heads, self.head_dim)
|
||||
v = v.reshape(b, -1, self.num_heads, self.head_dim)
|
||||
|
||||
# Compute keys and values for image
|
||||
k_img = self.k_img(context_img)
|
||||
v_img = self.v_img(context_img)
|
||||
if self.qk_norm:
|
||||
k_img = self.norm_k_img(k_img)
|
||||
k_img = k_img.reshape(b, -1, self.num_heads, self.head_dim)
|
||||
v_img = v_img.reshape(b, -1, self.num_heads, self.head_dim)
|
||||
|
||||
# Apply attention
|
||||
img_x = mlx_attention(q, k_img, v_img, k_lens=None)
|
||||
x = mlx_attention(q, k, v, k_lens=context_lens)
|
||||
|
||||
# Combine and project
|
||||
img_x = img_x.reshape(b, -1, self.dim)
|
||||
x = x.reshape(b, -1, self.dim)
|
||||
x = x + img_x
|
||||
x = self.o(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
WAN_CROSSATTENTION_CLASSES = {
|
||||
't2v_cross_attn': WanT2VCrossAttention,
|
||||
'i2v_cross_attn': WanI2VCrossAttention,
|
||||
}
|
||||
|
||||
|
||||
class WanAttentionBlock(nn.Module):
|
||||
"""Transformer block with self-attention, cross-attention, and FFN."""
|
||||
|
||||
def __init__(self, cross_attn_type: str, dim: int, ffn_dim: int, num_heads: int,
|
||||
window_size: Tuple[int, int] = (-1, -1), qk_norm: bool = True,
|
||||
cross_attn_norm: bool = False, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.cross_attn_norm = cross_attn_norm
|
||||
self.eps = eps
|
||||
|
||||
# Layers
|
||||
self.norm1 = WanLayerNorm(dim, eps)
|
||||
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
|
||||
|
||||
self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
||||
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](
|
||||
dim, num_heads, (-1, -1), qk_norm, eps)
|
||||
|
||||
self.norm2 = WanLayerNorm(dim, eps)
|
||||
|
||||
# FFN - use a list instead of Sequential to match PyTorch exactly!
|
||||
self.ffn = [
|
||||
nn.Linear(dim, ffn_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(ffn_dim, dim)
|
||||
]
|
||||
|
||||
# Modulation parameters
|
||||
self.modulation = mx.random.normal((1, 6, dim)) / math.sqrt(dim)
|
||||
|
||||
def __call__(self, x: mx.array, e: mx.array, seq_lens: mx.array,
|
||||
grid_sizes: mx.array, freqs: mx.array, context: mx.array,
|
||||
context_lens: Optional[mx.array]) -> mx.array:
|
||||
# Apply modulation
|
||||
e = (self.modulation + e).astype(mx.float32)
|
||||
e_chunks = [mx.squeeze(chunk, axis=1) for chunk in mx.split(e, 6, axis=1)]
|
||||
|
||||
# Self-attention with modulation
|
||||
y = self.norm1(x).astype(mx.float32)
|
||||
y = y * (1 + e_chunks[1]) + e_chunks[0]
|
||||
y = self.self_attn(y, seq_lens, grid_sizes, freqs)
|
||||
x = x + y * e_chunks[2]
|
||||
|
||||
# Cross-attention
|
||||
if self.cross_attn_norm and isinstance(self.norm3, WanLayerNorm):
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
||||
else:
|
||||
x = x + self.cross_attn(x, context, context_lens)
|
||||
|
||||
# FFN with modulation
|
||||
y = self.norm2(x).astype(mx.float32)
|
||||
y = y * (1 + e_chunks[4]) + e_chunks[3]
|
||||
|
||||
# Apply FFN layers manually
|
||||
y = self.ffn[0](y) # Linear
|
||||
y = self.ffn[1](y) # GELU
|
||||
y = self.ffn[2](y) # Linear
|
||||
|
||||
x = x + y * e_chunks[5]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
"""Output head for final projection."""
|
||||
|
||||
def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int],
|
||||
eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.out_dim = out_dim
|
||||
self.patch_size = patch_size
|
||||
self.eps = eps
|
||||
|
||||
# Output projection
|
||||
out_features = int(np.prod(patch_size)) * out_dim
|
||||
self.norm = WanLayerNorm(dim, eps)
|
||||
self.head = nn.Linear(dim, out_features)
|
||||
|
||||
# Modulation
|
||||
self.modulation = mx.random.normal((1, 2, dim)) / math.sqrt(dim)
|
||||
|
||||
def __call__(self, x: mx.array, e: mx.array) -> mx.array:
|
||||
# Apply modulation
|
||||
e = (self.modulation + mx.expand_dims(e, 1)).astype(mx.float32)
|
||||
e_chunks = mx.split(e, 2, axis=1)
|
||||
|
||||
# Apply normalization and projection with modulation
|
||||
x = self.norm(x) * (1 + e_chunks[1]) + e_chunks[0]
|
||||
x = self.head(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MLPProj(nn.Module):
|
||||
"""MLP projection for image embeddings."""
|
||||
|
||||
def __init__(self, in_dim: int, out_dim: int):
|
||||
super().__init__()
|
||||
|
||||
# Use a list to match PyTorch Sequential indexing
|
||||
self.proj = [
|
||||
nn.LayerNorm(in_dim),
|
||||
nn.Linear(in_dim, in_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(in_dim, out_dim),
|
||||
nn.LayerNorm(out_dim)
|
||||
]
|
||||
|
||||
def __call__(self, image_embeds: mx.array) -> mx.array:
|
||||
x = image_embeds
|
||||
for layer in self.proj:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class WanModel(nn.Module):
|
||||
"""
|
||||
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||
MLX implementation - True 1:1 port from PyTorch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type: str = 't2v',
|
||||
patch_size: Tuple[int, int, int] = (1, 2, 2),
|
||||
text_len: int = 512,
|
||||
in_dim: int = 16,
|
||||
dim: int = 2048,
|
||||
ffn_dim: int = 8192,
|
||||
freq_dim: int = 256,
|
||||
text_dim: int = 4096,
|
||||
out_dim: int = 16,
|
||||
num_heads: int = 16,
|
||||
num_layers: int = 32,
|
||||
window_size: Tuple[int, int] = (-1, -1),
|
||||
qk_norm: bool = True,
|
||||
cross_attn_norm: bool = True,
|
||||
eps: float = 1e-6
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert model_type in ['t2v', 'i2v']
|
||||
self.model_type = model_type
|
||||
|
||||
# Store configuration
|
||||
self.patch_size = patch_size
|
||||
self.text_len = text_len
|
||||
self.in_dim = in_dim
|
||||
self.dim = dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.freq_dim = freq_dim
|
||||
self.text_dim = text_dim
|
||||
self.out_dim = out_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.cross_attn_norm = cross_attn_norm
|
||||
self.eps = eps
|
||||
|
||||
# Embeddings
|
||||
self.patch_embedding = nn.Conv3d(
|
||||
in_dim, dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size
|
||||
)
|
||||
|
||||
# Use lists instead of Sequential to match PyTorch!
|
||||
self.text_embedding = [
|
||||
nn.Linear(text_dim, dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim, dim)
|
||||
]
|
||||
|
||||
self.time_embedding = [
|
||||
nn.Linear(freq_dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim)
|
||||
]
|
||||
|
||||
self.time_projection = [
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim * 6)
|
||||
]
|
||||
|
||||
# Transformer blocks
|
||||
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
||||
self.blocks = [
|
||||
WanAttentionBlock(
|
||||
cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size, qk_norm, cross_attn_norm, eps
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
|
||||
# Output head
|
||||
self.head = Head(dim, out_dim, patch_size, eps)
|
||||
|
||||
# Precompute RoPE frequencies
|
||||
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
||||
d = dim // num_heads
|
||||
self.freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6))
|
||||
], axis=1)
|
||||
|
||||
# Image embedding for i2v
|
||||
if model_type == 'i2v':
|
||||
self.img_emb = MLPProj(1280, dim)
|
||||
|
||||
# Initialize weights
|
||||
self.init_weights()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: List[mx.array],
|
||||
t: mx.array,
|
||||
context: List[mx.array],
|
||||
seq_len: int,
|
||||
clip_fea: Optional[mx.array] = None,
|
||||
y: Optional[List[mx.array]] = None
|
||||
) -> List[mx.array]:
|
||||
"""
|
||||
Forward pass through the diffusion model.
|
||||
|
||||
Args:
|
||||
x: List of input video tensors [C_in, F, H, W]
|
||||
t: Diffusion timesteps [B]
|
||||
context: List of text embeddings [L, C]
|
||||
seq_len: Maximum sequence length
|
||||
clip_fea: CLIP image features for i2v mode
|
||||
y: Conditional video inputs for i2v mode
|
||||
|
||||
Returns:
|
||||
List of denoised video tensors [C_out, F, H/8, W/8]
|
||||
"""
|
||||
if self.model_type == 'i2v':
|
||||
assert clip_fea is not None and y is not None
|
||||
|
||||
# Concatenate conditional inputs if provided
|
||||
if y is not None:
|
||||
x = [mx.concatenate([u, v], axis=0) for u, v in zip(x, y)]
|
||||
|
||||
# Patch embedding
|
||||
x = [mx.transpose(mx.expand_dims(u, 0), (0, 2, 3, 4, 1)) for u in x]
|
||||
x = [self.patch_embedding(u) for u in x]
|
||||
# Transpose back from MLX format (N, D, H, W, C) to (N, C, D, H, W) for the rest of the model
|
||||
x = [mx.transpose(u, (0, 4, 1, 2, 3)) for u in x]
|
||||
grid_sizes = mx.array([[u.shape[2], u.shape[3], u.shape[4]] for u in x])
|
||||
|
||||
# Flatten spatial dimensions
|
||||
x = [mx.transpose(u.reshape(u.shape[0], u.shape[1], -1), (0, 2, 1)) for u in x]
|
||||
seq_lens = mx.array([u.shape[1] for u in x])
|
||||
|
||||
# Pad sequences to max length
|
||||
x_padded = []
|
||||
for u in x:
|
||||
if u.shape[1] < seq_len:
|
||||
padding = mx.zeros((1, seq_len - u.shape[1], u.shape[2]))
|
||||
u = mx.concatenate([u, padding], axis=1)
|
||||
x_padded.append(u)
|
||||
x = mx.concatenate(x_padded, axis=0)
|
||||
|
||||
# Time embeddings - apply layers manually
|
||||
e = sinusoidal_embedding_1d(self.freq_dim, t).astype(mx.float32)
|
||||
e = self.time_embedding[0](e) # Linear
|
||||
e = self.time_embedding[1](e) # SiLU
|
||||
e = self.time_embedding[2](e) # Linear
|
||||
|
||||
# Time projection
|
||||
e = self.time_projection[0](e) # SiLU
|
||||
e0 = self.time_projection[1](e).reshape(-1, 6, self.dim) # Linear
|
||||
|
||||
# Process context
|
||||
context_lens = None
|
||||
context_padded = []
|
||||
for u in context:
|
||||
if u.shape[0] < self.text_len:
|
||||
padding = mx.zeros((self.text_len - u.shape[0], u.shape[1]))
|
||||
u = mx.concatenate([u, padding], axis=0)
|
||||
context_padded.append(u)
|
||||
context = mx.stack(context_padded)
|
||||
|
||||
# Apply text embedding layers manually
|
||||
context = self.text_embedding[0](context) # Linear
|
||||
context = self.text_embedding[1](context) # GELU
|
||||
context = self.text_embedding[2](context) # Linear
|
||||
|
||||
# Add image embeddings for i2v
|
||||
if clip_fea is not None:
|
||||
context_clip = self.img_emb(clip_fea)
|
||||
context = mx.concatenate([context_clip, context], axis=1)
|
||||
|
||||
# Apply transformer blocks
|
||||
for block in self.blocks:
|
||||
x = block(
|
||||
x, e0, seq_lens, grid_sizes, self.freqs,
|
||||
context, context_lens
|
||||
)
|
||||
|
||||
# Apply output head
|
||||
x = self.head(x, e)
|
||||
|
||||
# Unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
|
||||
return [u.astype(mx.float32) for u in x]
|
||||
|
||||
def unpatchify(self, x: mx.array, grid_sizes: mx.array) -> List[mx.array]:
|
||||
"""Reconstruct video tensors from patch embeddings."""
|
||||
c = self.out_dim
|
||||
out = []
|
||||
|
||||
for i in range(grid_sizes.shape[0]):
|
||||
f, h, w = int(grid_sizes[i, 0]), int(grid_sizes[i, 1]), int(grid_sizes[i, 2])
|
||||
seq_len = f * h * w
|
||||
|
||||
# Extract relevant sequence
|
||||
u = x[i, :seq_len]
|
||||
|
||||
# Reshape to grid with patches
|
||||
pf, ph, pw = self.patch_size
|
||||
u = u.reshape(f, h, w, pf, ph, pw, c)
|
||||
|
||||
# Rearrange dimensions
|
||||
u = mx.transpose(u, (6, 0, 3, 1, 4, 2, 5))
|
||||
|
||||
# Combine patches
|
||||
u = u.reshape(c, f * pf, h * ph, w * pw)
|
||||
|
||||
out.append(u)
|
||||
|
||||
return out
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize model parameters using Xavier/He initialization."""
|
||||
# Note: MLX doesn't have nn.init like PyTorch, so we manually initialize
|
||||
|
||||
# Helper function for Xavier uniform initialization
|
||||
def xavier_uniform(shape):
|
||||
bound = mx.sqrt(6.0 / (shape[0] + shape[1]))
|
||||
return mx.random.uniform(low=-bound, high=bound, shape=shape)
|
||||
|
||||
# Initialize linear layers in blocks
|
||||
for block in self.blocks:
|
||||
# Self attention
|
||||
block.self_attn.q.weight = xavier_uniform(block.self_attn.q.weight.shape)
|
||||
block.self_attn.k.weight = xavier_uniform(block.self_attn.k.weight.shape)
|
||||
block.self_attn.v.weight = xavier_uniform(block.self_attn.v.weight.shape)
|
||||
block.self_attn.o.weight = xavier_uniform(block.self_attn.o.weight.shape)
|
||||
|
||||
# Cross attention
|
||||
block.cross_attn.q.weight = xavier_uniform(block.cross_attn.q.weight.shape)
|
||||
block.cross_attn.k.weight = xavier_uniform(block.cross_attn.k.weight.shape)
|
||||
block.cross_attn.v.weight = xavier_uniform(block.cross_attn.v.weight.shape)
|
||||
block.cross_attn.o.weight = xavier_uniform(block.cross_attn.o.weight.shape)
|
||||
|
||||
# FFN layers - now it's a list!
|
||||
block.ffn[0].weight = xavier_uniform(block.ffn[0].weight.shape)
|
||||
block.ffn[2].weight = xavier_uniform(block.ffn[2].weight.shape)
|
||||
|
||||
# Modulation
|
||||
block.modulation = mx.random.normal(
|
||||
shape=(1, 6, self.dim),
|
||||
scale=1.0 / math.sqrt(self.dim)
|
||||
)
|
||||
|
||||
# Special initialization for embeddings
|
||||
# Patch embedding - Xavier uniform
|
||||
weight_shape = self.patch_embedding.weight.shape
|
||||
fan_in = weight_shape[1] * np.prod(self.patch_size)
|
||||
fan_out = weight_shape[0]
|
||||
bound = mx.sqrt(6.0 / (fan_in + fan_out))
|
||||
self.patch_embedding.weight = mx.random.uniform(
|
||||
low=-bound,
|
||||
high=bound,
|
||||
shape=weight_shape
|
||||
)
|
||||
|
||||
# Text embedding - normal distribution with std=0.02
|
||||
self.text_embedding[0].weight = mx.random.normal(shape=self.text_embedding[0].weight.shape, scale=0.02)
|
||||
self.text_embedding[2].weight = mx.random.normal(shape=self.text_embedding[2].weight.shape, scale=0.02)
|
||||
|
||||
# Time embedding - normal distribution with std=0.02
|
||||
self.time_embedding[0].weight = mx.random.normal(shape=self.time_embedding[0].weight.shape, scale=0.02)
|
||||
self.time_embedding[2].weight = mx.random.normal(shape=self.time_embedding[2].weight.shape, scale=0.02)
|
||||
|
||||
# Output head - initialize to zeros
|
||||
self.head.head.weight = mx.zeros(self.head.head.weight.shape)
|
||||
|
||||
# Head modulation
|
||||
self.head.modulation = mx.random.normal(
|
||||
shape=(1, 2, self.dim),
|
||||
scale=1.0 / math.sqrt(self.dim)
|
||||
)
|
||||
|
||||
# Initialize i2v specific layers if present
|
||||
if self.model_type == 'i2v':
|
||||
for i in [1, 3]: # Linear layers in the proj list
|
||||
if isinstance(self.img_emb.proj[i], nn.Linear):
|
||||
self.img_emb.proj[i].weight = xavier_uniform(self.img_emb.proj[i].weight.shape)
|
513
video/Wan2.1/wan/modules/t5.py
Normal file
@ -0,0 +1,513 @@
|
||||
# Modified from transformers.models.t5.modeling_t5
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import logging
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
|
||||
__all__ = [
|
||||
'T5Model',
|
||||
'T5Encoder',
|
||||
'T5Decoder',
|
||||
'T5EncoderModel',
|
||||
]
|
||||
|
||||
|
||||
def fp16_clamp(x):
|
||||
if x.dtype == torch.float16 and torch.isinf(x).any():
|
||||
clamp = torch.finfo(x.dtype).max - 1000
|
||||
x = torch.clamp(x, min=-clamp, max=clamp)
|
||||
return x
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, T5LayerNorm):
|
||||
nn.init.ones_(m.weight)
|
||||
elif isinstance(m, T5Model):
|
||||
nn.init.normal_(m.token_embedding.weight, std=1.0)
|
||||
elif isinstance(m, T5FeedForward):
|
||||
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
|
||||
elif isinstance(m, T5Attention):
|
||||
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
|
||||
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
|
||||
elif isinstance(m, T5RelativeEmbedding):
|
||||
nn.init.normal_(
|
||||
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return 0.5 * x * (1.0 + torch.tanh(
|
||||
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
||||
|
||||
|
||||
class T5LayerNorm(nn.Module):
|
||||
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
super(T5LayerNorm, self).__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
|
||||
self.eps)
|
||||
if self.weight.dtype in [torch.float16, torch.float32]:
|
||||
x = x.type_as(self.weight)
|
||||
return self.weight * x
|
||||
|
||||
|
||||
class T5Attention(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
|
||||
assert dim_attn % num_heads == 0
|
||||
super(T5Attention, self).__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim_attn // num_heads
|
||||
|
||||
# layers
|
||||
self.q = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.k = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.v = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.o = nn.Linear(dim_attn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x, context=None, mask=None, pos_bias=None):
|
||||
"""
|
||||
x: [B, L1, C].
|
||||
context: [B, L2, C] or None.
|
||||
mask: [B, L2] or [B, L1, L2] or None.
|
||||
"""
|
||||
# check inputs
|
||||
context = x if context is None else context
|
||||
b, n, c = x.size(0), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.q(x).view(b, -1, n, c)
|
||||
k = self.k(context).view(b, -1, n, c)
|
||||
v = self.v(context).view(b, -1, n, c)
|
||||
|
||||
# attention bias
|
||||
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
|
||||
if pos_bias is not None:
|
||||
attn_bias += pos_bias
|
||||
if mask is not None:
|
||||
assert mask.ndim in [2, 3]
|
||||
mask = mask.view(b, 1, 1,
|
||||
-1) if mask.ndim == 2 else mask.unsqueeze(1)
|
||||
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
|
||||
|
||||
# compute attention (T5 does not use scaling)
|
||||
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
|
||||
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
||||
x = torch.einsum('bnij,bjnc->binc', attn, v)
|
||||
|
||||
# output
|
||||
x = x.reshape(b, -1, n * c)
|
||||
x = self.o(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_ffn, dropout=0.1):
|
||||
super(T5FeedForward, self).__init__()
|
||||
self.dim = dim
|
||||
self.dim_ffn = dim_ffn
|
||||
|
||||
# layers
|
||||
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
|
||||
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x) * self.gate(x)
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5SelfAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.1):
|
||||
super(T5SelfAttention, self).__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.norm1 = T5LayerNorm(dim)
|
||||
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm2 = T5LayerNorm(dim)
|
||||
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
||||
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=True)
|
||||
|
||||
def forward(self, x, mask=None, pos_bias=None):
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(
|
||||
x.size(1), x.size(1))
|
||||
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
||||
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class T5CrossAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.1):
|
||||
super(T5CrossAttention, self).__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.norm1 = T5LayerNorm(dim)
|
||||
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm2 = T5LayerNorm(dim)
|
||||
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm3 = T5LayerNorm(dim)
|
||||
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
||||
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=False)
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
mask=None,
|
||||
encoder_states=None,
|
||||
encoder_mask=None,
|
||||
pos_bias=None):
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(
|
||||
x.size(1), x.size(1))
|
||||
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
|
||||
x = fp16_clamp(x + self.cross_attn(
|
||||
self.norm2(x), context=encoder_states, mask=encoder_mask))
|
||||
x = fp16_clamp(x + self.ffn(self.norm3(x)))
|
||||
return x
|
||||
|
||||
|
||||
class T5RelativeEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
||||
super(T5RelativeEmbedding, self).__init__()
|
||||
self.num_buckets = num_buckets
|
||||
self.num_heads = num_heads
|
||||
self.bidirectional = bidirectional
|
||||
self.max_dist = max_dist
|
||||
|
||||
# layers
|
||||
self.embedding = nn.Embedding(num_buckets, num_heads)
|
||||
|
||||
def forward(self, lq, lk):
|
||||
device = self.embedding.weight.device
|
||||
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
|
||||
# torch.arange(lq).unsqueeze(1).to(device)
|
||||
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
|
||||
torch.arange(lq, device=device).unsqueeze(1)
|
||||
rel_pos = self._relative_position_bucket(rel_pos)
|
||||
rel_pos_embeds = self.embedding(rel_pos)
|
||||
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
|
||||
0) # [1, N, Lq, Lk]
|
||||
return rel_pos_embeds.contiguous()
|
||||
|
||||
def _relative_position_bucket(self, rel_pos):
|
||||
# preprocess
|
||||
if self.bidirectional:
|
||||
num_buckets = self.num_buckets // 2
|
||||
rel_buckets = (rel_pos > 0).long() * num_buckets
|
||||
rel_pos = torch.abs(rel_pos)
|
||||
else:
|
||||
num_buckets = self.num_buckets
|
||||
rel_buckets = 0
|
||||
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
|
||||
|
||||
# embeddings for small and large positions
|
||||
max_exact = num_buckets // 2
|
||||
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
|
||||
math.log(self.max_dist / max_exact) *
|
||||
(num_buckets - max_exact)).long()
|
||||
rel_pos_large = torch.min(
|
||||
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
|
||||
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
|
||||
return rel_buckets
|
||||
|
||||
|
||||
class T5Encoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vocab,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_layers,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.1):
|
||||
super(T5Encoder, self).__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
|
||||
else nn.Embedding(vocab, dim)
|
||||
self.pos_embedding = T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=True) if shared_pos else None
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.blocks = nn.ModuleList([
|
||||
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
||||
shared_pos, dropout) for _ in range(num_layers)
|
||||
])
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
# initialize weights
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, ids, mask=None):
|
||||
x = self.token_embedding(ids)
|
||||
x = self.dropout(x)
|
||||
e = self.pos_embedding(x.size(1),
|
||||
x.size(1)) if self.shared_pos else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask, pos_bias=e)
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5Decoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vocab,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_layers,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.1):
|
||||
super(T5Decoder, self).__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
|
||||
else nn.Embedding(vocab, dim)
|
||||
self.pos_embedding = T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=False) if shared_pos else None
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.blocks = nn.ModuleList([
|
||||
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
||||
shared_pos, dropout) for _ in range(num_layers)
|
||||
])
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
# initialize weights
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
|
||||
b, s = ids.size()
|
||||
|
||||
# causal mask
|
||||
if mask is None:
|
||||
mask = torch.tril(torch.ones(1, s, s).to(ids.device))
|
||||
elif mask.ndim == 2:
|
||||
mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
|
||||
|
||||
# layers
|
||||
x = self.token_embedding(ids)
|
||||
x = self.dropout(x)
|
||||
e = self.pos_embedding(x.size(1),
|
||||
x.size(1)) if self.shared_pos else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5Model(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
encoder_layers,
|
||||
decoder_layers,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.1):
|
||||
super(T5Model, self).__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.encoder_layers = encoder_layers
|
||||
self.decoder_layers = decoder_layers
|
||||
self.num_buckets = num_buckets
|
||||
|
||||
# layers
|
||||
self.token_embedding = nn.Embedding(vocab_size, dim)
|
||||
self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
|
||||
num_heads, encoder_layers, num_buckets,
|
||||
shared_pos, dropout)
|
||||
self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
|
||||
num_heads, decoder_layers, num_buckets,
|
||||
shared_pos, dropout)
|
||||
self.head = nn.Linear(dim, vocab_size, bias=False)
|
||||
|
||||
# initialize weights
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
|
||||
x = self.encoder(encoder_ids, encoder_mask)
|
||||
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _t5(name,
|
||||
encoder_only=False,
|
||||
decoder_only=False,
|
||||
return_tokenizer=False,
|
||||
tokenizer_kwargs={},
|
||||
dtype=torch.float32,
|
||||
device='cpu',
|
||||
**kwargs):
|
||||
# sanity check
|
||||
assert not (encoder_only and decoder_only)
|
||||
|
||||
# params
|
||||
if encoder_only:
|
||||
model_cls = T5Encoder
|
||||
kwargs['vocab'] = kwargs.pop('vocab_size')
|
||||
kwargs['num_layers'] = kwargs.pop('encoder_layers')
|
||||
_ = kwargs.pop('decoder_layers')
|
||||
elif decoder_only:
|
||||
model_cls = T5Decoder
|
||||
kwargs['vocab'] = kwargs.pop('vocab_size')
|
||||
kwargs['num_layers'] = kwargs.pop('decoder_layers')
|
||||
_ = kwargs.pop('encoder_layers')
|
||||
else:
|
||||
model_cls = T5Model
|
||||
|
||||
# init model
|
||||
with torch.device(device):
|
||||
model = model_cls(**kwargs)
|
||||
|
||||
# set device
|
||||
model = model.to(dtype=dtype, device=device)
|
||||
|
||||
# init tokenizer
|
||||
if return_tokenizer:
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
|
||||
return model, tokenizer
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def umt5_xxl(**kwargs):
|
||||
cfg = dict(
|
||||
vocab_size=256384,
|
||||
dim=4096,
|
||||
dim_attn=4096,
|
||||
dim_ffn=10240,
|
||||
num_heads=64,
|
||||
encoder_layers=24,
|
||||
decoder_layers=24,
|
||||
num_buckets=32,
|
||||
shared_pos=False,
|
||||
dropout=0.1)
|
||||
cfg.update(**kwargs)
|
||||
return _t5('umt5-xxl', **cfg)
|
||||
|
||||
|
||||
class T5EncoderModel:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_len,
|
||||
dtype=torch.float32,
|
||||
device='mps' if torch.backends.mps.is_available() else 'cpu',
|
||||
checkpoint_path=None,
|
||||
tokenizer_path=None,
|
||||
shard_fn=None,
|
||||
):
|
||||
self.text_len = text_len
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.tokenizer_path = tokenizer_path
|
||||
|
||||
# init model
|
||||
model = umt5_xxl(
|
||||
encoder_only=True,
|
||||
return_tokenizer=False,
|
||||
dtype=dtype,
|
||||
device=device).eval().requires_grad_(False)
|
||||
logging.info(f'loading {checkpoint_path}')
|
||||
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
|
||||
self.model = model
|
||||
if shard_fn is not None:
|
||||
self.model = shard_fn(self.model, sync_module_states=False)
|
||||
else:
|
||||
self.model.to(self.device)
|
||||
# init tokenizer
|
||||
self.tokenizer = HuggingfaceTokenizer(
|
||||
name=tokenizer_path, seq_len=text_len, clean='whitespace')
|
||||
|
||||
def __call__(self, texts, device):
|
||||
ids, mask = self.tokenizer(
|
||||
texts, return_mask=True, add_special_tokens=True)
|
||||
ids = ids.to(device)
|
||||
mask = mask.to(device)
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
context = self.model(ids, mask)
|
||||
return [u[:v] for u, v in zip(context, seq_lens)]
|
617
video/Wan2.1/wan/modules/t5_mlx.py
Normal file
@ -0,0 +1,617 @@
|
||||
# Modified from transformers.models.t5.modeling_t5 for MLX
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
|
||||
__all__ = [
|
||||
'T5Model',
|
||||
'T5Encoder',
|
||||
'T5Decoder',
|
||||
'T5EncoderModel',
|
||||
]
|
||||
|
||||
|
||||
def fp16_clamp(x):
|
||||
if x.dtype == mx.float16:
|
||||
# Use same clamping as PyTorch for consistency
|
||||
clamp = 65504.0 # max value for float16
|
||||
return mx.clip(x, -clamp, clamp)
|
||||
return x
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __call__(self, x):
|
||||
return 0.5 * x * (1.0 + mx.tanh(
|
||||
math.sqrt(2.0 / math.pi) * (x + 0.044715 * mx.power(x, 3.0))))
|
||||
|
||||
|
||||
class T5LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = mx.ones((dim,))
|
||||
|
||||
def __call__(self, x):
|
||||
# Match PyTorch's approach: convert to float32 for stability
|
||||
x_float = x.astype(mx.float32) if x.dtype == mx.float16 else x
|
||||
variance = mx.mean(mx.square(x_float), axis=-1, keepdims=True)
|
||||
x_norm = x_float * mx.rsqrt(variance + self.eps)
|
||||
# Convert back to original dtype
|
||||
if x.dtype == mx.float16:
|
||||
x_norm = x_norm.astype(mx.float16)
|
||||
return self.weight * x_norm
|
||||
|
||||
|
||||
class T5Attention(nn.Module):
|
||||
def __init__(self, dim, dim_attn, num_heads, dropout=0.0):
|
||||
assert dim_attn % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim_attn // num_heads
|
||||
|
||||
# layers
|
||||
self.q = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.k = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.v = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.o = nn.Linear(dim_attn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def __call__(self, x, context=None, mask=None, pos_bias=None):
|
||||
"""
|
||||
x: [B, L1, C].
|
||||
context: [B, L2, C] or None.
|
||||
mask: [B, L2] or [B, L1, L2] or None.
|
||||
"""
|
||||
# check inputs
|
||||
context = x if context is None else context
|
||||
b, l1, _ = x.shape
|
||||
_, l2, _ = context.shape
|
||||
n, c = self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.q(x).reshape(b, l1, n, c)
|
||||
k = self.k(context).reshape(b, l2, n, c)
|
||||
v = self.v(context).reshape(b, l2, n, c)
|
||||
|
||||
# transpose for attention: [B, N, L, C]
|
||||
q = mx.transpose(q, (0, 2, 1, 3))
|
||||
k = mx.transpose(k, (0, 2, 1, 3))
|
||||
v = mx.transpose(v, (0, 2, 1, 3))
|
||||
|
||||
# compute attention (T5 does not use scaling)
|
||||
attn = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) # [B, N, L1, L2]
|
||||
|
||||
# add position bias if provided
|
||||
if pos_bias is not None:
|
||||
attn = attn + pos_bias
|
||||
|
||||
# apply mask
|
||||
if mask is not None:
|
||||
if mask.ndim == 2:
|
||||
# [B, L2] -> [B, 1, 1, L2]
|
||||
mask = mask[:, None, None, :]
|
||||
elif mask.ndim == 3:
|
||||
# [B, L1, L2] -> [B, 1, L1, L2]
|
||||
mask = mask[:, None, :, :]
|
||||
# Use very negative value that works well with float16
|
||||
min_value = -65504.0 if attn.dtype == mx.float16 else -1e9
|
||||
attn = mx.where(mask == 0, min_value, attn)
|
||||
|
||||
# softmax and apply attention
|
||||
attn = mx.softmax(attn.astype(mx.float32), axis=-1).astype(attn.dtype)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# apply attention to values
|
||||
x = mx.matmul(attn, v) # [B, N, L1, C]
|
||||
|
||||
# transpose back and reshape
|
||||
x = mx.transpose(x, (0, 2, 1, 3)) # [B, L1, N, C]
|
||||
x = x.reshape(b, l1, -1)
|
||||
|
||||
# output projection
|
||||
x = self.o(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_ffn, dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_ffn = dim_ffn
|
||||
|
||||
# layers
|
||||
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.gate_act = GELU()
|
||||
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def __call__(self, x):
|
||||
gate = self.gate_act(self.gate_proj(x))
|
||||
x = self.fc1(x) * gate
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5SelfAttention(nn.Module):
|
||||
def __init__(self,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.norm1 = T5LayerNorm(dim)
|
||||
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm2 = T5LayerNorm(dim)
|
||||
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
||||
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=True)
|
||||
|
||||
def __call__(self, x, mask=None, pos_bias=None):
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(
|
||||
x.shape[1], x.shape[1])
|
||||
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
||||
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class T5CrossAttention(nn.Module):
|
||||
def __init__(self,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.norm1 = T5LayerNorm(dim)
|
||||
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm2 = T5LayerNorm(dim)
|
||||
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm3 = T5LayerNorm(dim)
|
||||
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
||||
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=False)
|
||||
|
||||
def __call__(self,
|
||||
x,
|
||||
mask=None,
|
||||
encoder_states=None,
|
||||
encoder_mask=None,
|
||||
pos_bias=None):
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(
|
||||
x.shape[1], x.shape[1])
|
||||
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
|
||||
x = fp16_clamp(x + self.cross_attn(
|
||||
self.norm2(x), context=encoder_states, mask=encoder_mask))
|
||||
x = fp16_clamp(x + self.ffn(self.norm3(x)))
|
||||
return x
|
||||
|
||||
|
||||
class T5RelativeEmbedding(nn.Module):
|
||||
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
||||
super().__init__()
|
||||
self.num_buckets = num_buckets
|
||||
self.num_heads = num_heads
|
||||
self.bidirectional = bidirectional
|
||||
self.max_dist = max_dist
|
||||
|
||||
# layers
|
||||
self.embedding = nn.Embedding(num_buckets, num_heads)
|
||||
|
||||
def __call__(self, lq, lk):
|
||||
# Create relative position matrix
|
||||
positions_q = mx.arange(lq)[:, None]
|
||||
positions_k = mx.arange(lk)[None, :]
|
||||
rel_pos = positions_k - positions_q
|
||||
|
||||
# Apply bucketing
|
||||
rel_pos = self._relative_position_bucket(rel_pos)
|
||||
|
||||
# Get embeddings
|
||||
rel_pos_embeds = self.embedding(rel_pos)
|
||||
|
||||
# Reshape to [1, N, Lq, Lk]
|
||||
rel_pos_embeds = mx.transpose(rel_pos_embeds, (2, 0, 1))
|
||||
rel_pos_embeds = mx.expand_dims(rel_pos_embeds, 0)
|
||||
|
||||
return rel_pos_embeds
|
||||
|
||||
def _relative_position_bucket(self, rel_pos):
|
||||
# preprocess
|
||||
if self.bidirectional:
|
||||
num_buckets = self.num_buckets // 2
|
||||
rel_buckets = mx.array(rel_pos > 0, dtype=mx.int32) * num_buckets
|
||||
rel_pos = mx.abs(rel_pos)
|
||||
else:
|
||||
num_buckets = self.num_buckets
|
||||
rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32)
|
||||
rel_pos = -mx.minimum(rel_pos, mx.zeros_like(rel_pos))
|
||||
|
||||
# embeddings for small and large positions
|
||||
max_exact = num_buckets // 2
|
||||
is_small = rel_pos < max_exact
|
||||
|
||||
# For large positions, use log scale
|
||||
rel_pos_large = max_exact + (
|
||||
mx.log(mx.array(rel_pos, dtype=mx.float32) / max_exact) /
|
||||
math.log(self.max_dist / max_exact) *
|
||||
(num_buckets - max_exact)
|
||||
).astype(mx.int32)
|
||||
|
||||
rel_pos_large = mx.minimum(rel_pos_large, num_buckets - 1)
|
||||
|
||||
# Combine small and large position buckets
|
||||
rel_buckets = rel_buckets + mx.where(is_small, rel_pos, rel_pos_large)
|
||||
|
||||
return rel_buckets
|
||||
|
||||
|
||||
class T5Encoder(nn.Module):
|
||||
def __init__(self,
|
||||
vocab,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_layers,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
if isinstance(vocab, nn.Embedding):
|
||||
self.token_embedding = vocab
|
||||
else:
|
||||
self.token_embedding = nn.Embedding(vocab, dim)
|
||||
|
||||
self.pos_embedding = T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=True) if shared_pos else None
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.blocks = [
|
||||
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
||||
shared_pos, dropout) for _ in range(num_layers)
|
||||
]
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
def __call__(self, ids, mask=None):
|
||||
x = self.token_embedding(ids)
|
||||
x = self.dropout(x)
|
||||
e = self.pos_embedding(x.shape[1],
|
||||
x.shape[1]) if self.shared_pos else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask, pos_bias=e)
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5Decoder(nn.Module):
|
||||
def __init__(self,
|
||||
vocab,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_layers,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
if isinstance(vocab, nn.Embedding):
|
||||
self.token_embedding = vocab
|
||||
else:
|
||||
self.token_embedding = nn.Embedding(vocab, dim)
|
||||
|
||||
self.pos_embedding = T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=False) if shared_pos else None
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.blocks = [
|
||||
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
||||
shared_pos, dropout) for _ in range(num_layers)
|
||||
]
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
def __call__(self, ids, mask=None, encoder_states=None, encoder_mask=None):
|
||||
b, s = ids.shape
|
||||
|
||||
# causal mask
|
||||
if mask is None:
|
||||
mask = mx.tril(mx.ones((1, s, s)))
|
||||
elif mask.ndim == 2:
|
||||
# Expand mask properly
|
||||
mask = mx.tril(mx.expand_dims(mask, 1).broadcast_to((b, s, s)))
|
||||
|
||||
# layers
|
||||
x = self.token_embedding(ids)
|
||||
x = self.dropout(x)
|
||||
e = self.pos_embedding(x.shape[1],
|
||||
x.shape[1]) if self.shared_pos else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5Model(nn.Module):
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
encoder_layers,
|
||||
decoder_layers,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.encoder_layers = encoder_layers
|
||||
self.decoder_layers = decoder_layers
|
||||
self.num_buckets = num_buckets
|
||||
|
||||
# layers
|
||||
self.token_embedding = nn.Embedding(vocab_size, dim)
|
||||
self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
|
||||
num_heads, encoder_layers, num_buckets,
|
||||
shared_pos, dropout)
|
||||
self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
|
||||
num_heads, decoder_layers, num_buckets,
|
||||
shared_pos, dropout)
|
||||
self.head = nn.Linear(dim, vocab_size, bias=False)
|
||||
|
||||
def __call__(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
|
||||
x = self.encoder(encoder_ids, encoder_mask)
|
||||
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def init_mlx_weights(module, key):
|
||||
"""Initialize weights for T5 model components to match PyTorch initialization"""
|
||||
|
||||
def normal(key, shape, std=1.0):
|
||||
return mx.random.normal(key, shape) * std
|
||||
|
||||
if isinstance(module, T5LayerNorm):
|
||||
module.weight = mx.ones_like(module.weight)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
key = mx.random.split(key, 1)[0]
|
||||
module.weight = normal(key, module.weight.shape, std=1.0)
|
||||
elif isinstance(module, T5FeedForward):
|
||||
# Match PyTorch initialization
|
||||
key1, key2, key3 = mx.random.split(key, 3)
|
||||
module.gate_proj.weight = normal(key1, module.gate_proj.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.fc1.weight = normal(key2, module.fc1.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.fc2.weight = normal(key3, module.fc2.weight.shape,
|
||||
std=module.dim_ffn**-0.5)
|
||||
elif isinstance(module, T5Attention):
|
||||
# Match PyTorch initialization
|
||||
key1, key2, key3, key4 = random.split(key, 4)
|
||||
module.q.weight = normal(key1, module.q.weight.shape,
|
||||
std=(module.dim * module.dim_attn)**-0.5)
|
||||
module.k.weight = normal(key2, module.k.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.v.weight = normal(key3, module.v.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.o.weight = normal(key4, module.o.weight.shape,
|
||||
std=(module.num_heads * module.dim_attn)**-0.5)
|
||||
elif isinstance(module, T5RelativeEmbedding):
|
||||
key = mx.random.split(key, 1)[0]
|
||||
module.embedding.weight = normal(key, module.embedding.weight.shape,
|
||||
std=(2 * module.num_buckets * module.num_heads)**-0.5)
|
||||
elif isinstance(module, nn.Linear):
|
||||
# Generic linear layer initialization
|
||||
key = mx.random.split(key, 1)[0]
|
||||
fan_in = module.weight.shape[1]
|
||||
bound = 1.0 / math.sqrt(fan_in)
|
||||
module.weight = mx.random.uniform(key, module.weight.shape, -bound, bound)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def _t5(name,
|
||||
encoder_only=False,
|
||||
decoder_only=False,
|
||||
return_tokenizer=False,
|
||||
tokenizer_kwargs={},
|
||||
**kwargs):
|
||||
# sanity check
|
||||
assert not (encoder_only and decoder_only)
|
||||
|
||||
# params
|
||||
if encoder_only:
|
||||
model_cls = T5Encoder
|
||||
kwargs['vocab'] = kwargs.pop('vocab_size')
|
||||
kwargs['num_layers'] = kwargs.pop('encoder_layers')
|
||||
_ = kwargs.pop('decoder_layers')
|
||||
elif decoder_only:
|
||||
model_cls = T5Decoder
|
||||
kwargs['vocab'] = kwargs.pop('vocab_size')
|
||||
kwargs['num_layers'] = kwargs.pop('decoder_layers')
|
||||
_ = kwargs.pop('encoder_layers')
|
||||
else:
|
||||
model_cls = T5Model
|
||||
|
||||
# init model
|
||||
model = model_cls(**kwargs)
|
||||
|
||||
# Initialize weights properly
|
||||
key = mx.random.key(0)
|
||||
model = init_mlx_weights(model, key)
|
||||
|
||||
# init tokenizer
|
||||
if return_tokenizer:
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
|
||||
return model, tokenizer
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def umt5_xxl(**kwargs):
|
||||
cfg = dict(
|
||||
vocab_size=256384,
|
||||
dim=4096,
|
||||
dim_attn=4096,
|
||||
dim_ffn=10240,
|
||||
num_heads=64,
|
||||
encoder_layers=24,
|
||||
decoder_layers=24,
|
||||
num_buckets=32,
|
||||
shared_pos=False,
|
||||
dropout=0.0)
|
||||
cfg.update(**kwargs)
|
||||
return _t5('umt5-xxl', **cfg)
|
||||
|
||||
|
||||
class T5EncoderModel:
|
||||
def __init__(
|
||||
self,
|
||||
text_len,
|
||||
checkpoint_path=None,
|
||||
tokenizer_path=None,
|
||||
):
|
||||
self.text_len = text_len
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.tokenizer_path = tokenizer_path
|
||||
|
||||
# init model
|
||||
model = umt5_xxl(
|
||||
encoder_only=True,
|
||||
return_tokenizer=False)
|
||||
|
||||
if checkpoint_path:
|
||||
logging.info(f'loading {checkpoint_path}')
|
||||
# Load weights - assuming MLX format checkpoint
|
||||
weights = mx.load(checkpoint_path)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
|
||||
self.model = model
|
||||
|
||||
# init tokenizer
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
self.tokenizer = HuggingfaceTokenizer(
|
||||
name=tokenizer_path if tokenizer_path else 'google/umt5-xxl',
|
||||
seq_len=text_len,
|
||||
clean='whitespace')
|
||||
|
||||
def __call__(self, texts):
|
||||
# Handle single string input
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
# Tokenize texts
|
||||
tokenizer_output = self.tokenizer(
|
||||
texts, return_mask=True, add_special_tokens=True)
|
||||
|
||||
# Handle different tokenizer output formats
|
||||
if isinstance(tokenizer_output, tuple):
|
||||
ids, mask = tokenizer_output
|
||||
else:
|
||||
# Assuming dict output with 'input_ids' and 'attention_mask'
|
||||
ids = tokenizer_output['input_ids']
|
||||
mask = tokenizer_output['attention_mask']
|
||||
|
||||
# Convert to MLX arrays if not already
|
||||
if not isinstance(ids, mx.array):
|
||||
ids = mx.array(ids)
|
||||
if not isinstance(mask, mx.array):
|
||||
mask = mx.array(mask)
|
||||
|
||||
# Get sequence lengths
|
||||
seq_lens = mx.sum(mask > 0, axis=1)
|
||||
|
||||
# Run encoder
|
||||
context = self.model(ids, mask)
|
||||
|
||||
# Return variable length outputs
|
||||
# Convert seq_lens to Python list for indexing
|
||||
if seq_lens.ndim == 0: # Single value
|
||||
seq_lens_list = [seq_lens.item()]
|
||||
else:
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
|
||||
return [context[i, :int(seq_lens_list[i])] for i in range(len(texts))]
|
||||
|
||||
|
||||
# Utility function to convert PyTorch checkpoint to MLX
|
||||
def convert_pytorch_checkpoint(pytorch_path, mlx_path):
|
||||
"""Convert PyTorch checkpoint to MLX format"""
|
||||
import torch
|
||||
|
||||
# Load PyTorch checkpoint
|
||||
pytorch_state = torch.load(pytorch_path, map_location='cpu')
|
||||
|
||||
# Convert to numpy then to MLX
|
||||
mlx_state = {}
|
||||
for key, value in pytorch_state.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
# Handle the key mapping if needed
|
||||
mlx_key = key
|
||||
# Convert tensor to MLX array
|
||||
mlx_state[mlx_key] = mx.array(value.numpy())
|
||||
|
||||
# Save MLX checkpoint
|
||||
mx.save(mlx_path, mlx_state)
|
||||
|
||||
return mlx_state
|
82
video/Wan2.1/wan/modules/tokenizers.py
Normal file
@ -0,0 +1,82 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import html
|
||||
import string
|
||||
|
||||
import ftfy
|
||||
import regex as re
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
__all__ = ['HuggingfaceTokenizer']
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def canonicalize(text, keep_punctuation_exact_string=None):
|
||||
text = text.replace('_', ' ')
|
||||
if keep_punctuation_exact_string:
|
||||
text = keep_punctuation_exact_string.join(
|
||||
part.translate(str.maketrans('', '', string.punctuation))
|
||||
for part in text.split(keep_punctuation_exact_string))
|
||||
else:
|
||||
text = text.translate(str.maketrans('', '', string.punctuation))
|
||||
text = text.lower()
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
class HuggingfaceTokenizer:
|
||||
|
||||
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
||||
assert clean in (None, 'whitespace', 'lower', 'canonicalize')
|
||||
self.name = name
|
||||
self.seq_len = seq_len
|
||||
self.clean = clean
|
||||
|
||||
# init tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
|
||||
self.vocab_size = self.tokenizer.vocab_size
|
||||
|
||||
def __call__(self, sequence, **kwargs):
|
||||
return_mask = kwargs.pop('return_mask', False)
|
||||
|
||||
# arguments
|
||||
_kwargs = {'return_tensors': 'pt'}
|
||||
if self.seq_len is not None:
|
||||
_kwargs.update({
|
||||
'padding': 'max_length',
|
||||
'truncation': True,
|
||||
'max_length': self.seq_len
|
||||
})
|
||||
_kwargs.update(**kwargs)
|
||||
|
||||
# tokenization
|
||||
if isinstance(sequence, str):
|
||||
sequence = [sequence]
|
||||
if self.clean:
|
||||
sequence = [self._clean(u) for u in sequence]
|
||||
ids = self.tokenizer(sequence, **_kwargs)
|
||||
|
||||
# output
|
||||
if return_mask:
|
||||
return ids.input_ids, ids.attention_mask
|
||||
else:
|
||||
return ids.input_ids
|
||||
|
||||
def _clean(self, text):
|
||||
if self.clean == 'whitespace':
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
elif self.clean == 'lower':
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
elif self.clean == 'canonicalize':
|
||||
text = canonicalize(basic_clean(text))
|
||||
return text
|
702
video/Wan2.1/wan/modules/vae.py
Normal file
@ -0,0 +1,702 @@
|
||||
# Original PyTorch implementation of Wan VAE
|
||||
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.cuda.amp as amp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
__all__ = [
|
||||
'WanVAE',
|
||||
]
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
debug_line = 0
|
||||
|
||||
def debug(name, x):
|
||||
global debug_line
|
||||
print(f"LINE {debug_line}: {name}: shape = {tuple(x.shape)}, mean = {x.mean().item():.4f}, std = {x.std().item():.4f}")
|
||||
debug_line += 1
|
||||
return x
|
||||
|
||||
|
||||
class CausalConv3d(nn.Conv3d):
|
||||
"""
|
||||
Causal 3d convolusion.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
||||
self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
|
||||
def forward(self, x, cache_x=None):
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
cache_x = cache_x.to(x.device)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
x = F.pad(x, padding)
|
||||
|
||||
result = super().forward(x)
|
||||
debug("TORCH x after conv3d", result)
|
||||
return result
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
|
||||
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
||||
super().__init__()
|
||||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||||
|
||||
self.channel_first = channel_first
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(shape))
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(
|
||||
x, dim=(1 if self.channel_first else
|
||||
-1)) * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class Upsample(nn.Upsample):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Fix bfloat16 support for nearest neighbor interpolation.
|
||||
"""
|
||||
result = super().forward(x.float()).type_as(x)
|
||||
debug("TORCH x after upsample", result)
|
||||
return result
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
|
||||
def __init__(self, dim, mode):
|
||||
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
||||
'downsample3d')
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mode = mode
|
||||
|
||||
# layers
|
||||
if mode == 'upsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
||||
elif mode == 'upsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
|
||||
elif mode == 'downsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
elif mode == 'downsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
b, c, t, h, w = x.size()
|
||||
if self.mode == 'upsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = 'Rep'
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] != 'Rep':
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] == 'Rep':
|
||||
cache_x = torch.cat([
|
||||
torch.zeros_like(cache_x).to(cache_x.device),
|
||||
cache_x
|
||||
],
|
||||
dim=2)
|
||||
if feat_cache[idx] == 'Rep':
|
||||
x = self.time_conv(x)
|
||||
debug("TORCH x after time_conv", x)
|
||||
else:
|
||||
x = self.time_conv(x, feat_cache[idx])
|
||||
debug("TORCH x after time_conv with cache", x)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
x = x.reshape(b, 2, c, t, h, w)
|
||||
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
||||
3)
|
||||
x = x.reshape(b, c, t * 2, h, w)
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
x = self.resample(x)
|
||||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
||||
debug("TORCH x after resample", x)
|
||||
|
||||
if self.mode == 'downsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x.clone()
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -1:, :, :].clone()
|
||||
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
||||
# # cache last frame of last two chunk
|
||||
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
|
||||
x = self.time_conv(
|
||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
return x
|
||||
|
||||
def init_weight(self, conv):
|
||||
conv_weight = conv.weight
|
||||
nn.init.zeros_(conv_weight)
|
||||
c1, c2, t, h, w = conv_weight.size()
|
||||
one_matrix = torch.eye(c1, c2)
|
||||
init_matrix = one_matrix
|
||||
nn.init.zeros_(conv_weight)
|
||||
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
|
||||
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
|
||||
conv.weight.data.copy_(conv_weight)
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
def init_weight2(self, conv):
|
||||
conv_weight = conv.weight.data
|
||||
nn.init.zeros_(conv_weight)
|
||||
c1, c2, t, h, w = conv_weight.size()
|
||||
init_matrix = torch.eye(c1 // 2, c2)
|
||||
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
|
||||
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
||||
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
||||
conv.weight.data.copy_(conv_weight)
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
# layers
|
||||
self.residual = nn.Sequential(
|
||||
RMS_norm(in_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
||||
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
h = self.shortcut(x)
|
||||
debug("TORCH x after shortcut", h)
|
||||
for layer in self.residual:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
debug("TORCH x after residual block", x)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
debug("TORCH x after residual block", x)
|
||||
return x + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
Causal self-attention with a single head.
|
||||
"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
# layers
|
||||
self.norm = RMS_norm(dim)
|
||||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||
self.proj = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
# zero out the last layer params
|
||||
nn.init.zeros_(self.proj.weight)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
b, c, t, h, w = x.size()
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
x = self.norm(x)
|
||||
debug("TORCH x after norm", x)
|
||||
# compute query, key, value
|
||||
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
|
||||
-1).permute(0, 1, 3,
|
||||
2).contiguous().chunk(
|
||||
3, dim=-1)
|
||||
|
||||
# apply attention
|
||||
x = F.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
)
|
||||
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
debug("TORCH x after proj", x)
|
||||
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
||||
return x + identity
|
||||
|
||||
|
||||
class Encoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
downsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
for _ in range(num_res_blocks):
|
||||
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
downsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'downsample3d' if temperal_downsample[
|
||||
i] else 'downsample2d'
|
||||
downsamples.append(Resample(out_dim, mode=mode))
|
||||
scale /= 2.0
|
||||
self.downsamples = nn.Sequential(*downsamples)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
|
||||
ResidualBlock(out_dim, out_dim, dropout))
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
debug("TORCH x after conv1 with cache", x)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
debug("TORCH x after conv1", x)
|
||||
|
||||
## downsamples
|
||||
for i, layer in enumerate(self.downsamples):
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
debug("TORCH x after downsample layer", x)
|
||||
else:
|
||||
x = layer(x)
|
||||
debug("TORCH x after downsample layer", x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
debug("TORCH x after downsample layer", x)
|
||||
else:
|
||||
x = layer(x)
|
||||
debug("TORCH x after downsample layer", x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
debug("TORCH x after downsample layer", x)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
debug("TORCH x after downsample layer", x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_upsample = temperal_upsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
scale = 1.0 / 2**(len(dim_mult) - 2)
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
|
||||
ResidualBlock(dims[0], dims[0], dropout))
|
||||
|
||||
# upsample blocks
|
||||
upsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
if i == 1 or i == 2 or i == 3:
|
||||
in_dim = in_dim // 2
|
||||
for _ in range(num_res_blocks + 1):
|
||||
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
upsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# upsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
||||
upsamples.append(Resample(out_dim, mode=mode))
|
||||
scale *= 2.0
|
||||
self.upsamples = nn.Sequential(*upsamples)
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, 3, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
debug("TORCH x after conv1 with cache", x)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
debug("TORCH x after conv1", x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
debug("TORCH x after middle layer", x)
|
||||
else:
|
||||
x = layer(x)
|
||||
debug("TORCH x after middle layer", x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
debug("TORCH x after upsample layer", x)
|
||||
else:
|
||||
x = layer(x)
|
||||
debug("TORCH x after upsample layer", x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
debug("TORCH x after head layer", x)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
debug("TORCH x after head layer", x)
|
||||
return x
|
||||
|
||||
|
||||
def count_conv3d(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if isinstance(m, CausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
class WanVAE_(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
# modules
|
||||
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_downsample, dropout)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def forward(self, x):
|
||||
mu, log_var = self.encode(x)
|
||||
z = self.reparameterize(mu, log_var)
|
||||
x_recon = self.decode(z)
|
||||
return x_recon, mu, log_var
|
||||
|
||||
def encode(self, x, scale):
|
||||
self.clear_cache()
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(
|
||||
x[:, :, :1, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(
|
||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
mu = (mu - scale[0]) * scale[1]
|
||||
self.clear_cache()
|
||||
|
||||
return mu
|
||||
|
||||
def decode(self, z, scale):
|
||||
self.clear_cache()
|
||||
# z: [b,c,t,h,w]
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
z = z / scale[1] + scale[0]
|
||||
iter_ = z.shape[2]
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
self.clear_cache()
|
||||
return out
|
||||
|
||||
def reparameterize(self, mu, log_var):
|
||||
std = torch.exp(0.5 * log_var)
|
||||
eps = torch.randn_like(std)
|
||||
return eps * std + mu
|
||||
|
||||
def sample(self, imgs, deterministic=False):
|
||||
mu, log_var = self.encode(imgs)
|
||||
if deterministic:
|
||||
return mu
|
||||
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
||||
return mu + std * torch.randn_like(std)
|
||||
|
||||
def clear_cache(self):
|
||||
self._conv_num = count_conv3d(self.decoder)
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
#cache encode
|
||||
self._enc_conv_num = count_conv3d(self.encoder)
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
|
||||
|
||||
def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
|
||||
"""
|
||||
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
|
||||
"""
|
||||
# params
|
||||
cfg = dict(
|
||||
dim=96,
|
||||
z_dim=z_dim,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[False, True, True],
|
||||
dropout=0.0)
|
||||
cfg.update(**kwargs)
|
||||
|
||||
# init model
|
||||
with torch.device('meta'):
|
||||
model = WanVAE_(**cfg)
|
||||
|
||||
# load checkpoint
|
||||
logging.info(f'loading {pretrained_path}')
|
||||
model.load_state_dict(
|
||||
torch.load(pretrained_path, map_location=device), assign=True)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class WanVAE:
|
||||
|
||||
def __init__(self,
|
||||
z_dim=16,
|
||||
vae_pth='cache/vae_step_411000.pth',
|
||||
dtype=torch.float,
|
||||
device="cpu"):
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
mean = [
|
||||
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
||||
]
|
||||
std = [
|
||||
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
||||
]
|
||||
self.mean = torch.tensor(mean, dtype=dtype, device=device)
|
||||
self.std = torch.tensor(std, dtype=dtype, device=device)
|
||||
self.scale = [self.mean, 1.0 / self.std]
|
||||
|
||||
# init model
|
||||
self.model = _video_vae(
|
||||
pretrained_path=vae_pth,
|
||||
z_dim=z_dim,
|
||||
).eval().requires_grad_(False).to(device)
|
||||
|
||||
def encode(self, videos):
|
||||
"""
|
||||
videos: A list of videos each with shape [C, T, H, W].
|
||||
"""
|
||||
with amp.autocast(dtype=self.dtype):
|
||||
return [
|
||||
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
|
||||
for u in videos
|
||||
]
|
||||
|
||||
def decode(self, zs):
|
||||
with amp.autocast(dtype=self.dtype):
|
||||
return [
|
||||
self.model.decode(u.unsqueeze(0),
|
||||
self.scale).float().clamp_(-1, 1).squeeze(0)
|
||||
for u in zs
|
||||
]
|
719
video/Wan2.1/wan/modules/vae_mlx.py
Normal file
@ -0,0 +1,719 @@
|
||||
# vae_mlx_final.py
|
||||
import logging
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
__all__ = [
|
||||
'WanVAE',
|
||||
]
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
debug_line = 0
|
||||
|
||||
|
||||
class CausalConv3d(nn.Conv3d):
|
||||
"""
|
||||
Causal 3d convolution for MLX.
|
||||
Expects input in BTHWC format (batch, time, height, width, channels).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Padding order: (W, W, H, H, T, 0)
|
||||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
||||
self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
|
||||
def __call__(self, x, cache_x=None):
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
x = mx.concatenate([cache_x, x], axis=1) # Concat along time axis
|
||||
padding[4] -= cache_x.shape[1]
|
||||
|
||||
# Pad in BTHWC format
|
||||
pad_width = [(0, 0), (padding[4], padding[5]), (padding[2], padding[3]),
|
||||
(padding[0], padding[1]), (0, 0)]
|
||||
x = mx.pad(x, pad_width)
|
||||
|
||||
result = super().__call__(x)
|
||||
return result
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
|
||||
def __init__(self, dim, channel_first=False, images=True, bias=False):
|
||||
super().__init__()
|
||||
self.channel_first = channel_first
|
||||
self.images = images
|
||||
self.scale = dim**0.5
|
||||
|
||||
# Just keep as 1D - let broadcasting do its magic
|
||||
self.gamma = mx.ones((dim,))
|
||||
self.bias = mx.zeros((dim,)) if bias else 0.
|
||||
|
||||
def __call__(self, x):
|
||||
# F.normalize in PyTorch does L2 normalization, not RMS!
|
||||
# For NHWC/BTHWC format, normalize along the last axis
|
||||
# L2 norm: sqrt(sum(x^2))
|
||||
norm = mx.sqrt(mx.sum(x * x, axis=-1, keepdims=True) + 1e-6)
|
||||
x = x / norm
|
||||
return x * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
Upsampling layer that matches PyTorch's behavior.
|
||||
"""
|
||||
def __init__(self, scale_factor, mode='nearest-exact'):
|
||||
super().__init__()
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode # mode is now unused, but kept for signature consistency
|
||||
|
||||
def __call__(self, x):
|
||||
# For NHWC format (n, h, w, c)
|
||||
|
||||
# NOTE: For an integer scale_factor like 2.0, PyTorch's 'nearest-exact'
|
||||
# is equivalent to a simple repeat operation. The previous coordinate-based
|
||||
# sampling was not correct for this model and caused the divergence.
|
||||
|
||||
scale_h, scale_w = self.scale_factor
|
||||
|
||||
out = mx.repeat(x, int(scale_h), axis=1) # Repeat along H dimension
|
||||
out = mx.repeat(out, int(scale_w), axis=2) # Repeat along W dimension
|
||||
|
||||
return out
|
||||
|
||||
class AsymmetricPad(nn.Module):
|
||||
"""A module to apply asymmetric padding, compatible with nn.Sequential."""
|
||||
def __init__(self, pad_width: tuple):
|
||||
super().__init__()
|
||||
self.pad_width = pad_width
|
||||
|
||||
def __call__(self, x):
|
||||
return mx.pad(x, self.pad_width)
|
||||
|
||||
# Update your Resample class to use 'nearest-exact'
|
||||
class Resample(nn.Module):
|
||||
|
||||
def __init__(self, dim, mode):
|
||||
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
||||
'downsample3d')
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mode = mode
|
||||
|
||||
# layers
|
||||
if mode == 'upsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
)
|
||||
elif mode == 'upsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
)
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
|
||||
# --- CORRECTED PADDING LOGIC ---
|
||||
elif mode == 'downsample2d':
|
||||
# Replicate PyTorch's ZeroPad2d((0, 1, 0, 1)) + Conv2d(stride=2)
|
||||
# Use the new AsymmetricPad module.
|
||||
# Pad width for NHWC format is ((N), (H), (W), (C))
|
||||
# Pad H with (top, bottom) and W with (left, right)
|
||||
pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0)))
|
||||
conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0)
|
||||
self.resample = nn.Sequential(pad_layer, conv_layer)
|
||||
|
||||
elif mode == 'downsample3d':
|
||||
# The spatial downsampling part uses the same logic
|
||||
pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0)))
|
||||
conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0)
|
||||
self.resample = nn.Sequential(pad_layer, conv_layer)
|
||||
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def __call__(self, x, feat_cache=None, feat_idx=[0]):
|
||||
# The __call__ method logic remains unchanged from your original code
|
||||
b, t, h, w, c = x.shape
|
||||
|
||||
if self.mode == 'upsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = 'Rep'
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != 'Rep':
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] == 'Rep':
|
||||
cache_x = mx.concatenate([
|
||||
mx.zeros_like(cache_x), cache_x
|
||||
], axis=1)
|
||||
|
||||
if feat_cache[idx] == 'Rep':
|
||||
x = self.time_conv(x)
|
||||
else:
|
||||
x = self.time_conv(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
x = x.reshape(b, t, h, w, 2, c)
|
||||
x = mx.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=2)
|
||||
x = x.reshape(b, t * 2, h, w, c)
|
||||
|
||||
t = x.shape[1]
|
||||
x = x.reshape(b * t, h, w, c)
|
||||
|
||||
x = self.resample(x)
|
||||
|
||||
_, h_new, w_new, c_new = x.shape
|
||||
x = x.reshape(b, t, h_new, w_new, c_new)
|
||||
|
||||
if self.mode == 'downsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, -1:, :, :, :]
|
||||
x = self.time_conv(
|
||||
mx.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
# layers
|
||||
self.residual = nn.Sequential(
|
||||
RMS_norm(in_dim, images=False),
|
||||
nn.SiLU(),
|
||||
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
||||
RMS_norm(out_dim, images=False),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
|
||||
CausalConv3d(out_dim, out_dim, 3, padding=1)
|
||||
)
|
||||
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def __call__(self, x, feat_cache=None, feat_idx=[0]):
|
||||
h = self.shortcut(x)
|
||||
|
||||
for i, layer in enumerate(self.residual.layers):
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
Causal self-attention with a single head.
|
||||
"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
# layers
|
||||
self.norm = RMS_norm(dim)
|
||||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||
self.proj = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
# zero out the last layer params
|
||||
self.proj.weight = mx.zeros_like(self.proj.weight)
|
||||
|
||||
def __call__(self, x):
|
||||
# x is in BTHWC format
|
||||
identity = x
|
||||
b, t, h, w, c = x.shape
|
||||
x = x.reshape(b * t, h, w, c) # Combine batch and time
|
||||
x = self.norm(x)
|
||||
# compute query, key, value
|
||||
qkv = self.to_qkv(x) # Output: (b*t, h, w, 3*c)
|
||||
qkv = qkv.reshape(b * t, h * w, 3 * c)
|
||||
q, k, v = mx.split(qkv, 3, axis=-1)
|
||||
|
||||
# Reshape for attention
|
||||
q = q.reshape(b * t, h * w, c)
|
||||
k = k.reshape(b * t, h * w, c)
|
||||
v = v.reshape(b * t, h * w, c)
|
||||
|
||||
# Scaled dot product attention
|
||||
scale = 1.0 / mx.sqrt(mx.array(c, dtype=q.dtype))
|
||||
scores = (q @ k.transpose(0, 2, 1)) * scale
|
||||
weights = mx.softmax(scores, axis=-1)
|
||||
x = weights @ v
|
||||
x = x.reshape(b * t, h, w, c)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
x = x.reshape(b, t, h, w, c)
|
||||
return x + identity
|
||||
|
||||
|
||||
class Encoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
downsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
for _ in range(num_res_blocks):
|
||||
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
downsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'downsample3d' if temperal_downsample[i] else 'downsample2d'
|
||||
downsamples.append(Resample(out_dim, mode=mode))
|
||||
scale /= 2.0
|
||||
|
||||
self.downsamples = nn.Sequential(*downsamples)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(dims[-1], dims[-1], dropout),
|
||||
AttentionBlock(dims[-1]),
|
||||
ResidualBlock(dims[-1], dims[-1], dropout)
|
||||
)
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(dims[-1], images=False),
|
||||
nn.SiLU(),
|
||||
CausalConv3d(dims[-1], z_dim, 3, padding=1)
|
||||
)
|
||||
|
||||
def __call__(self, x, feat_cache=None, feat_idx=[0]):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## downsamples
|
||||
for i, layer in enumerate(self.downsamples.layers):
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle.layers:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for i, layer in enumerate(self.head.layers):
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_upsample = temperal_upsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
scale = 1.0 / 2**(len(dim_mult) - 2)
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(dims[0], dims[0], dropout),
|
||||
AttentionBlock(dims[0]),
|
||||
ResidualBlock(dims[0], dims[0], dropout)
|
||||
)
|
||||
|
||||
# upsample blocks
|
||||
upsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
if i == 1 or i == 2 or i == 3:
|
||||
in_dim = in_dim // 2
|
||||
for _ in range(num_res_blocks + 1):
|
||||
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
upsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# upsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
||||
upsamples.append(Resample(out_dim, mode=mode))
|
||||
scale *= 2.0
|
||||
|
||||
self.upsamples = nn.Sequential(*upsamples)
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(dims[-1], images=False),
|
||||
nn.SiLU(),
|
||||
CausalConv3d(dims[-1], 3, 3, padding=1)
|
||||
)
|
||||
|
||||
def __call__(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle.layers:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples.layers:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for i, layer in enumerate(self.head.layers):
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def count_conv3d(model):
|
||||
count = 0
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, CausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
class WanVAE_(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
# modules
|
||||
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_downsample, dropout)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def encode(self, x, scale):
|
||||
# x is in BTHWC format
|
||||
self.clear_cache()
|
||||
## cache
|
||||
t = x.shape[1]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
## Split encode input x by time into 1, 4, 4, 4....
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(
|
||||
x[:, :1, :, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(
|
||||
x[:, 1 + 4 * (i - 1):1 + 4 * i, :, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
out = mx.concatenate([out, out_], axis=1)
|
||||
|
||||
z = self.conv1(out)
|
||||
mu, log_var = mx.split(z, 2, axis=-1) # Split along channel dimension
|
||||
|
||||
if isinstance(scale[0], mx.array):
|
||||
# Reshape scale for broadcasting in BTHWC format
|
||||
scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim)
|
||||
scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim)
|
||||
mu = (mu - scale_mean) * scale_std
|
||||
else:
|
||||
mu = (mu - scale[0]) * scale[1]
|
||||
self.clear_cache()
|
||||
|
||||
return mu, log_var
|
||||
|
||||
def decode(self, z, scale):
|
||||
# z is in BTHWC format
|
||||
self.clear_cache()
|
||||
if isinstance(scale[0], mx.array):
|
||||
scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim)
|
||||
scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim)
|
||||
z = z / scale_std + scale_mean
|
||||
else:
|
||||
z = z / scale[1] + scale[0]
|
||||
iter_ = z.shape[1]
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(
|
||||
x[:, i:i + 1, :, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(
|
||||
x[:, i:i + 1, :, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
out = mx.concatenate([out, out_], axis=1)
|
||||
self.clear_cache()
|
||||
return out
|
||||
|
||||
def reparameterize(self, mu, log_var):
|
||||
std = mx.exp(0.5 * log_var)
|
||||
eps = mx.random.normal(std.shape)
|
||||
return eps * std + mu
|
||||
|
||||
def __call__(self, x):
|
||||
mu, log_var = self.encode(x, self.scale)
|
||||
z = self.reparameterize(mu, log_var)
|
||||
x_recon = self.decode(z, self.scale)
|
||||
return x_recon, mu, log_var
|
||||
|
||||
def sample(self, imgs, deterministic=False):
|
||||
mu, log_var = self.encode(imgs, self.scale)
|
||||
if deterministic:
|
||||
return mu
|
||||
std = mx.exp(0.5 * mx.clip(log_var, -30.0, 20.0))
|
||||
return mu + std * mx.random.normal(std.shape)
|
||||
|
||||
def clear_cache(self):
|
||||
self._conv_num = count_conv3d(self.decoder)
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
#cache encode
|
||||
self._enc_conv_num = count_conv3d(self.encoder)
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
|
||||
|
||||
def _video_vae(pretrained_path=None, z_dim=None, **kwargs):
|
||||
"""
|
||||
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
|
||||
"""
|
||||
# params
|
||||
cfg = dict(
|
||||
dim=96,
|
||||
z_dim=z_dim,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[False, True, True],
|
||||
dropout=0.0)
|
||||
cfg.update(**kwargs)
|
||||
|
||||
# init model
|
||||
model = WanVAE_(**cfg)
|
||||
|
||||
# load checkpoint
|
||||
if pretrained_path:
|
||||
logging.info(f'loading {pretrained_path}')
|
||||
weights = mx.load(pretrained_path)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class WanVAE:
|
||||
|
||||
def __init__(self,
|
||||
z_dim=16,
|
||||
vae_pth='cache/vae_step_411000.pth',
|
||||
dtype=mx.float32):
|
||||
self.dtype = dtype
|
||||
|
||||
mean = [
|
||||
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
||||
]
|
||||
std = [
|
||||
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
||||
]
|
||||
self.mean = mx.array(mean, dtype=dtype)
|
||||
self.std = mx.array(std, dtype=dtype)
|
||||
self.scale = [self.mean, 1.0 / self.std]
|
||||
|
||||
# init model
|
||||
self.model = _video_vae(
|
||||
pretrained_path=vae_pth,
|
||||
z_dim=z_dim,
|
||||
)
|
||||
|
||||
def encode(self, videos):
|
||||
"""
|
||||
videos: A list of videos each with shape [C, T, H, W].
|
||||
Returns: List of encoded videos in [C, T, H, W] format.
|
||||
"""
|
||||
encoded = []
|
||||
for video in videos:
|
||||
# Convert CTHW -> BTHWC
|
||||
x = mx.expand_dims(video, axis=0) # Add batch dimension
|
||||
x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC
|
||||
|
||||
# Encode
|
||||
z = self.model.encode(x, self.scale)[0] # Get mu only
|
||||
|
||||
# Convert back BTHWC -> CTHW and remove batch dimension
|
||||
z = z.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW
|
||||
z = z.squeeze(0) # Remove batch dimension -> CTHW
|
||||
|
||||
encoded.append(z.astype(mx.float32))
|
||||
|
||||
return encoded
|
||||
|
||||
def decode(self, zs):
|
||||
"""
|
||||
zs: A list of latent codes each with shape [C, T, H, W].
|
||||
Returns: List of decoded videos in [C, T, H, W] format.
|
||||
"""
|
||||
decoded = []
|
||||
for z in zs:
|
||||
# Convert CTHW -> BTHWC
|
||||
x = mx.expand_dims(z, axis=0) # Add batch dimension
|
||||
x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC
|
||||
|
||||
# Decode
|
||||
x = self.model.decode(x, self.scale)
|
||||
|
||||
# Convert back BTHWC -> CTHW and remove batch dimension
|
||||
x = x.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW
|
||||
x = x.squeeze(0) # Remove batch dimension -> CTHW
|
||||
|
||||
# Clamp values
|
||||
x = mx.clip(x, -1, 1)
|
||||
|
||||
decoded.append(x.astype(mx.float32))
|
||||
|
||||
return decoded
|
170
video/Wan2.1/wan/modules/xlm_roberta.py
Normal file
@ -0,0 +1,170 @@
|
||||
# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
__all__ = ['XLMRoberta', 'xlm_roberta_large']
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
|
||||
assert dim % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(dim, dim)
|
||||
self.v = nn.Linear(dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x, mask):
|
||||
"""
|
||||
x: [B, L, C].
|
||||
"""
|
||||
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
||||
k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
||||
v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
||||
|
||||
# compute attention
|
||||
p = self.dropout.p if self.training else 0.0
|
||||
x = F.scaled_dot_product_attention(q, k, v, mask, p)
|
||||
x = x.permute(0, 2, 1, 3).reshape(b, s, c)
|
||||
|
||||
# output
|
||||
x = self.o(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.post_norm = post_norm
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
self.attn = SelfAttention(dim, num_heads, dropout, eps)
|
||||
self.norm1 = nn.LayerNorm(dim, eps=eps)
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
|
||||
nn.Dropout(dropout))
|
||||
self.norm2 = nn.LayerNorm(dim, eps=eps)
|
||||
|
||||
def forward(self, x, mask):
|
||||
if self.post_norm:
|
||||
x = self.norm1(x + self.attn(x, mask))
|
||||
x = self.norm2(x + self.ffn(x))
|
||||
else:
|
||||
x = x + self.attn(self.norm1(x), mask)
|
||||
x = x + self.ffn(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class XLMRoberta(nn.Module):
|
||||
"""
|
||||
XLMRobertaModel with no pooler and no LM head.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vocab_size=250002,
|
||||
max_seq_len=514,
|
||||
type_size=1,
|
||||
pad_id=1,
|
||||
dim=1024,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
post_norm=True,
|
||||
dropout=0.1,
|
||||
eps=1e-5):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.max_seq_len = max_seq_len
|
||||
self.type_size = type_size
|
||||
self.pad_id = pad_id
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.post_norm = post_norm
|
||||
self.eps = eps
|
||||
|
||||
# embeddings
|
||||
self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
|
||||
self.type_embedding = nn.Embedding(type_size, dim)
|
||||
self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
AttentionBlock(dim, num_heads, post_norm, dropout, eps)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
# norm layer
|
||||
self.norm = nn.LayerNorm(dim, eps=eps)
|
||||
|
||||
def forward(self, ids):
|
||||
"""
|
||||
ids: [B, L] of torch.LongTensor.
|
||||
"""
|
||||
b, s = ids.shape
|
||||
mask = ids.ne(self.pad_id).long()
|
||||
|
||||
# embeddings
|
||||
x = self.token_embedding(ids) + \
|
||||
self.type_embedding(torch.zeros_like(ids)) + \
|
||||
self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
|
||||
if self.post_norm:
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
|
||||
# blocks
|
||||
mask = torch.where(
|
||||
mask.view(b, 1, 1, s).gt(0), 0.0,
|
||||
torch.finfo(x.dtype).min)
|
||||
for block in self.blocks:
|
||||
x = block(x, mask)
|
||||
|
||||
# output
|
||||
if not self.post_norm:
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
def xlm_roberta_large(pretrained=False,
|
||||
return_tokenizer=False,
|
||||
device='cpu',
|
||||
**kwargs):
|
||||
"""
|
||||
XLMRobertaLarge adapted from Huggingface.
|
||||
"""
|
||||
# params
|
||||
cfg = dict(
|
||||
vocab_size=250002,
|
||||
max_seq_len=514,
|
||||
type_size=1,
|
||||
pad_id=1,
|
||||
dim=1024,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
post_norm=True,
|
||||
dropout=0.1,
|
||||
eps=1e-5)
|
||||
cfg.update(**kwargs)
|
||||
|
||||
# init a model on device
|
||||
with torch.device(device):
|
||||
model = XLMRoberta(**cfg)
|
||||
return model
|
310
video/Wan2.1/wan/t5_model_io.py
Normal file
@ -0,0 +1,310 @@
|
||||
import json
|
||||
from typing import Optional, List, Tuple
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_unflatten
|
||||
from safetensors import safe_open
|
||||
import torch
|
||||
|
||||
|
||||
def check_safetensors_dtypes(safetensors_path: str):
|
||||
"""
|
||||
Check what dtypes are in the safetensors file.
|
||||
Useful for debugging dtype issues.
|
||||
"""
|
||||
print(f"🔍 Checking dtypes in: {safetensors_path}")
|
||||
|
||||
dtype_counts = {}
|
||||
|
||||
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
tensor = f.get_tensor(key)
|
||||
dtype_str = str(tensor.dtype)
|
||||
|
||||
if dtype_str not in dtype_counts:
|
||||
dtype_counts[dtype_str] = []
|
||||
dtype_counts[dtype_str].append(key)
|
||||
|
||||
print("📊 Dtype summary:")
|
||||
for dtype, keys in dtype_counts.items():
|
||||
print(f" {dtype}: {len(keys)} parameters")
|
||||
if dtype == "torch.bfloat16":
|
||||
print(f" ⚠️ BFloat16 detected - will convert to float32")
|
||||
print(f" Examples: {keys[:3]}")
|
||||
|
||||
return dtype_counts
|
||||
|
||||
|
||||
def convert_tensor_dtype(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert tensor to MLX-compatible dtype.
|
||||
"""
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
# Convert BFloat16 to float32
|
||||
return tensor.float()
|
||||
elif tensor.dtype == torch.float64:
|
||||
# Convert float64 to float32 for efficiency
|
||||
return tensor.float()
|
||||
else:
|
||||
# Keep other dtypes as-is
|
||||
return tensor
|
||||
|
||||
|
||||
def map_t5_encoder_weights(key: str, value: mx.array) -> List[Tuple[str, mx.array]]:
|
||||
"""
|
||||
Map T5 encoder weights from PyTorch format to MLX format.
|
||||
Following the pattern used in MLX Stable Diffusion.
|
||||
|
||||
Args:
|
||||
key: Parameter name from PyTorch model
|
||||
value: Parameter tensor
|
||||
|
||||
Returns:
|
||||
List of (key, value) tuples for MLX model
|
||||
"""
|
||||
|
||||
# Handle the main structural difference: FFN gate layer
|
||||
if ".ffn.gate.0.weight" in key:
|
||||
# PyTorch has Sequential(Linear, GELU) but MLX has separate gate_proj + gate_act
|
||||
key = key.replace(".ffn.gate.0.weight", ".ffn.gate_proj.weight")
|
||||
return [(key, value)]
|
||||
|
||||
elif ".ffn.gate.0.bias" in key:
|
||||
# Handle bias if it exists
|
||||
key = key.replace(".ffn.gate.0.bias", ".ffn.gate_proj.bias")
|
||||
return [(key, value)]
|
||||
|
||||
elif ".ffn.gate.1" in key:
|
||||
# Skip GELU activation parameters - MLX handles this separately
|
||||
print(f"Skipping GELU parameter: {key}")
|
||||
return []
|
||||
|
||||
# Handle any other potential FFN mappings
|
||||
elif ".ffn.fc1.weight" in key:
|
||||
return [(key, value)]
|
||||
elif ".ffn.fc2.weight" in key:
|
||||
return [(key, value)]
|
||||
|
||||
# Handle attention layers (should be direct mapping)
|
||||
elif ".attn.q.weight" in key:
|
||||
return [(key, value)]
|
||||
elif ".attn.k.weight" in key:
|
||||
return [(key, value)]
|
||||
elif ".attn.v.weight" in key:
|
||||
return [(key, value)]
|
||||
elif ".attn.o.weight" in key:
|
||||
return [(key, value)]
|
||||
|
||||
# Handle embeddings and norms (direct mapping)
|
||||
elif "token_embedding.weight" in key:
|
||||
return [(key, value)]
|
||||
elif "pos_embedding.embedding.weight" in key:
|
||||
return [(key, value)]
|
||||
elif "norm1.weight" in key or "norm2.weight" in key or "norm.weight" in key:
|
||||
return [(key, value)]
|
||||
|
||||
# Default: direct mapping for any other parameters
|
||||
else:
|
||||
return [(key, value)]
|
||||
|
||||
|
||||
def _flatten(params: List[List[Tuple[str, mx.array]]]) -> List[Tuple[str, mx.array]]:
|
||||
"""Flatten nested list of parameter tuples"""
|
||||
return [(k, v) for p in params for (k, v) in p]
|
||||
|
||||
|
||||
def _load_safetensor_weights(
|
||||
mapper_func,
|
||||
model,
|
||||
weight_file: str,
|
||||
float16: bool = False
|
||||
):
|
||||
"""
|
||||
Load safetensor weights using the mapping function.
|
||||
Based on MLX SD pattern.
|
||||
"""
|
||||
dtype = mx.float16 if float16 else mx.float32
|
||||
|
||||
# Load weights from safetensors file
|
||||
weights = {}
|
||||
with safe_open(weight_file, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
tensor = f.get_tensor(key)
|
||||
|
||||
# Handle BFloat16 - convert to float32 first
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
print(f"Converting BFloat16 to float32 for: {key}")
|
||||
tensor = tensor.float() # Convert to float32
|
||||
|
||||
weights[key] = mx.array(tensor.numpy()).astype(dtype)
|
||||
|
||||
# Apply mapping function
|
||||
mapped_weights = _flatten([mapper_func(k, v) for k, v in weights.items()])
|
||||
|
||||
# Update model with mapped weights
|
||||
model.update(tree_unflatten(mapped_weights))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_t5_encoder_from_safetensors(
|
||||
safetensors_path: str,
|
||||
model, # Your MLX T5Encoder instance
|
||||
float16: bool = False
|
||||
):
|
||||
"""
|
||||
Load T5 encoder weights from safetensors file into MLX model.
|
||||
|
||||
Args:
|
||||
safetensors_path: Path to the safetensors file
|
||||
model: Your MLX T5Encoder model instance
|
||||
float16: Whether to use float16 precision
|
||||
|
||||
Returns:
|
||||
Model with loaded weights
|
||||
"""
|
||||
print(f"Loading T5 encoder weights from: {safetensors_path}")
|
||||
|
||||
# Load and map weights
|
||||
model = _load_safetensor_weights(
|
||||
map_t5_encoder_weights,
|
||||
model,
|
||||
safetensors_path,
|
||||
float16
|
||||
)
|
||||
|
||||
print("T5 encoder weights loaded successfully!")
|
||||
return model
|
||||
|
||||
|
||||
def debug_weight_mapping(safetensors_path: str, float16: bool = False):
|
||||
"""
|
||||
Debug function to see how weights are being mapped.
|
||||
Useful for troubleshooting conversion issues.
|
||||
"""
|
||||
dtype = mx.float16 if float16 else mx.float32
|
||||
|
||||
print("=== T5 Weight Mapping Debug ===")
|
||||
|
||||
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
tensor = f.get_tensor(key)
|
||||
|
||||
# Handle BFloat16
|
||||
original_dtype = tensor.dtype
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
print(f"Converting BFloat16 to float32 for: {key}")
|
||||
tensor = tensor.float()
|
||||
|
||||
value = mx.array(tensor.numpy()).astype(dtype)
|
||||
|
||||
# Apply mapping
|
||||
mapped = map_t5_encoder_weights(key, value)
|
||||
|
||||
if len(mapped) == 0:
|
||||
print(f"SKIPPED: {key} ({original_dtype}) -> (no mapping)")
|
||||
elif len(mapped) == 1:
|
||||
new_key, new_value = mapped[0]
|
||||
if new_key == key:
|
||||
print(f"DIRECT: {key} ({original_dtype}) [{tensor.shape}]")
|
||||
else:
|
||||
print(f"MAPPED: {key} ({original_dtype}) -> {new_key} [{tensor.shape}]")
|
||||
else:
|
||||
print(f"SPLIT: {key} ({original_dtype}) -> {len(mapped)} parameters")
|
||||
for new_key, new_value in mapped:
|
||||
print(f" -> {new_key} [{new_value.shape}]")
|
||||
|
||||
|
||||
def convert_safetensors_to_mlx_weights(
|
||||
safetensors_path: str,
|
||||
output_path: str,
|
||||
float16: bool = False
|
||||
):
|
||||
"""
|
||||
Convert safetensors file to MLX weights file (.npz format).
|
||||
|
||||
Args:
|
||||
safetensors_path: Input safetensors file
|
||||
output_path: Output MLX weights file (.npz)
|
||||
float16: Whether to use float16 precision
|
||||
"""
|
||||
dtype = mx.float16 if float16 else mx.float32
|
||||
|
||||
print(f"Converting safetensors to MLX format...")
|
||||
print(f"Input: {safetensors_path}")
|
||||
print(f"Output: {output_path}")
|
||||
print(f"Target dtype: {dtype}")
|
||||
|
||||
# Load and convert weights
|
||||
weights = {}
|
||||
bfloat16_count = 0
|
||||
|
||||
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
tensor = f.get_tensor(key)
|
||||
|
||||
# Handle BFloat16
|
||||
# if tensor.dtype == torch.bfloat16:
|
||||
# bfloat16_count += 1
|
||||
# tensor = tensor.float() # Convert to float32 first
|
||||
|
||||
value = mx.array(tensor.numpy())#.astype(dtype)
|
||||
|
||||
# Apply mapping
|
||||
mapped = map_t5_encoder_weights(key, value)
|
||||
|
||||
for new_key, new_value in mapped:
|
||||
weights[new_key] = new_value
|
||||
|
||||
if bfloat16_count > 0:
|
||||
print(f"⚠️ Converted {bfloat16_count} BFloat16 tensors to float32")
|
||||
|
||||
# Save as MLX format
|
||||
print(f"Saving {len(weights)} parameters to: {output_path}")
|
||||
mx.save_safetensors(output_path, weights)
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
# Example usage functions
|
||||
def example_usage():
|
||||
"""Example of how to use the converter with BFloat16 handling"""
|
||||
|
||||
safetensors_file = "Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.safetensors"
|
||||
|
||||
# Step 1: Check dtypes first
|
||||
print("=== Step 1: Check dtypes ===")
|
||||
check_safetensors_dtypes(safetensors_file)
|
||||
|
||||
# Step 2: Debug the mapping
|
||||
print("\n=== Step 2: Debug weight mapping ===")
|
||||
debug_weight_mapping(safetensors_file)
|
||||
|
||||
# Step 3: Convert to MLX weights file
|
||||
print("\n=== Step 3: Convert to MLX ===")
|
||||
weights = convert_safetensors_to_mlx_weights(
|
||||
safetensors_file,
|
||||
"Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.safetensors",
|
||||
float16=False # Use float32 to avoid precision loss from BFloat16
|
||||
)
|
||||
|
||||
# Step 4: Load into MLX model (example)
|
||||
print("\n=== Step 4: Load into MLX model ===")
|
||||
# model = T5Encoder # Your MLX model
|
||||
# model = load_t5_encoder_from_safetensors(
|
||||
# safetensors_file,
|
||||
# model,
|
||||
# float16=False
|
||||
# )
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run debug to see mappings
|
||||
# debug_weight_mapping("Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.safetensors")
|
||||
example_usage()
|
||||
|
||||
# Or convert weights
|
||||
# convert_safetensors_to_mlx_weights("your_model.safetensors", "your_model_mlx.npz")
|
||||
|
||||
print("T5 converter ready!")
|
278
video/Wan2.1/wan/t5_torch_to_sf.py
Normal file
@ -0,0 +1,278 @@
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
from wan.modules.t5_mlx import T5Model
|
||||
|
||||
|
||||
def convert_pickle_to_safetensors(
|
||||
pickle_path: str,
|
||||
safetensors_path: str,
|
||||
model_class=None,
|
||||
model_kwargs=None,
|
||||
load_method: str = "weights_only" # Changed default to weights_only
|
||||
):
|
||||
"""Convert PyTorch pickle file to safetensors format."""
|
||||
|
||||
print(f"Loading PyTorch weights from: {pickle_path}")
|
||||
|
||||
# Try multiple loading methods in order of preference
|
||||
methods_to_try = [load_method, "weights_only", "state_dict", "full_model"]
|
||||
|
||||
for method in methods_to_try:
|
||||
try:
|
||||
if method == "weights_only":
|
||||
state_dict = torch.load(pickle_path, map_location='cpu', weights_only=True)
|
||||
|
||||
elif method == "state_dict":
|
||||
checkpoint = torch.load(pickle_path, map_location='cpu')
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif isinstance(checkpoint, dict) and 'model' in checkpoint:
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
elif method == "full_model":
|
||||
model = torch.load(pickle_path, map_location='cpu')
|
||||
if hasattr(model, 'state_dict'):
|
||||
state_dict = model.state_dict()
|
||||
else:
|
||||
state_dict = model
|
||||
|
||||
print(f"✅ Successfully loaded with method: {method}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Method {method} failed: {e}")
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(f"All loading methods failed for {pickle_path}")
|
||||
|
||||
# Clean up the state dict
|
||||
state_dict = clean_state_dict(state_dict)
|
||||
|
||||
print(f"Found {len(state_dict)} parameters")
|
||||
|
||||
# Convert BF16 to FP32 if needed
|
||||
for key, tensor in state_dict.items():
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
state_dict[key] = tensor.to(torch.float32)
|
||||
print(f"Converted {key} from bfloat16 to float32")
|
||||
|
||||
# Save as safetensors
|
||||
print(f"Saving to safetensors: {safetensors_path}")
|
||||
os.makedirs(os.path.dirname(safetensors_path), exist_ok=True)
|
||||
save_file(state_dict, safetensors_path)
|
||||
|
||||
print("✅ T5 conversion complete!")
|
||||
return state_dict
|
||||
|
||||
|
||||
def clean_state_dict(state_dict):
|
||||
"""
|
||||
Clean up state dict by removing unwanted prefixes or keys.
|
||||
"""
|
||||
cleaned = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
# Remove common prefixes that might interfere
|
||||
clean_key = key
|
||||
|
||||
if clean_key.startswith('module.'):
|
||||
clean_key = clean_key[7:]
|
||||
|
||||
if clean_key.startswith('model.'):
|
||||
clean_key = clean_key[6:]
|
||||
|
||||
cleaned[clean_key] = value
|
||||
|
||||
if clean_key != key:
|
||||
print(f"Cleaned key: {key} -> {clean_key}")
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
def load_with_your_torch_model(pickle_path: str, model_class, **model_kwargs):
|
||||
"""
|
||||
Load pickle weights into your specific PyTorch T5 model implementation.
|
||||
|
||||
Args:
|
||||
pickle_path: Path to pickle file
|
||||
model_class: Your T5Encoder class
|
||||
**model_kwargs: Arguments for your model constructor
|
||||
"""
|
||||
|
||||
print("Method 1: Loading into your PyTorch T5 model")
|
||||
|
||||
# Initialize your model
|
||||
model = model_class(**model_kwargs)
|
||||
|
||||
# Load checkpoint
|
||||
checkpoint = torch.load(pickle_path, map_location='cpu')
|
||||
|
||||
# Handle different checkpoint formats
|
||||
if isinstance(checkpoint, dict):
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model' in checkpoint:
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
# Assume the dict IS the state dict
|
||||
state_dict = checkpoint
|
||||
else:
|
||||
# Assume it's a model object
|
||||
state_dict = checkpoint.state_dict()
|
||||
|
||||
# Clean and load
|
||||
state_dict = clean_state_dict(state_dict)
|
||||
model.load_state_dict(state_dict, strict=False) # Use strict=False to ignore missing keys
|
||||
|
||||
return model, state_dict
|
||||
|
||||
|
||||
def explore_pickle_file(pickle_path: str):
|
||||
"""
|
||||
Explore the contents of a pickle file to understand its structure.
|
||||
"""
|
||||
print(f"🔍 Exploring pickle file: {pickle_path}")
|
||||
|
||||
try:
|
||||
# Try loading with weights_only first (safer)
|
||||
print("\n--- Trying weights_only=True ---")
|
||||
try:
|
||||
data = torch.load(pickle_path, map_location='cpu', weights_only=True)
|
||||
print(f"✅ Loaded with weights_only=True")
|
||||
print(f"Type: {type(data)}")
|
||||
|
||||
if isinstance(data, dict):
|
||||
print(f"Dictionary with {len(data)} keys:")
|
||||
for i, key in enumerate(data.keys()):
|
||||
print(f" {key}: {type(data[key])}")
|
||||
if hasattr(data[key], 'shape'):
|
||||
print(f" Shape: {data[key].shape}")
|
||||
if i >= 9: # Show first 10 keys
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ weights_only=True failed: {e}")
|
||||
|
||||
# Try regular loading
|
||||
print("\n--- Trying regular loading ---")
|
||||
data = torch.load(pickle_path, map_location='cpu')
|
||||
print(f"✅ Loaded successfully")
|
||||
print(f"Type: {type(data)}")
|
||||
|
||||
if hasattr(data, 'state_dict'):
|
||||
print("📋 Found state_dict method")
|
||||
state_dict = data.state_dict()
|
||||
print(f"State dict has {len(state_dict)} parameters")
|
||||
|
||||
elif isinstance(data, dict):
|
||||
print(f"📋 Dictionary with keys: {list(data.keys())}")
|
||||
|
||||
# Check for common checkpoint keys
|
||||
if 'state_dict' in data:
|
||||
print("Found 'state_dict' key")
|
||||
print(f"state_dict has {len(data['state_dict'])} parameters")
|
||||
elif 'model' in data:
|
||||
print("Found 'model' key")
|
||||
print(f"model has {len(data['model'])} parameters")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to load: {e}")
|
||||
|
||||
|
||||
def full_conversion_pipeline(
|
||||
pickle_path: str,
|
||||
safetensors_path: str,
|
||||
torch_model_class=None,
|
||||
model_kwargs=None
|
||||
):
|
||||
"""
|
||||
Complete pipeline: pickle -> safetensors -> ready for MLX conversion
|
||||
"""
|
||||
|
||||
print("🚀 Starting full conversion pipeline")
|
||||
print("="*50)
|
||||
|
||||
# Step 1: Explore the pickle file
|
||||
print("Step 1: Exploring pickle file structure")
|
||||
explore_pickle_file(pickle_path)
|
||||
|
||||
# Step 2: Convert to safetensors
|
||||
print(f"\nStep 2: Converting to safetensors")
|
||||
|
||||
# Try different loading methods
|
||||
for method in ["weights_only", "state_dict", "full_model"]:
|
||||
try:
|
||||
print(f"\nTrying load method: {method}")
|
||||
state_dict = convert_pickle_to_safetensors(
|
||||
pickle_path,
|
||||
safetensors_path,
|
||||
model_class=torch_model_class,
|
||||
model_kwargs=model_kwargs,
|
||||
load_method=method
|
||||
)
|
||||
print(f"✅ Success with method: {method}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Method {method} failed: {e}")
|
||||
continue
|
||||
else:
|
||||
print("❌ All methods failed!")
|
||||
return None
|
||||
|
||||
print(f"\n🎉 Conversion complete!")
|
||||
print(f"Safetensors file saved to: {safetensors_path}")
|
||||
print(f"Ready for MLX conversion using the previous script!")
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
# Example usage
|
||||
def example_usage():
|
||||
"""
|
||||
Example of how to use the conversion functions
|
||||
"""
|
||||
|
||||
# Your model class and parameters
|
||||
# class YourT5Encoder(nn.Module):
|
||||
# def __init__(self, vocab_size, d_model, ...):
|
||||
# ...
|
||||
|
||||
pickle_file = "Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth"
|
||||
safetensors_file = "Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.safetensors"
|
||||
|
||||
# Method 1: Quick exploration
|
||||
print("=== Exploring pickle file ===")
|
||||
explore_pickle_file(pickle_file)
|
||||
|
||||
# Method 2: Full pipeline
|
||||
print("\n=== Full conversion pipeline ===")
|
||||
state_dict = full_conversion_pipeline(
|
||||
pickle_file,
|
||||
safetensors_file,
|
||||
torch_model_class=T5Model, # Your model class
|
||||
model_kwargs={
|
||||
'vocab_size': 256384,
|
||||
'd_model': 4096,
|
||||
'num_layers': 24,
|
||||
# ... other parameters
|
||||
}
|
||||
)
|
||||
|
||||
# Method 3: Direct conversion (if you know the format)
|
||||
print("\n=== Direct conversion ===")
|
||||
# state_dict = convert_pickle_to_safetensors(
|
||||
# pickle_file,
|
||||
# safetensors_file,
|
||||
# load_method="state_dict" # or "weights_only" or "full_model"
|
||||
# )
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
example_usage()
|
312
video/Wan2.1/wan/text2video_mlx.py
Normal file
@ -0,0 +1,312 @@
|
||||
import glob
|
||||
import gc
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from tqdm import tqdm
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .modules.model_mlx import WanModel
|
||||
from .modules.t5_mlx import T5EncoderModel
|
||||
from .modules.vae_mlx import WanVAE
|
||||
from .utils.fm_solvers_mlx import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
|
||||
from .utils.fm_solvers_unipc_mlx import FlowUniPCMultistepScheduler
|
||||
from .wan_model_io import load_wan_from_safetensors
|
||||
|
||||
|
||||
class WanT2V:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
checkpoint_dir,
|
||||
):
|
||||
r"""
|
||||
Initializes the Wan text-to-video generation model components.
|
||||
|
||||
Args:
|
||||
config (EasyDict):
|
||||
Object containing model parameters initialized from config.py
|
||||
checkpoint_dir (`str`):
|
||||
Path to directory containing model checkpoints
|
||||
"""
|
||||
self.config = config
|
||||
self.num_train_timesteps = config.num_train_timesteps
|
||||
self.param_dtype = mx.float16 if config.param_dtype == 'float16' else mx.float32
|
||||
|
||||
# Initialize T5 text encoder - with automatic conversion
|
||||
t5_checkpoint_path = os.path.join(checkpoint_dir, config.t5_checkpoint)
|
||||
mlx_t5_path = t5_checkpoint_path.replace('.safetensors', '_mlx.safetensors')
|
||||
if not os.path.exists(mlx_t5_path):
|
||||
# Check if it's a .pth file that needs conversion
|
||||
pth_path = t5_checkpoint_path.replace('.safetensors', '.pth')
|
||||
if os.path.exists(pth_path):
|
||||
logging.info(f"Converting T5 PyTorch model to safetensors: {pth_path}")
|
||||
from .t5_torch_to_sf import convert_pickle_to_safetensors
|
||||
convert_pickle_to_safetensors(pth_path, t5_checkpoint_path, load_method="weights_only")
|
||||
# Convert torch safetensors to MLX safetensors
|
||||
from .t5_model_io import convert_safetensors_to_mlx_weights
|
||||
convert_safetensors_to_mlx_weights(t5_checkpoint_path, mlx_t5_path, float16=(self.param_dtype == mx.float16))
|
||||
|
||||
else:
|
||||
raise FileNotFoundError(f"T5 checkpoint not found: {t5_checkpoint_path} or {pth_path}")
|
||||
|
||||
t5_checkpoint_path = mlx_t5_path # Use the MLX version
|
||||
logging.info(f"Loading T5 text encoder... from {t5_checkpoint_path}")
|
||||
self.text_encoder = T5EncoderModel(
|
||||
text_len=config.text_len,
|
||||
checkpoint_path=t5_checkpoint_path,
|
||||
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer))
|
||||
|
||||
# Initialize VAE
|
||||
self.vae_stride = config.vae_stride
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
# Initialize VAE - with automatic conversion
|
||||
vae_path = os.path.join(checkpoint_dir, config.vae_checkpoint)
|
||||
if not os.path.exists(vae_path):
|
||||
# Check for PyTorch VAE file to convert
|
||||
pth_vae_path = vae_path.replace('_mlx.safetensors', '.pth')
|
||||
if not os.path.exists(pth_vae_path):
|
||||
# Try alternative naming
|
||||
pth_vae_path = os.path.join(checkpoint_dir, 'Wan2.1_VAE.pth')
|
||||
|
||||
if os.path.exists(pth_vae_path):
|
||||
logging.info(f"Converting VAE PyTorch model to MLX: {pth_vae_path}")
|
||||
from .vae_model_io import convert_pytorch_to_mlx
|
||||
convert_pytorch_to_mlx(pth_vae_path, vae_path, float16=(self.param_dtype == mx.float16))
|
||||
else:
|
||||
raise FileNotFoundError(f"VAE checkpoint not found: {vae_path} or {pth_vae_path}")
|
||||
|
||||
logging.info("Loading VAE...")
|
||||
self.vae = WanVAE(vae_pth=vae_path)
|
||||
|
||||
# Initialize WanModel
|
||||
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
||||
|
||||
# Create model with config parameters
|
||||
self.model = WanModel(
|
||||
model_type='t2v',
|
||||
patch_size=config.patch_size,
|
||||
text_len=config.text_len,
|
||||
in_dim=16,
|
||||
dim=config.dim,
|
||||
ffn_dim=config.ffn_dim,
|
||||
freq_dim=config.freq_dim,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=config.num_heads,
|
||||
num_layers=config.num_layers,
|
||||
window_size=getattr(config, 'window_size', (-1, -1)),
|
||||
qk_norm=getattr(config, 'qk_norm', True),
|
||||
cross_attn_norm=getattr(config, 'cross_attn_norm', True),
|
||||
eps=getattr(config, 'eps', 1e-6)
|
||||
)
|
||||
|
||||
# In WanT2V.__init__, replace the model loading section with:
|
||||
|
||||
# Load pretrained weights - with automatic conversion
|
||||
model_path = os.path.join(checkpoint_dir, "diffusion_pytorch_model_mlx.safetensors")
|
||||
if not os.path.exists(model_path):
|
||||
# Check for directory with multiple files (14B model)
|
||||
pattern = os.path.join(checkpoint_dir, "diffusion_mlx_model*.safetensors")
|
||||
mlx_files = glob.glob(pattern)
|
||||
|
||||
if not mlx_files:
|
||||
# No MLX files found, look for PyTorch files to convert
|
||||
pytorch_path = os.path.join(checkpoint_dir, "diffusion_pytorch_model.safetensors")
|
||||
pytorch_pattern = os.path.join(checkpoint_dir, "diffusion_pytorch_model-*.safetensors")
|
||||
pytorch_files = glob.glob(pytorch_pattern)
|
||||
|
||||
if os.path.exists(pytorch_path):
|
||||
logging.info(f"Converting single PyTorch model to MLX: {pytorch_path}")
|
||||
from .wan_model_io import convert_safetensors_to_mlx_weights
|
||||
convert_safetensors_to_mlx_weights(
|
||||
pytorch_path,
|
||||
model_path,
|
||||
float16=(self.param_dtype == mx.float16)
|
||||
)
|
||||
elif pytorch_files:
|
||||
logging.info(f"Converting {len(pytorch_files)} PyTorch model files to MLX")
|
||||
from .wan_model_io import convert_multiple_safetensors_to_mlx
|
||||
convert_multiple_safetensors_to_mlx(
|
||||
checkpoint_dir,
|
||||
float16=(self.param_dtype == mx.float16)
|
||||
)
|
||||
else:
|
||||
raise FileNotFoundError(f"No PyTorch model files found in {checkpoint_dir}")
|
||||
|
||||
# Load the model (now MLX format exists)
|
||||
if os.path.exists(model_path):
|
||||
# Single file (1.3B)
|
||||
logging.info(f"Loading model weights from {model_path}")
|
||||
self.model = load_wan_from_safetensors(model_path, self.model, float16=(self.param_dtype == mx.float16))
|
||||
else:
|
||||
# Multiple files (14B)
|
||||
logging.info(f"Loading model weights from directory {checkpoint_dir}")
|
||||
self.model = load_wan_from_safetensors(checkpoint_dir, self.model, float16=(self.param_dtype == mx.float16))
|
||||
|
||||
# Set model to eval mode
|
||||
self.model.eval()
|
||||
|
||||
self.sp_size = 1 # No sequence parallelism in MLX version
|
||||
self.sample_neg_prompt = config.sample_neg_prompt
|
||||
|
||||
def generate(self,
|
||||
input_prompt,
|
||||
size=(1280, 720),
|
||||
frame_num=81,
|
||||
shift=5.0,
|
||||
sample_solver='unipc',
|
||||
sampling_steps=50,
|
||||
guide_scale=5.0,
|
||||
n_prompt="",
|
||||
seed=-1,
|
||||
offload_model=True):
|
||||
r"""
|
||||
Generates video frames from text prompt using diffusion process.
|
||||
|
||||
Args:
|
||||
input_prompt (`str`):
|
||||
Text prompt for content generation
|
||||
size (tuple[`int`], *optional*, defaults to (1280,720)):
|
||||
Controls video resolution, (width,height).
|
||||
frame_num (`int`, *optional*, defaults to 81):
|
||||
How many frames to sample from a video. The number should be 4n+1
|
||||
shift (`float`, *optional*, defaults to 5.0):
|
||||
Noise schedule shift parameter. Affects temporal dynamics
|
||||
sample_solver (`str`, *optional*, defaults to 'unipc'):
|
||||
Solver used to sample the video.
|
||||
sampling_steps (`int`, *optional*, defaults to 50):
|
||||
Number of diffusion sampling steps. Higher values improve quality but slow generation
|
||||
guide_scale (`float`, *optional*, defaults 5.0):
|
||||
Classifier-free guidance scale. Controls prompt adherence vs. creativity
|
||||
n_prompt (`str`, *optional*, defaults to ""):
|
||||
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
|
||||
seed (`int`, *optional*, defaults to -1):
|
||||
Random seed for noise generation. If -1, use random seed.
|
||||
offload_model (`bool`, *optional*, defaults to True):
|
||||
If True, offloads models to CPU during generation to save memory
|
||||
|
||||
Returns:
|
||||
mx.array:
|
||||
Generated video frames tensor. Dimensions: (C, N, H, W) where:
|
||||
- C: Color channels (3 for RGB)
|
||||
- N: Number of frames
|
||||
- H: Frame height (from size)
|
||||
- W: Frame width (from size)
|
||||
"""
|
||||
# Preprocess
|
||||
F = frame_num
|
||||
target_shape = (
|
||||
self.vae.model.z_dim,
|
||||
(F - 1) // self.vae_stride[0] + 1,
|
||||
size[1] // self.vae_stride[1],
|
||||
size[0] // self.vae_stride[2]
|
||||
)
|
||||
|
||||
seq_len = math.ceil(
|
||||
(target_shape[2] * target_shape[3]) /
|
||||
(self.patch_size[1] * self.patch_size[2]) *
|
||||
target_shape[1] / self.sp_size
|
||||
) * self.sp_size
|
||||
|
||||
if n_prompt == "":
|
||||
n_prompt = self.sample_neg_prompt
|
||||
|
||||
# Set random seed
|
||||
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Encode text prompts
|
||||
logging.info("Encoding text prompts...")
|
||||
context = self.text_encoder([input_prompt])
|
||||
context_null = self.text_encoder([n_prompt])
|
||||
|
||||
# Generate initial noise
|
||||
noise = [
|
||||
mx.random.normal(
|
||||
shape=target_shape,
|
||||
dtype=mx.float32
|
||||
)
|
||||
]
|
||||
|
||||
# Initialize scheduler
|
||||
if sample_solver == 'unipc':
|
||||
sample_scheduler = FlowUniPCMultistepScheduler(
|
||||
num_train_timesteps=self.num_train_timesteps,
|
||||
shift=1,
|
||||
use_dynamic_shifting=False
|
||||
)
|
||||
sample_scheduler.set_timesteps(
|
||||
sampling_steps, shift=shift
|
||||
)
|
||||
timesteps = sample_scheduler.timesteps
|
||||
elif sample_solver == 'dpm++':
|
||||
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
||||
num_train_timesteps=self.num_train_timesteps,
|
||||
shift=1,
|
||||
use_dynamic_shifting=False
|
||||
)
|
||||
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
||||
timesteps, _ = retrieve_timesteps(
|
||||
sample_scheduler,
|
||||
sigmas=sampling_sigmas
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported solver: {sample_solver}")
|
||||
|
||||
# Sample videos
|
||||
latents = noise
|
||||
|
||||
arg_c = {'context': context, 'seq_len': seq_len}
|
||||
arg_null = {'context': context_null, 'seq_len': seq_len}
|
||||
|
||||
logging.info(f"Generating video with {len(timesteps)} steps...")
|
||||
|
||||
for _, t in enumerate(tqdm(timesteps)):
|
||||
latent_model_input = latents
|
||||
timestep = mx.array([t])
|
||||
|
||||
# Model predictions
|
||||
noise_pred_cond = self.model(
|
||||
latent_model_input, t=timestep, **arg_c
|
||||
)[0]
|
||||
noise_pred_uncond = self.model(
|
||||
latent_model_input, t=timestep, **arg_null
|
||||
)[0]
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred = noise_pred_uncond + guide_scale * (
|
||||
noise_pred_cond - noise_pred_uncond
|
||||
)
|
||||
|
||||
# Scheduler step
|
||||
temp_x0 = sample_scheduler.step(
|
||||
mx.expand_dims(noise_pred, 0),
|
||||
t,
|
||||
mx.expand_dims(latents[0], 0),
|
||||
return_dict=False
|
||||
)[0]
|
||||
latents = [mx.squeeze(temp_x0, 0)]
|
||||
mx.eval(latents)
|
||||
|
||||
x0 = latents
|
||||
|
||||
# Decode latents to video
|
||||
logging.info("Decoding latents to video...")
|
||||
|
||||
videos = self.vae.decode(x0)
|
||||
|
||||
# Memory cleanup
|
||||
del noise, latents, sample_scheduler
|
||||
if offload_model:
|
||||
mx.eval(videos) # Ensure computation is complete
|
||||
gc.collect()
|
||||
|
||||
return videos[0]
|
8
video/Wan2.1/wan/utils/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
from .fm_solvers_mlx import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,
|
||||
retrieve_timesteps)
|
||||
from .fm_solvers_unipc_mlx import FlowUniPCMultistepScheduler
|
||||
|
||||
__all__ = [
|
||||
'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
|
||||
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler'
|
||||
]
|
562
video/Wan2.1/wan/utils/fm_solvers_mlx.py
Normal file
@ -0,0 +1,562 @@
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_sampling_sigmas(sampling_steps, shift):
|
||||
sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
|
||||
sigma = (shift * sigma / (1 + (shift - 1) * sigma))
|
||||
return sigma
|
||||
|
||||
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps=None,
|
||||
device=None,
|
||||
timesteps=None,
|
||||
sigmas=None,
|
||||
**kwargs,
|
||||
):
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError(
|
||||
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
|
||||
)
|
||||
if timesteps is not None:
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class SchedulerOutput:
|
||||
"""Output class for scheduler step results."""
|
||||
def __init__(self, prev_sample: mx.array):
|
||||
self.prev_sample = prev_sample
|
||||
|
||||
|
||||
class FlowDPMSolverMultistepScheduler:
|
||||
"""
|
||||
MLX implementation of FlowDPMSolverMultistepScheduler.
|
||||
A fast dedicated high-order solver for diffusion ODEs.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
solver_order: int = 2,
|
||||
prediction_type: str = "flow_prediction",
|
||||
shift: Optional[float] = 1.0,
|
||||
use_dynamic_shifting: bool = False,
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
algorithm_type: str = "dpmsolver++",
|
||||
solver_type: str = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
euler_at_final: bool = False,
|
||||
final_sigmas_type: Optional[str] = "zero",
|
||||
lambda_min_clipped: float = -float("inf"),
|
||||
variance_type: Optional[str] = None,
|
||||
invert_sigmas: bool = False,
|
||||
):
|
||||
# Store configuration
|
||||
self.config = {
|
||||
'num_train_timesteps': num_train_timesteps,
|
||||
'solver_order': solver_order,
|
||||
'prediction_type': prediction_type,
|
||||
'shift': shift,
|
||||
'use_dynamic_shifting': use_dynamic_shifting,
|
||||
'thresholding': thresholding,
|
||||
'dynamic_thresholding_ratio': dynamic_thresholding_ratio,
|
||||
'sample_max_value': sample_max_value,
|
||||
'algorithm_type': algorithm_type,
|
||||
'solver_type': solver_type,
|
||||
'lower_order_final': lower_order_final,
|
||||
'euler_at_final': euler_at_final,
|
||||
'final_sigmas_type': final_sigmas_type,
|
||||
'lambda_min_clipped': lambda_min_clipped,
|
||||
'variance_type': variance_type,
|
||||
'invert_sigmas': invert_sigmas,
|
||||
}
|
||||
|
||||
# Validate algorithm type
|
||||
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
|
||||
if algorithm_type == "deis":
|
||||
self.config['algorithm_type'] = "dpmsolver++"
|
||||
else:
|
||||
raise NotImplementedError(f"{algorithm_type} is not implemented")
|
||||
|
||||
# Validate solver type
|
||||
if solver_type not in ["midpoint", "heun"]:
|
||||
if solver_type in ["logrho", "bh1", "bh2"]:
|
||||
self.config['solver_type'] = "midpoint"
|
||||
else:
|
||||
raise NotImplementedError(f"{solver_type} is not implemented")
|
||||
|
||||
# Initialize scheduling
|
||||
self.num_inference_steps = None
|
||||
alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy()
|
||||
sigmas = 1.0 - alphas
|
||||
sigmas = mx.array(sigmas, dtype=mx.float32)
|
||||
|
||||
if not use_dynamic_shifting:
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
|
||||
self.sigmas = sigmas
|
||||
self.timesteps = sigmas * num_train_timesteps
|
||||
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
self.sigma_min = float(self.sigmas[-1])
|
||||
self.sigma_max = float(self.sigmas[0])
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: Union[int, None] = None,
|
||||
device: Union[str, None] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
mu: Optional[Union[float, None]] = None,
|
||||
shift: Optional[Union[float, None]] = None,
|
||||
):
|
||||
"""Sets the discrete timesteps used for the diffusion chain."""
|
||||
if self.config['use_dynamic_shifting'] and mu is None:
|
||||
raise ValueError(
|
||||
"you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
|
||||
)
|
||||
|
||||
if sigmas is None:
|
||||
sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1]
|
||||
|
||||
if self.config['use_dynamic_shifting']:
|
||||
sigmas = self.time_shift(mu, 1.0, sigmas)
|
||||
else:
|
||||
if shift is None:
|
||||
shift = self.config['shift']
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
|
||||
if self.config['final_sigmas_type'] == "sigma_min":
|
||||
sigma_last = self.sigma_min
|
||||
elif self.config['final_sigmas_type'] == "zero":
|
||||
sigma_last = 0
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config['final_sigmas_type']}"
|
||||
)
|
||||
|
||||
timesteps = sigmas * self.config['num_train_timesteps']
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
|
||||
self.sigmas = mx.array(sigmas)
|
||||
self.timesteps = mx.array(timesteps, dtype=mx.int64)
|
||||
|
||||
self.num_inference_steps = len(timesteps)
|
||||
|
||||
self.model_outputs = [None] * self.config['solver_order']
|
||||
self.lower_order_nums = 0
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def _threshold_sample(self, sample: mx.array) -> mx.array:
|
||||
"""Dynamic thresholding method."""
|
||||
dtype = sample.dtype
|
||||
batch_size, channels, *remaining_dims = sample.shape
|
||||
|
||||
# Flatten sample for quantile calculation
|
||||
sample_flat = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
||||
|
||||
abs_sample = mx.abs(sample_flat)
|
||||
|
||||
# Compute quantile
|
||||
s = mx.quantile(
|
||||
abs_sample,
|
||||
self.config['dynamic_thresholding_ratio'],
|
||||
axis=1,
|
||||
keepdims=True
|
||||
)
|
||||
s = mx.clip(s, 1, self.config['sample_max_value'])
|
||||
|
||||
# Threshold and normalize
|
||||
sample_flat = mx.clip(sample_flat, -s, s) / s
|
||||
|
||||
sample = sample_flat.reshape(batch_size, channels, *remaining_dims)
|
||||
return sample.astype(dtype)
|
||||
|
||||
def _sigma_to_t(self, sigma):
|
||||
return sigma * self.config['num_train_timesteps']
|
||||
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
return 1 - sigma, sigma
|
||||
|
||||
def time_shift(self, mu: float, sigma: float, t: mx.array):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
|
||||
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
sample: mx.array,
|
||||
**kwargs,
|
||||
) -> mx.array:
|
||||
"""Convert model output to the corresponding type the algorithm needs."""
|
||||
# DPM-Solver++ needs to solve an integral of the data prediction model
|
||||
if self.config['algorithm_type'] in ["dpmsolver++", "sde-dpmsolver++"]:
|
||||
if self.config['prediction_type'] == "flow_prediction":
|
||||
sigma_t = self.sigmas[self.step_index]
|
||||
x0_pred = sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config['prediction_type']} must be "
|
||||
f"'flow_prediction' for the FlowDPMSolverMultistepScheduler."
|
||||
)
|
||||
|
||||
if self.config['thresholding']:
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
|
||||
return x0_pred
|
||||
|
||||
# DPM-Solver needs to solve an integral of the noise prediction model
|
||||
elif self.config['algorithm_type'] in ["dpmsolver", "sde-dpmsolver"]:
|
||||
if self.config['prediction_type'] == "flow_prediction":
|
||||
sigma_t = self.sigmas[self.step_index]
|
||||
epsilon = sample - (1 - sigma_t) * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config['prediction_type']} must be "
|
||||
f"'flow_prediction' for the FlowDPMSolverMultistepScheduler."
|
||||
)
|
||||
|
||||
if self.config['thresholding']:
|
||||
sigma_t = self.sigmas[self.step_index]
|
||||
x0_pred = sample - sigma_t * model_output
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
epsilon = model_output + x0_pred
|
||||
|
||||
return epsilon
|
||||
|
||||
def dpm_solver_first_order_update(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
sample: mx.array,
|
||||
noise: Optional[mx.array] = None,
|
||||
**kwargs,
|
||||
) -> mx.array:
|
||||
"""One step for the first-order DPMSolver (equivalent to DDIM)."""
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
|
||||
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
|
||||
lambda_s = mx.log(alpha_s) - mx.log(sigma_s)
|
||||
|
||||
h = lambda_t - lambda_s
|
||||
|
||||
if self.config['algorithm_type'] == "dpmsolver++":
|
||||
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (mx.exp(-h) - 1.0)) * model_output
|
||||
elif self.config['algorithm_type'] == "dpmsolver":
|
||||
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (mx.exp(h) - 1.0)) * model_output
|
||||
elif self.config['algorithm_type'] == "sde-dpmsolver++":
|
||||
assert noise is not None
|
||||
x_t = (
|
||||
(sigma_t / sigma_s * mx.exp(-h)) * sample +
|
||||
(alpha_t * (1 - mx.exp(-2.0 * h))) * model_output +
|
||||
sigma_t * mx.sqrt(1.0 - mx.exp(-2 * h)) * noise
|
||||
)
|
||||
elif self.config['algorithm_type'] == "sde-dpmsolver":
|
||||
assert noise is not None
|
||||
x_t = (
|
||||
(alpha_t / alpha_s) * sample -
|
||||
2.0 * (sigma_t * (mx.exp(h) - 1.0)) * model_output +
|
||||
sigma_t * mx.sqrt(mx.exp(2 * h) - 1.0) * noise
|
||||
)
|
||||
|
||||
return x_t
|
||||
|
||||
def multistep_dpm_solver_second_order_update(
|
||||
self,
|
||||
model_output_list: List[mx.array],
|
||||
sample: mx.array,
|
||||
noise: Optional[mx.array] = None,
|
||||
**kwargs,
|
||||
) -> mx.array:
|
||||
"""One step for the second-order multistep DPMSolver."""
|
||||
sigma_t, sigma_s0, sigma_s1 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
|
||||
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
|
||||
lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0)
|
||||
lambda_s1 = mx.log(alpha_s1) - mx.log(sigma_s1)
|
||||
|
||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||
|
||||
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
||||
r0 = h_0 / h
|
||||
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
||||
|
||||
if self.config['algorithm_type'] == "dpmsolver++":
|
||||
if self.config['solver_type'] == "midpoint":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0) * sample -
|
||||
(alpha_t * (mx.exp(-h) - 1.0)) * D0 -
|
||||
0.5 * (alpha_t * (mx.exp(-h) - 1.0)) * D1
|
||||
)
|
||||
elif self.config['solver_type'] == "heun":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0) * sample -
|
||||
(alpha_t * (mx.exp(-h) - 1.0)) * D0 +
|
||||
(alpha_t * ((mx.exp(-h) - 1.0) / h + 1.0)) * D1
|
||||
)
|
||||
elif self.config['algorithm_type'] == "dpmsolver":
|
||||
if self.config['solver_type'] == "midpoint":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample -
|
||||
(sigma_t * (mx.exp(h) - 1.0)) * D0 -
|
||||
0.5 * (sigma_t * (mx.exp(h) - 1.0)) * D1
|
||||
)
|
||||
elif self.config['solver_type'] == "heun":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample -
|
||||
(sigma_t * (mx.exp(h) - 1.0)) * D0 -
|
||||
(sigma_t * ((mx.exp(h) - 1.0) / h - 1.0)) * D1
|
||||
)
|
||||
elif self.config['algorithm_type'] == "sde-dpmsolver++":
|
||||
assert noise is not None
|
||||
if self.config['solver_type'] == "midpoint":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0 * mx.exp(-h)) * sample +
|
||||
(alpha_t * (1 - mx.exp(-2.0 * h))) * D0 +
|
||||
0.5 * (alpha_t * (1 - mx.exp(-2.0 * h))) * D1 +
|
||||
sigma_t * mx.sqrt(1.0 - mx.exp(-2 * h)) * noise
|
||||
)
|
||||
elif self.config['solver_type'] == "heun":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0 * mx.exp(-h)) * sample +
|
||||
(alpha_t * (1 - mx.exp(-2.0 * h))) * D0 +
|
||||
(alpha_t * ((1.0 - mx.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 +
|
||||
sigma_t * mx.sqrt(1.0 - mx.exp(-2 * h)) * noise
|
||||
)
|
||||
elif self.config['algorithm_type'] == "sde-dpmsolver":
|
||||
assert noise is not None
|
||||
if self.config['solver_type'] == "midpoint":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample -
|
||||
2.0 * (sigma_t * (mx.exp(h) - 1.0)) * D0 -
|
||||
(sigma_t * (mx.exp(h) - 1.0)) * D1 +
|
||||
sigma_t * mx.sqrt(mx.exp(2 * h) - 1.0) * noise
|
||||
)
|
||||
elif self.config['solver_type'] == "heun":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample -
|
||||
2.0 * (sigma_t * (mx.exp(h) - 1.0)) * D0 -
|
||||
2.0 * (sigma_t * ((mx.exp(h) - 1.0) / h - 1.0)) * D1 +
|
||||
sigma_t * mx.sqrt(mx.exp(2 * h) - 1.0) * noise
|
||||
)
|
||||
|
||||
return x_t
|
||||
|
||||
def multistep_dpm_solver_third_order_update(
|
||||
self,
|
||||
model_output_list: List[mx.array],
|
||||
sample: mx.array,
|
||||
**kwargs,
|
||||
) -> mx.array:
|
||||
"""One step for the third-order multistep DPMSolver."""
|
||||
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
self.sigmas[self.step_index - 2],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
|
||||
|
||||
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
|
||||
lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0)
|
||||
lambda_s1 = mx.log(alpha_s1) - mx.log(sigma_s1)
|
||||
lambda_s2 = mx.log(alpha_s2) - mx.log(sigma_s2)
|
||||
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
|
||||
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
|
||||
r0, r1 = h_0 / h, h_1 / h
|
||||
D0 = m0
|
||||
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
|
||||
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
|
||||
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
|
||||
|
||||
if self.config['algorithm_type'] == "dpmsolver++":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0) * sample -
|
||||
(alpha_t * (mx.exp(-h) - 1.0)) * D0 +
|
||||
(alpha_t * ((mx.exp(-h) - 1.0) / h + 1.0)) * D1 -
|
||||
(alpha_t * ((mx.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
|
||||
)
|
||||
elif self.config['algorithm_type'] == "dpmsolver":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample -
|
||||
(sigma_t * (mx.exp(h) - 1.0)) * D0 -
|
||||
(sigma_t * ((mx.exp(h) - 1.0) / h - 1.0)) * D1 -
|
||||
(sigma_t * ((mx.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
|
||||
)
|
||||
|
||||
return x_t
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
indices = mx.where(schedule_timesteps == timestep)[0]
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return int(indices[pos])
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
"""Initialize the step_index counter for the scheduler."""
|
||||
if self.begin_index is None:
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
timestep: Union[int, mx.array],
|
||||
sample: mx.array,
|
||||
generator=None,
|
||||
variance_noise: Optional[mx.array] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""Predict the sample from the previous timestep."""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# Improve numerical stability for small number of steps
|
||||
lower_order_final = (
|
||||
(self.step_index == len(self.timesteps) - 1) and
|
||||
(self.config['euler_at_final'] or
|
||||
(self.config['lower_order_final'] and len(self.timesteps) < 15) or
|
||||
self.config['final_sigmas_type'] == "zero")
|
||||
)
|
||||
lower_order_second = (
|
||||
(self.step_index == len(self.timesteps) - 2) and
|
||||
self.config['lower_order_final'] and
|
||||
len(self.timesteps) < 15
|
||||
)
|
||||
|
||||
model_output = self.convert_model_output(model_output, sample=sample)
|
||||
for i in range(self.config['solver_order'] - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.model_outputs[-1] = model_output
|
||||
|
||||
# Upcast to avoid precision issues
|
||||
sample = sample.astype(mx.float32)
|
||||
|
||||
# Generate noise if needed for SDE variants
|
||||
if self.config['algorithm_type'] in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
|
||||
noise = mx.random.normal(model_output.shape, dtype=mx.float32)
|
||||
elif self.config['algorithm_type'] in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
||||
noise = variance_noise.astype(mx.float32)
|
||||
else:
|
||||
noise = None
|
||||
|
||||
if self.config['solver_order'] == 1 or self.lower_order_nums < 1 or lower_order_final:
|
||||
prev_sample = self.dpm_solver_first_order_update(
|
||||
model_output, sample=sample, noise=noise
|
||||
)
|
||||
elif self.config['solver_order'] == 2 or self.lower_order_nums < 2 or lower_order_second:
|
||||
prev_sample = self.multistep_dpm_solver_second_order_update(
|
||||
self.model_outputs, sample=sample, noise=noise
|
||||
)
|
||||
else:
|
||||
prev_sample = self.multistep_dpm_solver_third_order_update(
|
||||
self.model_outputs, sample=sample
|
||||
)
|
||||
|
||||
if self.lower_order_nums < self.config['solver_order']:
|
||||
self.lower_order_nums += 1
|
||||
|
||||
# Cast sample back to expected dtype
|
||||
prev_sample = prev_sample.astype(model_output.dtype)
|
||||
|
||||
# Increase step index
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def scale_model_input(self, sample: mx.array, *args, **kwargs) -> mx.array:
|
||||
"""Scale model input - no scaling needed for this scheduler."""
|
||||
return sample
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: mx.array,
|
||||
noise: mx.array,
|
||||
timesteps: mx.array,
|
||||
) -> mx.array:
|
||||
"""Add noise to original samples."""
|
||||
sigmas = self.sigmas.astype(original_samples.dtype)
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
# Get step indices
|
||||
if self.begin_index is None:
|
||||
step_indices = [
|
||||
self.index_for_timestep(t, schedule_timesteps)
|
||||
for t in timesteps
|
||||
]
|
||||
elif self.step_index is not None:
|
||||
step_indices = [self.step_index] * timesteps.shape[0]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices]
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = mx.expand_dims(sigma, -1)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.config['num_train_timesteps']
|
546
video/Wan2.1/wan/utils/fm_solvers_unipc_mlx.py
Normal file
@ -0,0 +1,546 @@
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SchedulerOutput:
|
||||
"""Output class for scheduler step results."""
|
||||
def __init__(self, prev_sample: mx.array):
|
||||
self.prev_sample = prev_sample
|
||||
|
||||
|
||||
class FlowUniPCMultistepScheduler:
|
||||
"""
|
||||
MLX implementation of UniPCMultistepScheduler.
|
||||
A training-free framework designed for the fast sampling of diffusion models.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
solver_order: int = 2,
|
||||
prediction_type: str = "flow_prediction",
|
||||
shift: Optional[float] = 1.0,
|
||||
use_dynamic_shifting: bool = False,
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
predict_x0: bool = True,
|
||||
solver_type: str = "bh2",
|
||||
lower_order_final: bool = True,
|
||||
disable_corrector: List[int] = [],
|
||||
solver_p = None,
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
final_sigmas_type: Optional[str] = "zero",
|
||||
):
|
||||
# Store configuration
|
||||
self.config = {
|
||||
'num_train_timesteps': num_train_timesteps,
|
||||
'solver_order': solver_order,
|
||||
'prediction_type': prediction_type,
|
||||
'shift': shift,
|
||||
'use_dynamic_shifting': use_dynamic_shifting,
|
||||
'thresholding': thresholding,
|
||||
'dynamic_thresholding_ratio': dynamic_thresholding_ratio,
|
||||
'sample_max_value': sample_max_value,
|
||||
'predict_x0': predict_x0,
|
||||
'solver_type': solver_type,
|
||||
'lower_order_final': lower_order_final,
|
||||
'disable_corrector': disable_corrector,
|
||||
'solver_p': solver_p,
|
||||
'timestep_spacing': timestep_spacing,
|
||||
'steps_offset': steps_offset,
|
||||
'final_sigmas_type': final_sigmas_type,
|
||||
}
|
||||
|
||||
# Validate solver type
|
||||
if solver_type not in ["bh1", "bh2"]:
|
||||
if solver_type in ["midpoint", "heun", "logrho"]:
|
||||
self.config['solver_type'] = "bh2"
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"{solver_type} is not implemented for {self.__class__}"
|
||||
)
|
||||
|
||||
self.predict_x0 = predict_x0
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy()
|
||||
sigmas = 1.0 - alphas
|
||||
sigmas = mx.array(sigmas, dtype=mx.float32)
|
||||
|
||||
if not use_dynamic_shifting:
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
|
||||
self.sigmas = sigmas
|
||||
self.timesteps = sigmas * num_train_timesteps
|
||||
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.timestep_list = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self.disable_corrector = disable_corrector
|
||||
self.solver_p = solver_p
|
||||
self.last_sample = None
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
self.sigma_min = float(self.sigmas[-1])
|
||||
self.sigma_max = float(self.sigmas[0])
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""The index counter for current timestep."""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""The index for the first timestep."""
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""Sets the begin index for the scheduler."""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: Union[int, None] = None,
|
||||
device: Union[str, None] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
mu: Optional[Union[float, None]] = None,
|
||||
shift: Optional[Union[float, None]] = None,
|
||||
):
|
||||
"""Sets the discrete timesteps used for the diffusion chain."""
|
||||
if self.config['use_dynamic_shifting'] and mu is None:
|
||||
raise ValueError(
|
||||
"you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
|
||||
)
|
||||
|
||||
if sigmas is None:
|
||||
sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1]
|
||||
|
||||
if self.config['use_dynamic_shifting']:
|
||||
sigmas = self.time_shift(mu, 1.0, sigmas)
|
||||
else:
|
||||
if shift is None:
|
||||
shift = self.config['shift']
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
|
||||
if self.config['final_sigmas_type'] == "sigma_min":
|
||||
sigma_last = self.sigma_min
|
||||
elif self.config['final_sigmas_type'] == "zero":
|
||||
sigma_last = 0
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config['final_sigmas_type']}"
|
||||
)
|
||||
|
||||
timesteps = sigmas * self.config['num_train_timesteps']
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
|
||||
self.sigmas = mx.array(sigmas)
|
||||
self.timesteps = mx.array(timesteps, dtype=mx.int64)
|
||||
|
||||
self.num_inference_steps = len(timesteps)
|
||||
|
||||
self.model_outputs = [None] * self.config['solver_order']
|
||||
self.lower_order_nums = 0
|
||||
self.last_sample = None
|
||||
if self.solver_p:
|
||||
self.solver_p.set_timesteps(self.num_inference_steps, device=device)
|
||||
|
||||
# add an index counter for schedulers
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def _threshold_sample(self, sample: mx.array) -> mx.array:
|
||||
"""Dynamic thresholding method."""
|
||||
dtype = sample.dtype
|
||||
batch_size, channels, *remaining_dims = sample.shape
|
||||
|
||||
# Flatten sample for quantile calculation
|
||||
sample_flat = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
||||
|
||||
abs_sample = mx.abs(sample_flat)
|
||||
|
||||
# Compute quantile
|
||||
s = mx.quantile(
|
||||
abs_sample,
|
||||
self.config['dynamic_thresholding_ratio'],
|
||||
axis=1,
|
||||
keepdims=True
|
||||
)
|
||||
s = mx.clip(s, 1, self.config['sample_max_value'])
|
||||
|
||||
# Threshold and normalize
|
||||
sample_flat = mx.clip(sample_flat, -s, s) / s
|
||||
|
||||
sample = sample_flat.reshape(batch_size, channels, *remaining_dims)
|
||||
return sample.astype(dtype)
|
||||
|
||||
def _sigma_to_t(self, sigma):
|
||||
return sigma * self.config['num_train_timesteps']
|
||||
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
return 1 - sigma, sigma
|
||||
|
||||
def time_shift(self, mu: float, sigma: float, t: mx.array):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
|
||||
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
sample: mx.array = None,
|
||||
**kwargs,
|
||||
) -> mx.array:
|
||||
"""Convert the model output to the corresponding type the UniPC algorithm needs."""
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
|
||||
if self.predict_x0:
|
||||
if self.config['prediction_type'] == "flow_prediction":
|
||||
sigma_t = self.sigmas[self.step_index]
|
||||
x0_pred = sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config['prediction_type']} must be 'flow_prediction' "
|
||||
f"for the UniPCMultistepScheduler."
|
||||
)
|
||||
|
||||
if self.config['thresholding']:
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
|
||||
return x0_pred
|
||||
else:
|
||||
if self.config['prediction_type'] == "flow_prediction":
|
||||
sigma_t = self.sigmas[self.step_index]
|
||||
epsilon = sample - (1 - sigma_t) * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config['prediction_type']} must be 'flow_prediction' "
|
||||
f"for the UniPCMultistepScheduler."
|
||||
)
|
||||
|
||||
if self.config['thresholding']:
|
||||
sigma_t = self.sigmas[self.step_index]
|
||||
x0_pred = sample - sigma_t * model_output
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
epsilon = model_output + x0_pred
|
||||
|
||||
return epsilon
|
||||
|
||||
def multistep_uni_p_bh_update(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
sample: mx.array = None,
|
||||
order: int = None,
|
||||
**kwargs,
|
||||
) -> mx.array:
|
||||
"""One step for the UniP (B(h) version)."""
|
||||
model_output_list = self.model_outputs
|
||||
|
||||
s0 = self.timestep_list[-1]
|
||||
m0 = model_output_list[-1]
|
||||
x = sample
|
||||
|
||||
if self.solver_p:
|
||||
x_t = self.solver_p.step(model_output, s0, x).prev_sample
|
||||
return x_t
|
||||
|
||||
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
|
||||
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
|
||||
lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0)
|
||||
|
||||
h = lambda_t - lambda_s0
|
||||
|
||||
rks = []
|
||||
D1s = []
|
||||
for i in range(1, order):
|
||||
si = self.step_index - i
|
||||
mi = model_output_list[-(i + 1)]
|
||||
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
||||
lambda_si = mx.log(alpha_si) - mx.log(sigma_si)
|
||||
rk = (lambda_si - lambda_s0) / h
|
||||
rks.append(rk)
|
||||
D1s.append((mi - m0) / rk)
|
||||
|
||||
rks.append(1.0)
|
||||
rks = mx.array(rks)
|
||||
|
||||
R = []
|
||||
b = []
|
||||
|
||||
hh = -h if self.predict_x0 else h
|
||||
h_phi_1 = mx.exp(hh) - 1 # h\phi_1(h) = e^h - 1
|
||||
h_phi_k = h_phi_1 / hh - 1
|
||||
|
||||
factorial_i = 1
|
||||
|
||||
if self.config['solver_type'] == "bh1":
|
||||
B_h = hh
|
||||
elif self.config['solver_type'] == "bh2":
|
||||
B_h = mx.exp(hh) - 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
for i in range(1, order + 1):
|
||||
R.append(mx.power(rks, i - 1))
|
||||
b.append(h_phi_k * factorial_i / B_h)
|
||||
factorial_i *= i + 1
|
||||
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
||||
|
||||
R = mx.stack(R)
|
||||
b = mx.array(b)
|
||||
|
||||
if len(D1s) > 0:
|
||||
D1s = mx.stack(D1s, axis=1) # (B, K)
|
||||
# for order 2, we use a simplified version
|
||||
if order == 2:
|
||||
rhos_p = mx.array([0.5], dtype=x.dtype)
|
||||
else:
|
||||
rhos_p = mx.linalg.solve(R[:-1, :-1], b[:-1], stream=mx.cpu).astype(x.dtype)
|
||||
else:
|
||||
D1s = None
|
||||
|
||||
if self.predict_x0:
|
||||
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
||||
if D1s is not None:
|
||||
pred_res = mx.sum(rhos_p[:, None, None, None] * D1s, axis=0)
|
||||
else:
|
||||
pred_res = 0
|
||||
x_t = x_t_ - alpha_t * B_h * pred_res
|
||||
else:
|
||||
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
||||
if D1s is not None:
|
||||
pred_res = mx.sum(rhos_p[:, None, None, None] * D1s, axis=0)
|
||||
else:
|
||||
pred_res = 0
|
||||
x_t = x_t_ - sigma_t * B_h * pred_res
|
||||
|
||||
x_t = x_t.astype(x.dtype)
|
||||
return x_t
|
||||
|
||||
def multistep_uni_c_bh_update(
|
||||
self,
|
||||
this_model_output: mx.array,
|
||||
last_sample: mx.array = None,
|
||||
this_sample: mx.array = None,
|
||||
order: int = None,
|
||||
**kwargs,
|
||||
) -> mx.array:
|
||||
"""One step for the UniC (B(h) version)."""
|
||||
model_output_list = self.model_outputs
|
||||
|
||||
m0 = model_output_list[-1]
|
||||
x = last_sample
|
||||
x_t = this_sample
|
||||
model_t = this_model_output
|
||||
|
||||
sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
|
||||
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
|
||||
lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0)
|
||||
|
||||
h = lambda_t - lambda_s0
|
||||
|
||||
rks = []
|
||||
D1s = []
|
||||
for i in range(1, order):
|
||||
si = self.step_index - (i + 1)
|
||||
mi = model_output_list[-(i + 1)]
|
||||
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
||||
lambda_si = mx.log(alpha_si) - mx.log(sigma_si)
|
||||
rk = (lambda_si - lambda_s0) / h
|
||||
rks.append(rk)
|
||||
D1s.append((mi - m0) / rk)
|
||||
|
||||
rks.append(1.0)
|
||||
rks = mx.array(rks)
|
||||
|
||||
R = []
|
||||
b = []
|
||||
|
||||
hh = -h if self.predict_x0 else h
|
||||
h_phi_1 = mx.exp(hh) - 1
|
||||
h_phi_k = h_phi_1 / hh - 1
|
||||
|
||||
factorial_i = 1
|
||||
|
||||
if self.config['solver_type'] == "bh1":
|
||||
B_h = hh
|
||||
elif self.config['solver_type'] == "bh2":
|
||||
B_h = mx.exp(hh) - 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
for i in range(1, order + 1):
|
||||
R.append(mx.power(rks, i - 1))
|
||||
b.append(h_phi_k * factorial_i / B_h)
|
||||
factorial_i *= i + 1
|
||||
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
||||
|
||||
R = mx.stack(R)
|
||||
b = mx.array(b)
|
||||
|
||||
if len(D1s) > 0:
|
||||
D1s = mx.stack(D1s, axis=1)
|
||||
else:
|
||||
D1s = None
|
||||
|
||||
# for order 1, we use a simplified version
|
||||
if order == 1:
|
||||
rhos_c = mx.array([0.5], dtype=x.dtype)
|
||||
else:
|
||||
rhos_c = mx.linalg.solve(R, b, stream=mx.cpu).astype(x.dtype)
|
||||
|
||||
if self.predict_x0:
|
||||
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
||||
if D1s is not None:
|
||||
corr_res = mx.sum(rhos_c[:-1, None, None, None] * D1s, axis=0)
|
||||
else:
|
||||
corr_res = 0
|
||||
D1_t = model_t - m0
|
||||
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
||||
else:
|
||||
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
||||
if D1s is not None:
|
||||
corr_res = mx.sum(rhos_c[:-1, None, None, None] * D1s, axis=0)
|
||||
else:
|
||||
corr_res = 0
|
||||
D1_t = model_t - m0
|
||||
x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
||||
|
||||
x_t = x_t.astype(x.dtype)
|
||||
return x_t
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
condition = schedule_timesteps == timestep
|
||||
indices = mx.argmax(condition.astype(mx.int32))
|
||||
|
||||
# Convert scalar to int and return
|
||||
return int(indices)
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
"""Initialize the step_index counter for the scheduler."""
|
||||
if self.begin_index is None:
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
timestep: Union[int, mx.array],
|
||||
sample: mx.array,
|
||||
return_dict: bool = True,
|
||||
generator=None
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""Predict the sample from the previous timestep."""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
use_corrector = (
|
||||
self.step_index > 0 and
|
||||
self.step_index - 1 not in self.disable_corrector and
|
||||
self.last_sample is not None
|
||||
)
|
||||
|
||||
model_output_convert = self.convert_model_output(
|
||||
model_output, sample=sample
|
||||
)
|
||||
if use_corrector:
|
||||
sample = self.multistep_uni_c_bh_update(
|
||||
this_model_output=model_output_convert,
|
||||
last_sample=self.last_sample,
|
||||
this_sample=sample,
|
||||
order=self.this_order,
|
||||
)
|
||||
|
||||
for i in range(self.config['solver_order'] - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.timestep_list[i] = self.timestep_list[i + 1]
|
||||
|
||||
self.model_outputs[-1] = model_output_convert
|
||||
self.timestep_list[-1] = timestep
|
||||
|
||||
if self.config['lower_order_final']:
|
||||
this_order = min(
|
||||
self.config['solver_order'],
|
||||
len(self.timesteps) - self.step_index
|
||||
)
|
||||
else:
|
||||
this_order = self.config['solver_order']
|
||||
|
||||
self.this_order = min(this_order, self.lower_order_nums + 1)
|
||||
assert self.this_order > 0
|
||||
|
||||
self.last_sample = sample
|
||||
prev_sample = self.multistep_uni_p_bh_update(
|
||||
model_output=model_output,
|
||||
sample=sample,
|
||||
order=self.this_order,
|
||||
)
|
||||
|
||||
if self.lower_order_nums < self.config['solver_order']:
|
||||
self.lower_order_nums += 1
|
||||
|
||||
# Increase step index
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def scale_model_input(self, sample: mx.array, *args, **kwargs) -> mx.array:
|
||||
"""Scale model input - no scaling needed for this scheduler."""
|
||||
return sample
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: mx.array,
|
||||
noise: mx.array,
|
||||
timesteps: mx.array,
|
||||
) -> mx.array:
|
||||
"""Add noise to original samples."""
|
||||
sigmas = self.sigmas.astype(original_samples.dtype)
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
# Get step indices
|
||||
if self.begin_index is None:
|
||||
step_indices = [
|
||||
self.index_for_timestep(t, schedule_timesteps)
|
||||
for t in timesteps
|
||||
]
|
||||
elif self.step_index is not None:
|
||||
step_indices = [self.step_index] * timesteps.shape[0]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices]
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = mx.expand_dims(sigma, -1)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.config['num_train_timesteps']
|
373
video/Wan2.1/wan/utils/qwen_vl_utils.py
Normal file
@ -0,0 +1,373 @@
|
||||
# Copied from https://github.com/kq-chen/qwen-vl-utils
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import torchvision
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from torchvision import io, transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
IMAGE_FACTOR = 28
|
||||
MIN_PIXELS = 4 * 28 * 28
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
MAX_RATIO = 200
|
||||
|
||||
VIDEO_MIN_PIXELS = 128 * 28 * 28
|
||||
VIDEO_MAX_PIXELS = 768 * 28 * 28
|
||||
VIDEO_TOTAL_PIXELS = 24576 * 28 * 28
|
||||
FRAME_FACTOR = 2
|
||||
FPS = 2.0
|
||||
FPS_MIN_FRAMES = 4
|
||||
FPS_MAX_FRAMES = 768
|
||||
|
||||
|
||||
def round_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
|
||||
|
||||
def ceil_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
|
||||
def floor_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
|
||||
def smart_resize(height: int,
|
||||
width: int,
|
||||
factor: int = IMAGE_FACTOR,
|
||||
min_pixels: int = MIN_PIXELS,
|
||||
max_pixels: int = MAX_PIXELS) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
"""
|
||||
if max(height, width) / min(height, width) > MAX_RATIO:
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
||||
)
|
||||
h_bar = max(factor, round_by_factor(height, factor))
|
||||
w_bar = max(factor, round_by_factor(width, factor))
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = floor_by_factor(height / beta, factor)
|
||||
w_bar = floor_by_factor(width / beta, factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = ceil_by_factor(height * beta, factor)
|
||||
w_bar = ceil_by_factor(width * beta, factor)
|
||||
return h_bar, w_bar
|
||||
|
||||
|
||||
def fetch_image(ele: dict[str, str | Image.Image],
|
||||
size_factor: int = IMAGE_FACTOR) -> Image.Image:
|
||||
if "image" in ele:
|
||||
image = ele["image"]
|
||||
else:
|
||||
image = ele["image_url"]
|
||||
image_obj = None
|
||||
if isinstance(image, Image.Image):
|
||||
image_obj = image
|
||||
elif image.startswith("http://") or image.startswith("https://"):
|
||||
image_obj = Image.open(requests.get(image, stream=True).raw)
|
||||
elif image.startswith("file://"):
|
||||
image_obj = Image.open(image[7:])
|
||||
elif image.startswith("data:image"):
|
||||
if "base64," in image:
|
||||
_, base64_data = image.split("base64,", 1)
|
||||
data = base64.b64decode(base64_data)
|
||||
image_obj = Image.open(BytesIO(data))
|
||||
else:
|
||||
image_obj = Image.open(image)
|
||||
if image_obj is None:
|
||||
raise ValueError(
|
||||
f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
|
||||
)
|
||||
image = image_obj.convert("RGB")
|
||||
## resize
|
||||
if "resized_height" in ele and "resized_width" in ele:
|
||||
resized_height, resized_width = smart_resize(
|
||||
ele["resized_height"],
|
||||
ele["resized_width"],
|
||||
factor=size_factor,
|
||||
)
|
||||
else:
|
||||
width, height = image.size
|
||||
min_pixels = ele.get("min_pixels", MIN_PIXELS)
|
||||
max_pixels = ele.get("max_pixels", MAX_PIXELS)
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=size_factor,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
image = image.resize((resized_width, resized_height))
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def smart_nframes(
|
||||
ele: dict,
|
||||
total_frames: int,
|
||||
video_fps: int | float,
|
||||
) -> int:
|
||||
"""calculate the number of frames for video used for model inputs.
|
||||
|
||||
Args:
|
||||
ele (dict): a dict contains the configuration of video.
|
||||
support either `fps` or `nframes`:
|
||||
- nframes: the number of frames to extract for model inputs.
|
||||
- fps: the fps to extract frames for model inputs.
|
||||
- min_frames: the minimum number of frames of the video, only used when fps is provided.
|
||||
- max_frames: the maximum number of frames of the video, only used when fps is provided.
|
||||
total_frames (int): the original total number of frames of the video.
|
||||
video_fps (int | float): the original fps of the video.
|
||||
|
||||
Raises:
|
||||
ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
|
||||
|
||||
Returns:
|
||||
int: the number of frames for video used for model inputs.
|
||||
"""
|
||||
assert not ("fps" in ele and
|
||||
"nframes" in ele), "Only accept either `fps` or `nframes`"
|
||||
if "nframes" in ele:
|
||||
nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
|
||||
else:
|
||||
fps = ele.get("fps", FPS)
|
||||
min_frames = ceil_by_factor(
|
||||
ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
|
||||
max_frames = floor_by_factor(
|
||||
ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)),
|
||||
FRAME_FACTOR)
|
||||
nframes = total_frames / video_fps * fps
|
||||
nframes = min(max(nframes, min_frames), max_frames)
|
||||
nframes = round_by_factor(nframes, FRAME_FACTOR)
|
||||
if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
|
||||
raise ValueError(
|
||||
f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}."
|
||||
)
|
||||
return nframes
|
||||
|
||||
|
||||
def _read_video_torchvision(ele: dict,) -> torch.Tensor:
|
||||
"""read video using torchvision.io.read_video
|
||||
|
||||
Args:
|
||||
ele (dict): a dict contains the configuration of video.
|
||||
support keys:
|
||||
- video: the path of video. support "file://", "http://", "https://" and local path.
|
||||
- video_start: the start time of video.
|
||||
- video_end: the end time of video.
|
||||
Returns:
|
||||
torch.Tensor: the video tensor with shape (T, C, H, W).
|
||||
"""
|
||||
video_path = ele["video"]
|
||||
if version.parse(torchvision.__version__) < version.parse("0.19.0"):
|
||||
if "http://" in video_path or "https://" in video_path:
|
||||
warnings.warn(
|
||||
"torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0."
|
||||
)
|
||||
if "file://" in video_path:
|
||||
video_path = video_path[7:]
|
||||
st = time.time()
|
||||
video, audio, info = io.read_video(
|
||||
video_path,
|
||||
start_pts=ele.get("video_start", 0.0),
|
||||
end_pts=ele.get("video_end", None),
|
||||
pts_unit="sec",
|
||||
output_format="TCHW",
|
||||
)
|
||||
total_frames, video_fps = video.size(0), info["video_fps"]
|
||||
logger.info(
|
||||
f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s"
|
||||
)
|
||||
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
|
||||
idx = torch.linspace(0, total_frames - 1, nframes).round().long()
|
||||
video = video[idx]
|
||||
return video
|
||||
|
||||
|
||||
def is_decord_available() -> bool:
|
||||
import importlib.util
|
||||
|
||||
return importlib.util.find_spec("decord") is not None
|
||||
|
||||
|
||||
def _read_video_decord(ele: dict,) -> torch.Tensor:
|
||||
"""read video using decord.VideoReader
|
||||
|
||||
Args:
|
||||
ele (dict): a dict contains the configuration of video.
|
||||
support keys:
|
||||
- video: the path of video. support "file://", "http://", "https://" and local path.
|
||||
- video_start: the start time of video.
|
||||
- video_end: the end time of video.
|
||||
Returns:
|
||||
torch.Tensor: the video tensor with shape (T, C, H, W).
|
||||
"""
|
||||
import decord
|
||||
video_path = ele["video"]
|
||||
st = time.time()
|
||||
vr = decord.VideoReader(video_path)
|
||||
# TODO: support start_pts and end_pts
|
||||
if 'video_start' in ele or 'video_end' in ele:
|
||||
raise NotImplementedError(
|
||||
"not support start_pts and end_pts in decord for now.")
|
||||
total_frames, video_fps = len(vr), vr.get_avg_fps()
|
||||
logger.info(
|
||||
f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s"
|
||||
)
|
||||
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
|
||||
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
|
||||
video = vr.get_batch(idx).asnumpy()
|
||||
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
|
||||
return video
|
||||
|
||||
|
||||
VIDEO_READER_BACKENDS = {
|
||||
"decord": _read_video_decord,
|
||||
"torchvision": _read_video_torchvision,
|
||||
}
|
||||
|
||||
FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_video_reader_backend() -> str:
|
||||
if FORCE_QWENVL_VIDEO_READER is not None:
|
||||
video_reader_backend = FORCE_QWENVL_VIDEO_READER
|
||||
elif is_decord_available():
|
||||
video_reader_backend = "decord"
|
||||
else:
|
||||
video_reader_backend = "torchvision"
|
||||
print(
|
||||
f"qwen-vl-utils using {video_reader_backend} to read video.",
|
||||
file=sys.stderr)
|
||||
return video_reader_backend
|
||||
|
||||
|
||||
def fetch_video(
|
||||
ele: dict,
|
||||
image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]:
|
||||
# Handle MPS device compatibility
|
||||
original_device = None
|
||||
if isinstance(ele.get("video"), torch.Tensor) and ele["video"].device.type == "cpu":
|
||||
original_device = ele["video"].device
|
||||
ele["video"] = ele["video"].cpu()
|
||||
|
||||
if isinstance(ele["video"], str):
|
||||
video_reader_backend = get_video_reader_backend()
|
||||
video = VIDEO_READER_BACKENDS[video_reader_backend](ele)
|
||||
nframes, _, height, width = video.shape
|
||||
|
||||
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
|
||||
total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
|
||||
max_pixels = max(
|
||||
min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR),
|
||||
int(min_pixels * 1.05))
|
||||
max_pixels = ele.get("max_pixels", max_pixels)
|
||||
if "resized_height" in ele and "resized_width" in ele:
|
||||
resized_height, resized_width = smart_resize(
|
||||
ele["resized_height"],
|
||||
ele["resized_width"],
|
||||
factor=image_factor,
|
||||
)
|
||||
else:
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=image_factor,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
video = transforms.functional.resize(
|
||||
video,
|
||||
[resized_height, resized_width],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
antialias=True,
|
||||
).float()
|
||||
return video
|
||||
else:
|
||||
assert isinstance(ele["video"], (list, tuple))
|
||||
process_info = ele.copy()
|
||||
process_info.pop("type", None)
|
||||
process_info.pop("video", None)
|
||||
images = [
|
||||
fetch_image({
|
||||
"image": video_element,
|
||||
**process_info
|
||||
},
|
||||
size_factor=image_factor)
|
||||
for video_element in ele["video"]
|
||||
]
|
||||
nframes = ceil_by_factor(len(images), FRAME_FACTOR)
|
||||
if len(images) < nframes:
|
||||
images.extend([images[-1]] * (nframes - len(images)))
|
||||
return images
|
||||
|
||||
# Return to original device if needed
|
||||
if original_device is not None and isinstance(video, torch.Tensor):
|
||||
video = video.to(original_device)
|
||||
|
||||
|
||||
def extract_vision_info(
|
||||
conversations: list[dict] | list[list[dict]]) -> list[dict]:
|
||||
vision_infos = []
|
||||
if isinstance(conversations[0], dict):
|
||||
conversations = [conversations]
|
||||
for conversation in conversations:
|
||||
for message in conversation:
|
||||
if isinstance(message["content"], list):
|
||||
for ele in message["content"]:
|
||||
if ("image" in ele or "image_url" in ele or
|
||||
"video" in ele or
|
||||
ele["type"] in ("image", "image_url", "video")):
|
||||
vision_infos.append(ele)
|
||||
return vision_infos
|
||||
|
||||
|
||||
def process_vision_info(
|
||||
conversations: list[dict] | list[list[dict]],
|
||||
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] |
|
||||
None]:
|
||||
vision_infos = extract_vision_info(conversations)
|
||||
## Read images or videos
|
||||
image_inputs = []
|
||||
video_inputs = []
|
||||
for vision_info in vision_infos:
|
||||
if "image" in vision_info or "image_url" in vision_info:
|
||||
image_inputs.append(fetch_image(vision_info))
|
||||
elif "video" in vision_info:
|
||||
video_inputs.append(fetch_video(vision_info))
|
||||
else:
|
||||
raise ValueError("image, image_url or video should in content.")
|
||||
if len(image_inputs) == 0:
|
||||
image_inputs = None
|
||||
if len(video_inputs) == 0:
|
||||
video_inputs = None
|
||||
return image_inputs, video_inputs
|
175
video/Wan2.1/wan/utils/utils_mlx_own.py
Normal file
@ -0,0 +1,175 @@
|
||||
import argparse
|
||||
import binascii
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
import imageio
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['cache_video', 'cache_image', 'str2bool']
|
||||
|
||||
|
||||
def rand_name(length=8, suffix=''):
|
||||
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
|
||||
if suffix:
|
||||
if not suffix.startswith('.'):
|
||||
suffix = '.' + suffix
|
||||
name += suffix
|
||||
return name
|
||||
|
||||
|
||||
def make_grid(tensor, nrow=8, normalize=True, value_range=(-1, 1)):
|
||||
"""MLX equivalent of torchvision.utils.make_grid"""
|
||||
# tensor shape: (batch, channels, height, width)
|
||||
batch_size, channels, height, width = tensor.shape
|
||||
|
||||
# Calculate grid dimensions
|
||||
ncol = nrow
|
||||
nrow_actual = (batch_size + ncol - 1) // ncol
|
||||
|
||||
# Create grid
|
||||
grid_height = height * nrow_actual + (nrow_actual - 1) * 2 # 2 pixel padding
|
||||
grid_width = width * ncol + (ncol - 1) * 2
|
||||
|
||||
# Initialize grid with zeros
|
||||
grid = mx.zeros((channels, grid_height, grid_width))
|
||||
|
||||
# Fill grid
|
||||
for idx in range(batch_size):
|
||||
row = idx // ncol
|
||||
col = idx % ncol
|
||||
|
||||
y_start = row * (height + 2)
|
||||
y_end = y_start + height
|
||||
x_start = col * (width + 2)
|
||||
x_end = x_start + width
|
||||
|
||||
img = tensor[idx]
|
||||
if normalize:
|
||||
# Normalize to [0, 1]
|
||||
img = (img - value_range[0]) / (value_range[1] - value_range[0])
|
||||
|
||||
grid[:, y_start:y_end, x_start:x_end] = img
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
def cache_video(tensor,
|
||||
save_file=None,
|
||||
fps=30,
|
||||
suffix='.mp4',
|
||||
nrow=8,
|
||||
normalize=True,
|
||||
value_range=(-1, 1),
|
||||
retry=5):
|
||||
# cache file
|
||||
cache_file = osp.join('/tmp', rand_name(
|
||||
suffix=suffix)) if save_file is None else save_file
|
||||
|
||||
# save to cache
|
||||
error = None
|
||||
for _ in range(retry):
|
||||
try:
|
||||
# preprocess
|
||||
tensor = mx.clip(tensor, value_range[0], value_range[1])
|
||||
|
||||
# tensor shape: (batch, channels, frames, height, width)
|
||||
# Process each frame
|
||||
frames = []
|
||||
for frame_idx in range(tensor.shape[2]):
|
||||
frame = tensor[:, :, frame_idx, :, :] # (batch, channels, height, width)
|
||||
grid = make_grid(frame, nrow=nrow, normalize=normalize, value_range=value_range)
|
||||
frames.append(grid)
|
||||
|
||||
# Stack frames and convert to (frames, height, width, channels)
|
||||
tensor = mx.stack(frames, axis=0) # (frames, channels, height, width)
|
||||
tensor = mx.transpose(tensor, [0, 2, 3, 1]) # (frames, height, width, channels)
|
||||
|
||||
# Convert to uint8
|
||||
tensor = (tensor * 255).astype(mx.uint8)
|
||||
tensor_np = np.array(tensor)
|
||||
|
||||
# write video
|
||||
writer = imageio.get_writer(
|
||||
cache_file, fps=fps, codec='libx264', quality=8)
|
||||
for frame in tensor_np:
|
||||
writer.append_data(frame)
|
||||
writer.close()
|
||||
return cache_file
|
||||
except Exception as e:
|
||||
error = e
|
||||
continue
|
||||
else:
|
||||
print(f'cache_video failed, error: {error}', flush=True)
|
||||
return None
|
||||
|
||||
|
||||
def save_image(tensor, save_file, nrow=8, normalize=True, value_range=(-1, 1)):
|
||||
"""MLX equivalent of torchvision.utils.save_image"""
|
||||
# Make grid
|
||||
grid = make_grid(tensor, nrow=nrow, normalize=normalize, value_range=value_range)
|
||||
|
||||
# Convert to (height, width, channels) and uint8
|
||||
grid = mx.transpose(grid, [1, 2, 0]) # (height, width, channels)
|
||||
grid = (grid * 255).astype(mx.uint8)
|
||||
|
||||
# Save using imageio
|
||||
imageio.imwrite(save_file, np.array(grid))
|
||||
|
||||
|
||||
def cache_image(tensor,
|
||||
save_file,
|
||||
nrow=8,
|
||||
normalize=True,
|
||||
value_range=(-1, 1),
|
||||
retry=5):
|
||||
# cache file
|
||||
suffix = osp.splitext(save_file)[1]
|
||||
if suffix.lower() not in [
|
||||
'.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp'
|
||||
]:
|
||||
suffix = '.png'
|
||||
|
||||
# save to cache
|
||||
error = None
|
||||
for _ in range(retry):
|
||||
try:
|
||||
tensor = mx.clip(tensor, value_range[0], value_range[1])
|
||||
save_image(
|
||||
tensor,
|
||||
save_file,
|
||||
nrow=nrow,
|
||||
normalize=normalize,
|
||||
value_range=value_range)
|
||||
return save_file
|
||||
except Exception as e:
|
||||
error = e
|
||||
continue
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
"""
|
||||
Convert a string to a boolean.
|
||||
|
||||
Supported true values: 'yes', 'true', 't', 'y', '1'
|
||||
Supported false values: 'no', 'false', 'f', 'n', '0'
|
||||
|
||||
Args:
|
||||
v (str): String to convert.
|
||||
|
||||
Returns:
|
||||
bool: Converted boolean value.
|
||||
|
||||
Raises:
|
||||
argparse.ArgumentTypeError: If the value cannot be converted to boolean.
|
||||
"""
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
v_lower = v.lower()
|
||||
if v_lower in ('yes', 'true', 't', 'y', '1'):
|
||||
return True
|
||||
elif v_lower in ('no', 'false', 'f', 'n', '0'):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError('Boolean value expected (True/False)')
|
175
video/Wan2.1/wan/vae_model_io.py
Normal file
@ -0,0 +1,175 @@
|
||||
import torch
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from typing import Dict, Tuple
|
||||
from safetensors import safe_open
|
||||
|
||||
def convert_pytorch_to_mlx(pytorch_path: str, output_path: str, float16: bool = False):
|
||||
"""
|
||||
Convert PyTorch VAE weights to MLX format with correct mapping.
|
||||
"""
|
||||
print(f"Converting {pytorch_path} -> {output_path}")
|
||||
dtype = mx.float16 if float16 else mx.float32
|
||||
|
||||
# Load PyTorch weights
|
||||
if pytorch_path.endswith('.safetensors'):
|
||||
weights = {}
|
||||
with safe_open(pytorch_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
tensor = f.get_tensor(key)
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
tensor = tensor.float()
|
||||
weights[key] = tensor.numpy()
|
||||
else:
|
||||
checkpoint = torch.load(pytorch_path, map_location='cpu')
|
||||
weights = {}
|
||||
state_dict = checkpoint if isinstance(checkpoint, dict) and 'state_dict' not in checkpoint else checkpoint.get('state_dict', checkpoint)
|
||||
for key, tensor in state_dict.items():
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
tensor = tensor.float()
|
||||
weights[key] = tensor.numpy()
|
||||
|
||||
# Convert weights
|
||||
mlx_weights = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
# Skip these
|
||||
if any(skip in key for skip in ["num_batches_tracked", "running_mean", "running_var"]):
|
||||
continue
|
||||
|
||||
# Convert weight formats
|
||||
if value.ndim == 5 and "weight" in key: # Conv3d weights
|
||||
# PyTorch: (out_channels, in_channels, D, H, W)
|
||||
# MLX: (out_channels, D, H, W, in_channels)
|
||||
value = np.transpose(value, (0, 2, 3, 4, 1))
|
||||
elif value.ndim == 4 and "weight" in key: # Conv2d weights
|
||||
# PyTorch: (out_channels, in_channels, H, W)
|
||||
# MLX Conv2d expects: (out_channels, H, W, in_channels)
|
||||
value = np.transpose(value, (0, 2, 3, 1))
|
||||
elif value.ndim == 1 and "bias" in key: # Conv biases
|
||||
# Keep as is - MLX uses same format
|
||||
pass
|
||||
|
||||
# Map the key
|
||||
new_key = key
|
||||
|
||||
# Map residual block internals within Sequential
|
||||
# PyTorch: encoder.downsamples.0.residual.0.gamma
|
||||
# MLX: encoder.downsamples.layers.0.residual.layers.0.gamma
|
||||
import re
|
||||
|
||||
# Add .layers to Sequential modules
|
||||
new_key = re.sub(r'\.downsamples\.(\d+)', r'.downsamples.layers.\1', new_key)
|
||||
new_key = re.sub(r'\.upsamples\.(\d+)', r'.upsamples.layers.\1', new_key)
|
||||
new_key = re.sub(r'\.middle\.(\d+)', r'.middle.layers.\1', new_key)
|
||||
new_key = re.sub(r'\.head\.(\d+)', r'.head.layers.\1', new_key)
|
||||
|
||||
# Map residual Sequential internals
|
||||
if ".residual." in new_key:
|
||||
match = re.search(r'\.residual\.(\d+)\.', new_key)
|
||||
if match:
|
||||
idx = int(match.group(1))
|
||||
if idx == 0: # First RMS_norm
|
||||
new_key = re.sub(r'\.residual\.0\.', '.residual.layers.0.', new_key)
|
||||
elif idx == 1: # SiLU - skip
|
||||
continue
|
||||
elif idx == 2: # First Conv3d
|
||||
new_key = re.sub(r'\.residual\.2\.', '.residual.layers.2.', new_key)
|
||||
elif idx == 3: # Second RMS_norm
|
||||
new_key = re.sub(r'\.residual\.3\.', '.residual.layers.3.', new_key)
|
||||
elif idx == 4: # Second SiLU - skip
|
||||
continue
|
||||
elif idx == 5: # Dropout - could be Identity in MLX
|
||||
if "Dropout" in key:
|
||||
continue
|
||||
new_key = re.sub(r'\.residual\.5\.', '.residual.layers.5.', new_key)
|
||||
elif idx == 6: # Second Conv3d
|
||||
new_key = re.sub(r'\.residual\.6\.', '.residual.layers.6.', new_key)
|
||||
|
||||
# ------ START: REPLACEMENT BLOCK ------
|
||||
# Map resample internals
|
||||
if ".resample." in new_key:
|
||||
# In both Encoder and Decoder Resample blocks, the Conv2d is at index 1
|
||||
# in the nn.Sequential block, following either a padding or upsample layer.
|
||||
# We just need to map PyTorch's .1 to MLX's .layers.1
|
||||
if ".resample.1." in new_key:
|
||||
new_key = new_key.replace(".resample.1.", ".resample.layers.1.")
|
||||
|
||||
# The layers at index 0 (ZeroPad2d, Upsample) have no weights, so we can
|
||||
# safely skip any keys associated with them.
|
||||
if ".resample.0." in key:
|
||||
continue
|
||||
# ------ END: REPLACEMENT BLOCK ------
|
||||
|
||||
# Map head internals (already using Sequential in MLX)
|
||||
# Just need to handle the layers index
|
||||
|
||||
# Handle shortcut layers
|
||||
if ".shortcut." in new_key and "Identity" not in key:
|
||||
# Shortcut Conv3d layers - keep as is
|
||||
pass
|
||||
elif "Identity" in key:
|
||||
# Skip Identity modules
|
||||
continue
|
||||
|
||||
# Handle time_conv in Resample
|
||||
if "time_conv" in new_key:
|
||||
# Keep as is - already correctly named
|
||||
pass
|
||||
|
||||
# Handle attention layers
|
||||
if "to_qkv" in new_key or "proj" in new_key:
|
||||
# Keep as is - already correctly named
|
||||
pass
|
||||
|
||||
# In the conversion script
|
||||
if "gamma" in new_key:
|
||||
# Squeeze gamma from (C, 1, 1) or (C, 1, 1, 1) to just (C,)
|
||||
value = np.squeeze(value) # This removes all dimensions of size 1
|
||||
# Result will always be 1D array of shape (C,)
|
||||
|
||||
# Add to MLX weights
|
||||
mlx_weights[new_key] = mx.array(value).astype(dtype)
|
||||
|
||||
# Verify critical layers are present
|
||||
critical_prefixes = [
|
||||
"encoder.conv1", "decoder.conv1", "conv1", "conv2",
|
||||
"encoder.head.layers.2", "decoder.head.layers.2" # Updated for Sequential
|
||||
]
|
||||
|
||||
for prefix in critical_prefixes:
|
||||
found = any(k.startswith(prefix) for k in mlx_weights.keys())
|
||||
if not found:
|
||||
print(f"WARNING: No weights found for {prefix}")
|
||||
|
||||
print(f"Converted {len(mlx_weights)} parameters")
|
||||
|
||||
# Print a few example keys for verification
|
||||
print("\nExample converted keys:")
|
||||
for i, key in enumerate(sorted(mlx_weights.keys())[:10]):
|
||||
print(f" {key}")
|
||||
|
||||
# Save
|
||||
if output_path.endswith('.safetensors'):
|
||||
mx.save_safetensors(output_path, mlx_weights)
|
||||
else:
|
||||
mx.savez(output_path, **mlx_weights)
|
||||
|
||||
print(f"\nSaved to {output_path}")
|
||||
print("\nAll converted keys:")
|
||||
for key in sorted(mlx_weights.keys()):
|
||||
print(f" {key}: {mlx_weights[key].shape}")
|
||||
return mlx_weights
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
if len(sys.argv) < 3:
|
||||
print("Usage: python convert_vae_final.py <input.pth> <output.safetensors> [--fp16]")
|
||||
else:
|
||||
convert_pytorch_to_mlx(
|
||||
sys.argv[1],
|
||||
sys.argv[2],
|
||||
"--fp16" in sys.argv
|
||||
)
|
228
video/Wan2.1/wan/wan_model_io.py
Normal file
@ -0,0 +1,228 @@
|
||||
from typing import List, Tuple
|
||||
import os
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_unflatten
|
||||
from safetensors import safe_open
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def map_wan_weights(key: str, value: mx.array) -> List[Tuple[str, mx.array]]:
|
||||
# Remove .layers. from PyTorch Sequential to match MLX Python lists
|
||||
key = key.replace(".layers.", ".")
|
||||
|
||||
# Handle conv transpose if needed
|
||||
if "patch_embedding.weight" in key:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
return [(key, value)]
|
||||
|
||||
|
||||
def _flatten(params: List[List[Tuple[str, mx.array]]]) -> List[Tuple[str, mx.array]]:
|
||||
"""Flatten nested list of parameter tuples"""
|
||||
return [(k, v) for p in params for (k, v) in p]
|
||||
|
||||
|
||||
def load_wan_from_safetensors(
|
||||
safetensors_path: str,
|
||||
model,
|
||||
float16: bool = False
|
||||
):
|
||||
"""
|
||||
Load WanModel weights from safetensors file(s) into MLX model.
|
||||
"""
|
||||
import os
|
||||
import glob
|
||||
|
||||
if os.path.isdir(safetensors_path):
|
||||
# Multiple files (14B model) - only diffusion_mlx_model files
|
||||
pattern = os.path.join(safetensors_path, "diffusion_mlx_model*.safetensors")
|
||||
safetensor_files = sorted(glob.glob(pattern))
|
||||
print(f"Found {len(safetensor_files)} diffusion_mlx_model safetensors files")
|
||||
|
||||
# Load all files and merge weights
|
||||
all_weights = {}
|
||||
for file_path in safetensor_files:
|
||||
print(f"Loading: {file_path}")
|
||||
weights = mx.load(file_path)
|
||||
all_weights.update(weights)
|
||||
|
||||
model.update(tree_unflatten(list(all_weights.items())))
|
||||
else:
|
||||
# Single file (1.3B model)
|
||||
print(f"Loading single file: {safetensors_path}")
|
||||
weights = mx.load(safetensors_path)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
|
||||
print("WanModel weights loaded successfully!")
|
||||
return model
|
||||
|
||||
|
||||
def convert_safetensors_to_mlx_weights(
|
||||
safetensors_path: str,
|
||||
output_path: str,
|
||||
float16: bool = False
|
||||
):
|
||||
"""
|
||||
Convert safetensors file to MLX weights file.
|
||||
|
||||
Args:
|
||||
safetensors_path: Input safetensors file
|
||||
output_path: Output MLX weights file (.safetensors)
|
||||
float16: Whether to use float16 precision
|
||||
"""
|
||||
dtype = mx.float16 if float16 else mx.float32
|
||||
|
||||
print(f"Converting safetensors to MLX format...")
|
||||
print(f"Input: {safetensors_path}")
|
||||
print(f"Output: {output_path}")
|
||||
print(f"Target dtype: {dtype}")
|
||||
|
||||
# Load and convert weights
|
||||
weights = {}
|
||||
bfloat16_count = 0
|
||||
original_keys = []
|
||||
|
||||
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
|
||||
original_keys = list(f.keys()) # Store keys before closing
|
||||
print(f"Processing {len(original_keys)} parameters...")
|
||||
|
||||
for key in original_keys:
|
||||
tensor = f.get_tensor(key)
|
||||
|
||||
# Handle BFloat16
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
bfloat16_count += 1
|
||||
tensor = tensor.float() # Convert to float32 first
|
||||
|
||||
value = mx.array(tensor.numpy()).astype(dtype)
|
||||
|
||||
# Apply mapping
|
||||
mapped = map_wan_weights(key, value)
|
||||
|
||||
for new_key, new_value in mapped:
|
||||
weights[new_key] = new_value
|
||||
|
||||
if bfloat16_count > 0:
|
||||
print(f"⚠️ Converted {bfloat16_count} BFloat16 tensors to {dtype}")
|
||||
|
||||
# Print mapping summary
|
||||
skipped = len(original_keys) - len(weights)
|
||||
if skipped > 0:
|
||||
print(f"ℹ️ Skipped {skipped} activation layer parameters")
|
||||
|
||||
# Save as MLX format
|
||||
print(f"Saving {len(weights)} parameters to: {output_path}")
|
||||
mx.save_safetensors(output_path, weights)
|
||||
|
||||
# Print a few example keys for verification
|
||||
print("\nExample converted keys:")
|
||||
for i, key in enumerate(sorted(weights.keys())[:10]):
|
||||
print(f" {key}: {weights[key].shape}")
|
||||
|
||||
return weights
|
||||
|
||||
def convert_multiple_safetensors_to_mlx(
|
||||
checkpoint_dir: str,
|
||||
float16: bool = False
|
||||
):
|
||||
"""Convert multiple PyTorch safetensors files to MLX format."""
|
||||
import glob
|
||||
|
||||
# Find all PyTorch model files
|
||||
pytorch_pattern = os.path.join(checkpoint_dir, "diffusion_pytorch_model-*.safetensors")
|
||||
pytorch_files = sorted(glob.glob(pytorch_pattern))
|
||||
|
||||
if not pytorch_files:
|
||||
raise FileNotFoundError(f"No PyTorch model files found matching: {pytorch_pattern}")
|
||||
|
||||
print(f"Converting {len(pytorch_files)} PyTorch files to MLX format...")
|
||||
|
||||
for i, pytorch_file in enumerate(pytorch_files, 1):
|
||||
# Extract the suffix (e.g., "00001-of-00006")
|
||||
basename = os.path.basename(pytorch_file)
|
||||
suffix = basename.replace("diffusion_pytorch_model-", "").replace(".safetensors", "")
|
||||
|
||||
# Create MLX filename
|
||||
mlx_file = os.path.join(checkpoint_dir, f"diffusion_mlx_model-{suffix}.safetensors")
|
||||
|
||||
print(f"Converting {i}/{len(pytorch_files)}: {basename}")
|
||||
convert_safetensors_to_mlx_weights(pytorch_file, mlx_file, float16)
|
||||
|
||||
print("All files converted successfully!")
|
||||
|
||||
|
||||
def debug_weight_mapping(safetensors_path: str, float16: bool = False):
|
||||
"""
|
||||
Debug function to see how weights are being mapped.
|
||||
"""
|
||||
dtype = mx.float16 if float16 else mx.float32
|
||||
|
||||
print("=== WAN Weight Mapping Debug ===")
|
||||
|
||||
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
|
||||
# Check first 20 keys to see the mapping
|
||||
for i, key in enumerate(f.keys()):
|
||||
if i >= 20:
|
||||
break
|
||||
|
||||
tensor = f.get_tensor(key)
|
||||
|
||||
# Handle BFloat16
|
||||
original_dtype = tensor.dtype
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
tensor = tensor.float()
|
||||
|
||||
value = mx.array(tensor.numpy()).astype(dtype)
|
||||
|
||||
# Apply mapping
|
||||
mapped = map_wan_weights(key, value)
|
||||
|
||||
if len(mapped) == 0:
|
||||
print(f"SKIPPED: {key} ({original_dtype})")
|
||||
elif len(mapped) == 1:
|
||||
new_key, new_value = mapped[0]
|
||||
if new_key == key:
|
||||
print(f"DIRECT: {key} ({original_dtype}) [{tensor.shape}]")
|
||||
else:
|
||||
print(f"MAPPED: {key} -> {new_key} [{tensor.shape}]")
|
||||
|
||||
|
||||
def check_model_structure(model):
|
||||
"""
|
||||
Print the structure of an MLX model to debug loading issues.
|
||||
"""
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
print("=== Model Structure ===")
|
||||
params = dict(tree_flatten(model))
|
||||
print(f"Model has {len(params)} parameters")
|
||||
|
||||
print("\nFirst 20 parameter names:")
|
||||
for i, (key, value) in enumerate(params.items()):
|
||||
if i >= 20:
|
||||
break
|
||||
print(f" {key}: {value.shape}")
|
||||
|
||||
return params
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
if len(sys.argv) < 3:
|
||||
print("Usage: python wan_model_io.py <input.safetensors> <output.safetensors> [--fp16]")
|
||||
sys.exit(1)
|
||||
|
||||
input_file = sys.argv[1]
|
||||
output_file = sys.argv[2]
|
||||
use_fp16 = "--fp16" in sys.argv
|
||||
|
||||
# Debug the mapping first (optional)
|
||||
debug_weight_mapping(input_file, use_fp16)
|
||||
|
||||
# Convert weights
|
||||
convert_safetensors_to_mlx_weights(input_file, output_file, float16=use_fp16)
|
||||
|
||||
print("Conversion complete!")
|