mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-21 20:46:50 +08:00
233 lines
7.0 KiB
Python
233 lines
7.0 KiB
Python
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
|
# utils MLX version
|
|
import argparse
|
|
import binascii
|
|
import logging
|
|
import os
|
|
import os.path as osp
|
|
|
|
import imageio
|
|
import mlx.core as mx
|
|
import numpy as np
|
|
|
|
__all__ = ['save_video', 'save_image', 'str2bool', 'masks_like', 'best_output_size']
|
|
|
|
|
|
def rand_name(length=8, suffix=''):
|
|
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
|
|
if suffix:
|
|
if not suffix.startswith('.'):
|
|
suffix = '.' + suffix
|
|
name += suffix
|
|
return name
|
|
|
|
|
|
def make_grid(tensor, nrow=8, normalize=True, value_range=(-1, 1)):
|
|
"""MLX equivalent of torchvision.utils.make_grid"""
|
|
# tensor shape: (batch, channels, height, width)
|
|
batch_size, channels, height, width = tensor.shape
|
|
|
|
# Calculate grid dimensions
|
|
ncol = nrow
|
|
nrow_actual = (batch_size + ncol - 1) // ncol
|
|
|
|
# Create grid
|
|
grid_height = height * nrow_actual + (nrow_actual - 1) * 2 # 2 pixel padding
|
|
grid_width = width * ncol + (ncol - 1) * 2
|
|
|
|
# Initialize grid with zeros
|
|
grid = mx.zeros((channels, grid_height, grid_width))
|
|
|
|
# Fill grid
|
|
for idx in range(batch_size):
|
|
row = idx // ncol
|
|
col = idx % ncol
|
|
|
|
y_start = row * (height + 2)
|
|
y_end = y_start + height
|
|
x_start = col * (width + 2)
|
|
x_end = x_start + width
|
|
|
|
img = tensor[idx]
|
|
if normalize:
|
|
# Normalize to [0, 1]
|
|
img = (img - value_range[0]) / (value_range[1] - value_range[0])
|
|
|
|
grid[:, y_start:y_end, x_start:x_end] = img
|
|
|
|
return grid
|
|
|
|
|
|
def save_video(tensor,
|
|
save_file=None,
|
|
fps=30,
|
|
suffix='.mp4',
|
|
nrow=8,
|
|
normalize=True,
|
|
value_range=(-1, 1)):
|
|
# cache file
|
|
cache_file = osp.join('/tmp', rand_name(
|
|
suffix=suffix)) if save_file is None else save_file
|
|
|
|
# save to cache
|
|
try:
|
|
# preprocess
|
|
tensor = mx.clip(tensor, value_range[0], value_range[1])
|
|
|
|
# tensor shape: (batch, channels, frames, height, width)
|
|
# Process each frame
|
|
frames = []
|
|
for frame_idx in range(tensor.shape[2]):
|
|
frame = tensor[:, :, frame_idx, :, :] # (batch, channels, height, width)
|
|
grid = make_grid(frame, nrow=nrow, normalize=normalize, value_range=value_range)
|
|
frames.append(grid)
|
|
|
|
# Stack frames and convert to (frames, height, width, channels)
|
|
tensor = mx.stack(frames, axis=0) # (frames, channels, height, width)
|
|
tensor = mx.transpose(tensor, [0, 2, 3, 1]) # (frames, height, width, channels)
|
|
|
|
# Convert to uint8
|
|
tensor = (tensor * 255).astype(mx.uint8)
|
|
tensor_np = np.array(tensor)
|
|
|
|
# write video
|
|
writer = imageio.get_writer(
|
|
cache_file, fps=fps, codec='libx264', quality=8)
|
|
for frame in tensor_np:
|
|
writer.append_data(frame)
|
|
writer.close()
|
|
except Exception as e:
|
|
logging.info(f'save_video failed, error: {e}')
|
|
|
|
|
|
def save_image(tensor, save_file, nrow=8, normalize=True, value_range=(-1, 1)):
|
|
# cache file
|
|
suffix = osp.splitext(save_file)[1]
|
|
if suffix.lower() not in [
|
|
'.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp'
|
|
]:
|
|
suffix = '.png'
|
|
|
|
# save to cache
|
|
try:
|
|
# Clip values
|
|
tensor = mx.clip(tensor, value_range[0], value_range[1])
|
|
|
|
# Make grid
|
|
grid = make_grid(tensor, nrow=nrow, normalize=normalize, value_range=value_range)
|
|
|
|
# Convert to (height, width, channels) and uint8
|
|
grid = mx.transpose(grid, [1, 2, 0]) # (height, width, channels)
|
|
grid = (grid * 255).astype(mx.uint8)
|
|
|
|
# Save using imageio
|
|
imageio.imwrite(save_file, np.array(grid))
|
|
return save_file
|
|
except Exception as e:
|
|
logging.info(f'save_image failed, error: {e}')
|
|
|
|
|
|
def str2bool(v):
|
|
"""
|
|
Convert a string to a boolean.
|
|
|
|
Supported true values: 'yes', 'true', 't', 'y', '1'
|
|
Supported false values: 'no', 'false', 'f', 'n', '0'
|
|
|
|
Args:
|
|
v (str): String to convert.
|
|
|
|
Returns:
|
|
bool: Converted boolean value.
|
|
|
|
Raises:
|
|
argparse.ArgumentTypeError: If the value cannot be converted to boolean.
|
|
"""
|
|
if isinstance(v, bool):
|
|
return v
|
|
v_lower = v.lower()
|
|
if v_lower in ('yes', 'true', 't', 'y', '1'):
|
|
return True
|
|
elif v_lower in ('no', 'false', 'f', 'n', '0'):
|
|
return False
|
|
else:
|
|
raise argparse.ArgumentTypeError('Boolean value expected (True/False)')
|
|
|
|
|
|
def masks_like(tensor, zero=False, generator=None, p=0.2):
|
|
"""
|
|
Generate masks similar to input tensors.
|
|
|
|
Args:
|
|
tensor: List of MLX arrays
|
|
zero: Whether to apply zero masking
|
|
generator: Random generator (for MLX, we use mx.random.seed instead)
|
|
p: Probability for random masking
|
|
|
|
Returns:
|
|
Tuple of two lists of masks
|
|
"""
|
|
assert isinstance(tensor, list)
|
|
out1 = [mx.ones(u.shape, dtype=u.dtype) for u in tensor]
|
|
out2 = [mx.ones(u.shape, dtype=u.dtype) for u in tensor]
|
|
|
|
if zero:
|
|
if generator is not None:
|
|
# MLX doesn't have the same generator API as PyTorch
|
|
# We'll use random state instead
|
|
for u, v in zip(out1, out2):
|
|
random_num = mx.random.uniform(0, 1, shape=(1,)).item()
|
|
if random_num < p:
|
|
# Generate random values with normal distribution
|
|
normal_vals = mx.random.normal(shape=u[:, 0].shape, loc=-3.5, scale=0.5)
|
|
u[:, 0] = mx.exp(normal_vals)
|
|
v[:, 0] = mx.zeros_like(v[:, 0])
|
|
else:
|
|
# Keep original values
|
|
u[:, 0] = u[:, 0]
|
|
v[:, 0] = v[:, 0]
|
|
else:
|
|
for u, v in zip(out1, out2):
|
|
u[:, 0] = mx.zeros_like(u[:, 0])
|
|
v[:, 0] = mx.zeros_like(v[:, 0])
|
|
|
|
return out1, out2
|
|
|
|
|
|
def best_output_size(w, h, dw, dh, expected_area):
|
|
"""
|
|
Calculate the best output size given constraints.
|
|
|
|
Args:
|
|
w: Width
|
|
h: Height
|
|
dw: Width divisor
|
|
dh: Height divisor
|
|
expected_area: Target area
|
|
|
|
Returns:
|
|
Tuple of (output_width, output_height)
|
|
"""
|
|
# float output size
|
|
ratio = w / h
|
|
ow = (expected_area * ratio)**0.5
|
|
oh = expected_area / ow
|
|
|
|
# process width first
|
|
ow1 = int(ow // dw * dw)
|
|
oh1 = int(expected_area / ow1 // dh * dh)
|
|
assert ow1 % dw == 0 and oh1 % dh == 0 and ow1 * oh1 <= expected_area
|
|
ratio1 = ow1 / oh1
|
|
|
|
# process height first
|
|
oh2 = int(oh // dh * dh)
|
|
ow2 = int(expected_area / oh2 // dw * dw)
|
|
assert oh2 % dh == 0 and ow2 % dw == 0 and ow2 * oh2 <= expected_area
|
|
ratio2 = ow2 / oh2
|
|
|
|
# compare ratios
|
|
if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2,
|
|
ratio2 / ratio):
|
|
return ow1, oh1
|
|
else:
|
|
return ow2, oh2 |