mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-29 12:51:12 +08:00
fix
This commit is contained in:
parent
e80bf95182
commit
35ecc17042
@ -310,7 +310,7 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
|||||||
train, valid, test = load_local_dataset(args, data_path, tokenizer, args)
|
train, valid, test = load_local_dataset(args, data_path, tokenizer, args)
|
||||||
else:
|
else:
|
||||||
print(f"Loading Hugging Face dataset {args.data}.")
|
print(f"Loading Hugging Face dataset {args.data}.")
|
||||||
train, valid, test = load_hf_dataset(args.data, tokenizer, args)
|
train, valid, test = load_hf_dataset(args, args.data, tokenizer, args)
|
||||||
|
|
||||||
if args.train and len(train) == 0:
|
if args.train and len(train) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -117,7 +117,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
|
|||||||
|
|
||||||
end_sequence = tokenizer.encode("</answer>")
|
end_sequence = tokenizer.encode("</answer>")
|
||||||
end_sequence_length = len(end_sequence)
|
end_sequence_length = len(end_sequence)
|
||||||
output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32)
|
output = mx.zeros((prompt.shape[1] + max_tokens,))
|
||||||
output[:prompt.shape[1]] = prompt[0]
|
output[:prompt.shape[1]] = prompt[0]
|
||||||
current_length = prompt.shape[1]
|
current_length = prompt.shape[1]
|
||||||
|
|
||||||
@ -126,7 +126,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
|
|||||||
if temperature > 0:
|
if temperature > 0:
|
||||||
logits /= temperature
|
logits /= temperature
|
||||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||||
return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0]
|
return mx.random.categorical(logprobs[None, :])[0]
|
||||||
|
|
||||||
for _ in range(max_tokens):
|
for _ in range(max_tokens):
|
||||||
current_input = output[:current_length][None, :]
|
current_input = output[:current_length][None, :]
|
||||||
|
Loading…
Reference in New Issue
Block a user