This commit is contained in:
Alex Barron 2024-12-19 00:08:28 -08:00
parent 4385363c0f
commit d5f49d65b9

View File

@ -222,10 +222,10 @@ class MLXLM(LM):
if group is not None:
per_group = int(np.ceil(num_results / group.size()))
scores = mx.pad(scores, ((0, per_group - len(scores)),))
scores = mx.distributed.all_gather(scores[mx.newaxis], stream=mx.cpu)
scores = scores.T.reshape(-1)
is_greedy = mx.distributed.all_gather(is_greedy[mx.newaxis], stream=mx.cpu)
is_greedy = mx.pad(is_greedy, ((0, per_group - len(is_greedy))))
scores = mx.distributed.all_gather(scores[mx.newaxis], stream=mx.cpu)
is_greedy = mx.distributed.all_gather(is_greedy[mx.newaxis], stream=mx.cpu)
scores = scores.T.reshape(-1)
is_greedy = is_greedy.T.reshape(-1)
scores = np.array(scores[:num_results])