diff --git a/llms/mlx_lm/examples/pipeline_generate.py b/llms/mlx_lm/examples/pipeline_generate.py index 95f0fcb8..7f6f34db 100644 --- a/llms/mlx_lm/examples/pipeline_generate.py +++ b/llms/mlx_lm/examples/pipeline_generate.py @@ -51,7 +51,7 @@ mx.eval(model.parameters()) # Synchronize processes before generation to avoid timeout if downloading # model for the first time. -mx.eval(mx.distributed.all_sum(mx.array(1.0))) +mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu)) def rprint(*args, **kwargs):