mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
ordering
This commit is contained in:
parent
4385363c0f
commit
d5f49d65b9
@ -222,10 +222,10 @@ class MLXLM(LM):
|
||||
if group is not None:
|
||||
per_group = int(np.ceil(num_results / group.size()))
|
||||
scores = mx.pad(scores, ((0, per_group - len(scores)),))
|
||||
scores = mx.distributed.all_gather(scores[mx.newaxis], stream=mx.cpu)
|
||||
scores = scores.T.reshape(-1)
|
||||
is_greedy = mx.distributed.all_gather(is_greedy[mx.newaxis], stream=mx.cpu)
|
||||
is_greedy = mx.pad(is_greedy, ((0, per_group - len(is_greedy))))
|
||||
scores = mx.distributed.all_gather(scores[mx.newaxis], stream=mx.cpu)
|
||||
is_greedy = mx.distributed.all_gather(is_greedy[mx.newaxis], stream=mx.cpu)
|
||||
scores = scores.T.reshape(-1)
|
||||
is_greedy = is_greedy.T.reshape(-1)
|
||||
|
||||
scores = np.array(scores[:num_results])
|
||||
|
Loading…
Reference in New Issue
Block a user