mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 15:50:57 +08:00
Add input_masked loss calculation and batching w/ padding
This commit is contained in:
parent
f58c7de901
commit
604be3cec9
@ -63,6 +63,65 @@ class TrainingArgs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def iterate_input_masked_batches(
|
||||||
|
input_text, output_text, tokenizer, max_seq_length=2048
|
||||||
|
):
|
||||||
|
batch_size = len(input_text)
|
||||||
|
|
||||||
|
input_batch = [tokenizer.encode(record) for record in input_text]
|
||||||
|
output_batch = [tokenizer.encode(record) for record in output_text]
|
||||||
|
|
||||||
|
input_lengths = [len(x) for x in input_batch]
|
||||||
|
output_lengths = [len(x) for x in output_batch]
|
||||||
|
|
||||||
|
full_labels = [input_batch[idx] + output_batch[idx] for idx in range(batch_size)]
|
||||||
|
lengths = [len(x) for x in full_labels]
|
||||||
|
|
||||||
|
if max(lengths) > max_seq_length:
|
||||||
|
print(
|
||||||
|
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
|
||||||
|
f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. "
|
||||||
|
"Consider pre-splitting your data to save memory."
|
||||||
|
)
|
||||||
|
pad_to = 8
|
||||||
|
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
|
||||||
|
max_length_in_batch = min(max_length_in_batch, max_seq_length)
|
||||||
|
|
||||||
|
batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32)
|
||||||
|
adjusted_lengths = []
|
||||||
|
for j in range(batch_size):
|
||||||
|
input_length = input_lengths[j]
|
||||||
|
full_ids_end_idx = input_length + min(
|
||||||
|
output_lengths[j], max_length_in_batch - input_length
|
||||||
|
)
|
||||||
|
adjusted_lengths.append(full_ids_end_idx)
|
||||||
|
batch_arr[j, :full_ids_end_idx] = full_labels[j][:full_ids_end_idx]
|
||||||
|
|
||||||
|
batch = mx.array(batch_arr)
|
||||||
|
input_lengths = mx.array(input_lengths)
|
||||||
|
lengths = mx.array(adjusted_lengths)
|
||||||
|
|
||||||
|
return batch[:, :-1], input_lengths, lengths
|
||||||
|
|
||||||
|
|
||||||
|
def input_masked_loss(model, inputs, input_lengths, lengths):
|
||||||
|
shifted_inputs = inputs[:, :-1]
|
||||||
|
shifted_labels = inputs[:, 1:]
|
||||||
|
logits = model(shifted_inputs)
|
||||||
|
logits = logits.astype(mx.float32)
|
||||||
|
|
||||||
|
mask_width = shifted_inputs.shape[1]
|
||||||
|
token_indices = mx.arange(mask_width)[None, :]
|
||||||
|
mask = mx.logical_and(
|
||||||
|
token_indices >= input_lengths[:, None], token_indices < lengths[:, None]
|
||||||
|
)
|
||||||
|
|
||||||
|
ce = nn.losses.cross_entropy(logits, shifted_labels) * mask
|
||||||
|
ntoks = mask.sum()
|
||||||
|
ce = ce.sum() / ntoks
|
||||||
|
return ce, ntoks
|
||||||
|
|
||||||
|
|
||||||
def default_loss(model, inputs, targets, lengths):
|
def default_loss(model, inputs, targets, lengths):
|
||||||
logits = model(inputs)
|
logits = model(inputs)
|
||||||
logits = logits.astype(mx.float32)
|
logits = logits.astype(mx.float32)
|
||||||
|
Loading…
Reference in New Issue
Block a user