mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 00:30:09 +08:00
Fix variable reference
This commit is contained in:
parent
27cd361d76
commit
30fd5af843
@ -235,6 +235,10 @@ def train_model(
|
||||
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||
)
|
||||
)
|
||||
|
||||
if args.mask_inputs:
|
||||
print("Masking inputs..")
|
||||
|
||||
# Train model
|
||||
train(
|
||||
model=model,
|
||||
|
@ -168,7 +168,7 @@ def iterate_delineated_batches(
|
||||
for j in batch_idx[i]:
|
||||
prompt, completion = dataset.get_prompt_and_completion(j)
|
||||
prompt_length, completion_length = input_and_output_lengths(
|
||||
prompt, prompt, tokenizer
|
||||
prompt, completion, tokenizer
|
||||
)
|
||||
|
||||
prompt_lengths.append(prompt_length)
|
||||
|
Loading…
Reference in New Issue
Block a user