mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 15:08:37 +08:00
Fix variable reference
This commit is contained in:
parent
27cd361d76
commit
30fd5af843
@ -235,6 +235,10 @@ def train_model(
|
|||||||
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.mask_inputs:
|
||||||
|
print("Masking inputs..")
|
||||||
|
|
||||||
# Train model
|
# Train model
|
||||||
train(
|
train(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -168,7 +168,7 @@ def iterate_delineated_batches(
|
|||||||
for j in batch_idx[i]:
|
for j in batch_idx[i]:
|
||||||
prompt, completion = dataset.get_prompt_and_completion(j)
|
prompt, completion = dataset.get_prompt_and_completion(j)
|
||||||
prompt_length, completion_length = input_and_output_lengths(
|
prompt_length, completion_length = input_and_output_lengths(
|
||||||
prompt, prompt, tokenizer
|
prompt, completion, tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_lengths.append(prompt_length)
|
prompt_lengths.append(prompt_length)
|
||||||
|
Loading…
Reference in New Issue
Block a user