mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Update for nn.layers.distributed
This commit is contained in:
parent
02b007f19c
commit
208856520d
@ -109,10 +109,10 @@ class Flux(nn.Module):
|
|||||||
block.txt_attn.num_heads //= N
|
block.txt_attn.num_heads //= N
|
||||||
block.sharding_group = group
|
block.sharding_group = group
|
||||||
block.img_attn.qkv = shard_linear(
|
block.img_attn.qkv = shard_linear(
|
||||||
block.img_attn.qkv, "all-to-sharded", groups=3, group=group
|
block.img_attn.qkv, "all-to-sharded", segments=3, group=group
|
||||||
)
|
)
|
||||||
block.txt_attn.qkv = shard_linear(
|
block.txt_attn.qkv = shard_linear(
|
||||||
block.txt_attn.qkv, "all-to-sharded", groups=3, group=group
|
block.txt_attn.qkv, "all-to-sharded", segments=3, group=group
|
||||||
)
|
)
|
||||||
shard_inplace(block.img_attn.proj, "sharded-to-all", group=group)
|
shard_inplace(block.img_attn.proj, "sharded-to-all", group=group)
|
||||||
shard_inplace(block.txt_attn.proj, "sharded-to-all", group=group)
|
shard_inplace(block.txt_attn.proj, "sharded-to-all", group=group)
|
||||||
@ -131,11 +131,11 @@ class Flux(nn.Module):
|
|||||||
block.linear1 = shard_linear(
|
block.linear1 = shard_linear(
|
||||||
block.linear1,
|
block.linear1,
|
||||||
"all-to-sharded",
|
"all-to-sharded",
|
||||||
groups=[1 / 7, 2 / 7, 3 / 7],
|
segments=[1 / 7, 2 / 7, 3 / 7],
|
||||||
group=group,
|
group=group,
|
||||||
)
|
)
|
||||||
block.linear2 = shard_linear(
|
block.linear2 = shard_linear(
|
||||||
block.linear2, "sharded-to-all", groups=[1 / 5], group=group
|
block.linear2, "sharded-to-all", segments=[1 / 5], group=group
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
@ -62,6 +62,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--adapter")
|
parser.add_argument("--adapter")
|
||||||
parser.add_argument("--fuse-adapter", action="store_true")
|
parser.add_argument("--fuse-adapter", action="store_true")
|
||||||
parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false")
|
parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false")
|
||||||
|
parser.add_argument("--force-shard", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Load the models
|
# Load the models
|
||||||
@ -77,8 +78,16 @@ if __name__ == "__main__":
|
|||||||
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
||||||
|
|
||||||
group = mx.distributed.init()
|
group = mx.distributed.init()
|
||||||
|
n_images = args.n_images
|
||||||
|
should_gather = False
|
||||||
if group.size() > 1:
|
if group.size() > 1:
|
||||||
flux.flow.shard(group)
|
if args.force_shard or n_images < group.size() or n_images % group.size() != 0:
|
||||||
|
flux.flow.shard(group)
|
||||||
|
if args.seed is None:
|
||||||
|
args.seed = mx.distributed.all_sum(mx.random.randint(0, 2**20)).item()
|
||||||
|
else:
|
||||||
|
n_images //= group.size()
|
||||||
|
should_gather = True
|
||||||
|
|
||||||
if args.preload_models:
|
if args.preload_models:
|
||||||
flux.ensure_models_are_loaded()
|
flux.ensure_models_are_loaded()
|
||||||
@ -87,7 +96,7 @@ if __name__ == "__main__":
|
|||||||
latent_size = to_latent_size(args.image_size)
|
latent_size = to_latent_size(args.image_size)
|
||||||
latents = flux.generate_latents(
|
latents = flux.generate_latents(
|
||||||
args.prompt,
|
args.prompt,
|
||||||
n_images=args.n_images,
|
n_images=n_images,
|
||||||
num_steps=args.steps,
|
num_steps=args.steps,
|
||||||
latent_size=latent_size,
|
latent_size=latent_size,
|
||||||
guidance=args.guidance,
|
guidance=args.guidance,
|
||||||
@ -97,8 +106,8 @@ if __name__ == "__main__":
|
|||||||
# First we get and eval the conditioning
|
# First we get and eval the conditioning
|
||||||
conditioning = next(latents)
|
conditioning = next(latents)
|
||||||
mx.eval(conditioning)
|
mx.eval(conditioning)
|
||||||
peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3
|
peak_mem_conditioning = mx.get_peak_memory() / 1024**3
|
||||||
mx.metal.reset_peak_memory()
|
mx.reset_peak_memory()
|
||||||
|
|
||||||
# The following is not necessary but it may help in memory constrained
|
# The following is not necessary but it may help in memory constrained
|
||||||
# systems by reusing the memory kept by the text encoders.
|
# systems by reusing the memory kept by the text encoders.
|
||||||
@ -106,36 +115,42 @@ if __name__ == "__main__":
|
|||||||
del flux.clip
|
del flux.clip
|
||||||
|
|
||||||
# Actual denoising loop
|
# Actual denoising loop
|
||||||
for x_t in tqdm(latents, total=args.steps):
|
for x_t in tqdm(latents, total=args.steps, disable=group.rank() > 0):
|
||||||
mx.eval(x_t)
|
mx.eval(x_t)
|
||||||
|
|
||||||
# The following is not necessary but it may help in memory constrained
|
# The following is not necessary but it may help in memory constrained
|
||||||
# systems by reusing the memory kept by the flow transformer.
|
# systems by reusing the memory kept by the flow transformer.
|
||||||
del flux.flow
|
del flux.flow
|
||||||
peak_mem_generation = mx.metal.get_peak_memory() / 1024**3
|
peak_mem_generation = mx.get_peak_memory() / 1024**3
|
||||||
mx.metal.reset_peak_memory()
|
mx.reset_peak_memory()
|
||||||
|
|
||||||
# Decode them into images
|
# Decode them into images
|
||||||
decoded = []
|
decoded = []
|
||||||
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
|
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
|
||||||
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
|
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
|
||||||
mx.eval(decoded[-1])
|
mx.eval(decoded[-1])
|
||||||
peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3
|
peak_mem_decoding = mx.get_peak_memory() / 1024**3
|
||||||
peak_mem_overall = max(
|
peak_mem_overall = max(
|
||||||
peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
|
peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Gather them if each node has different images
|
||||||
|
decoded = mx.concatenate(decoded, axis=0)
|
||||||
|
if should_gather:
|
||||||
|
decoded = mx.distributed.all_gather(decoded)
|
||||||
|
mx.eval(decoded)
|
||||||
|
|
||||||
if args.save_raw:
|
if args.save_raw:
|
||||||
*name, suffix = args.output.split(".")
|
*name, suffix = args.output.split(".")
|
||||||
name = ".".join(name)
|
name = ".".join(name)
|
||||||
x = mx.concatenate(decoded, axis=0)
|
x = decoded
|
||||||
x = (x * 255).astype(mx.uint8)
|
x = (x * 255).astype(mx.uint8)
|
||||||
for i in range(len(x)):
|
for i in range(len(x)):
|
||||||
im = Image.fromarray(np.array(x[i]))
|
im = Image.fromarray(np.array(x[i]))
|
||||||
im.save(".".join([name, str(i), suffix]))
|
im.save(".".join([name, str(i), suffix]))
|
||||||
else:
|
else:
|
||||||
# Arrange them on a grid
|
# Arrange them on a grid
|
||||||
x = mx.concatenate(decoded, axis=0)
|
x = decoded
|
||||||
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
|
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
|
||||||
B, H, W, C = x.shape
|
B, H, W, C = x.shape
|
||||||
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
||||||
|
Loading…
Reference in New Issue
Block a user