Fix variable reference

This commit is contained in:
Chime Ogbuji 2024-11-06 12:33:49 -05:00 committed by Awni Hannun
parent 27cd361d76
commit 30fd5af843
2 changed files with 5 additions and 1 deletions

View File

@ -235,6 +235,10 @@ def train_model(
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
)
)
if args.mask_inputs:
print("Masking inputs..")
# Train model
train(
model=model,

View File

@ -168,7 +168,7 @@ def iterate_delineated_batches(
for j in batch_idx[i]:
prompt, completion = dataset.get_prompt_and_completion(j)
prompt_length, completion_length = input_and_output_lengths(
prompt, prompt, tokenizer
prompt, completion, tokenizer
)
prompt_lengths.append(prompt_length)