This commit is contained in:
Awni Hannun 2025-01-06 06:10:50 -08:00
parent 22d4a20dc2
commit 7fed460146

View File

@ -51,7 +51,7 @@ mx.eval(model.parameters())
# Synchronize processes before generation to avoid timeout if downloading
# model for the first time.
mx.eval(mx.distributed.all_sum(mx.array(1.0)))
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
def rprint(*args, **kwargs):