diff --git a/llms/mlx_lm/evaluate.py b/llms/mlx_lm/evaluate.py index 8fa00dd2..56dce27c 100644 --- a/llms/mlx_lm/evaluate.py +++ b/llms/mlx_lm/evaluate.py @@ -222,10 +222,10 @@ class MLXLM(LM): if group is not None: per_group = int(np.ceil(num_results / group.size())) scores = mx.pad(scores, ((0, per_group - len(scores)),)) - scores = mx.distributed.all_gather(scores[mx.newaxis], stream=mx.cpu) - scores = scores.T.reshape(-1) - is_greedy = mx.distributed.all_gather(is_greedy[mx.newaxis], stream=mx.cpu) is_greedy = mx.pad(is_greedy, ((0, per_group - len(is_greedy)))) + scores = mx.distributed.all_gather(scores[mx.newaxis], stream=mx.cpu) + is_greedy = mx.distributed.all_gather(is_greedy[mx.newaxis], stream=mx.cpu) + scores = scores.T.reshape(-1) is_greedy = is_greedy.T.reshape(-1) scores = np.array(scores[:num_results])