Merge branch 'ml-explore:main' into adding-orpo-training

This commit is contained in:
Gökdeniz Gülmez 2025-02-21 19:59:43 +01:00 committed by GitHub
commit de147187c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 15 additions and 21 deletions

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.21.0" __version__ = "0.21.5"

View File

@ -208,8 +208,14 @@ def train_model(
training_callback: TrainingCallback = None, training_callback: TrainingCallback = None,
): ):
model.freeze() model.freeze()
if args.num_layers > len(model.layers):
raise ValueError(
f"Requested to train {args.num_layers} layers "
f"but the model only has {len(model.layers)} layers."
)
if args.fine_tune_type == "full": if args.fine_tune_type == "full":
for l in model.layers[-min(args.num_layers, 0) :]: for l in model.layers[-max(args.num_layers, 0) :]:
l.unfreeze() l.unfreeze()
elif args.fine_tune_type in ["lora", "dora"]: elif args.fine_tune_type in ["lora", "dora"]:
# Convert linear layers to lora/dora layers and unfreeze in the process # Convert linear layers to lora/dora layers and unfreeze in the process

View File

@ -52,11 +52,6 @@ def linear_to_lora_layers(
use_dora (bool): If True, uses DoRA instead of LoRA. use_dora (bool): If True, uses DoRA instead of LoRA.
Default: ``False`` Default: ``False``
""" """
if num_layers > len(model.layers):
raise ValueError(
f"Requested {num_layers} LoRA layers "
f"but the model only has {len(model.layers)} layers."
)
def to_lora(layer): def to_lora(layer):
if isinstance(layer, (nn.Linear, nn.QuantizedLinear)): if isinstance(layer, (nn.Linear, nn.QuantizedLinear)):
@ -154,7 +149,7 @@ def linear_to_lora_layers(
else: else:
raise ValueError(f"Lora does not support {model.model_type}") raise ValueError(f"Lora does not support {model.model_type}")
for l in model.layers[-min(num_layers, 0) :]: for l in model.layers[-max(num_layers, 0) :]:
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys] lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
if lora_layers: if lora_layers:
l.update_modules(tree_unflatten(lora_layers)) l.update_modules(tree_unflatten(lora_layers))

View File

@ -409,8 +409,7 @@ def speculative_generate_step(
for processor in logits_processors: for processor in logits_processors:
logits = processor(tokens, logits) logits = processor(tokens, logits)
logprobs = logits - mx.logsumexp(logits, keepdims=True) logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
logprobs = logprobs.squeeze(0)
y = sampler(logprobs) y = sampler(logprobs)
return y, logprobs return y, logprobs
@ -429,16 +428,14 @@ def speculative_generate_step(
prev_tokens = ( prev_tokens = (
mx.concat([prev_tokens, y]) if prev_tokens is not None else y mx.concat([prev_tokens, y]) if prev_tokens is not None else y
) )
y, logprobs = _process_and_sample( y, logprobs = _process_and_sample(prev_tokens, logits[:, i, :])
prev_tokens, logits[:, i : i + 1, :]
)
out_y.append(y) out_y.append(y)
out_logprobs.append(logprobs) out_logprobs.append(logprobs)
return mx.concatenate(out_y, axis=0), mx.concatenate( return mx.concatenate(out_y, axis=0), mx.concatenate(
out_logprobs, axis=0 out_logprobs, axis=0
) )
else: else:
return _process_and_sample(None, logits) return _process_and_sample(None, logits.squeeze(0))
def _prefill(model, cache, y): def _prefill(model, cache, y):
while y.size > prefill_step_size: while y.size > prefill_step_size:
@ -476,13 +473,9 @@ def speculative_generate_step(
num_draft = min(max_tokens - ntoks, num_draft_tokens) num_draft = min(max_tokens - ntoks, num_draft_tokens)
draft_tokens = _draft_generate(draft_y, num_draft) draft_tokens = _draft_generate(draft_y, num_draft)
if prev_tokens is not None: if prev_tokens is not None:
prev_tokens = prev_tokens[ prev_tokens = prev_tokens[: prev_tokens.size - y.size - num_draft + 1]
: prev_tokens.size - draft_y.size - num_draft + 1
]
y = mx.concatenate([y, draft_tokens]) y = mx.concatenate([y, draft_tokens])
tokens, logprobs = _step(model, model_cache, y, num_draft + 1) tokens, logprobs = _step(model, model_cache, y, num_draft + 1)
mx.eval(tokens, draft_tokens) mx.eval(tokens, draft_tokens)
draft_tokens = draft_tokens.tolist() draft_tokens = draft_tokens.tolist()
tokens = tokens.tolist() tokens = tokens.tolist()
@ -514,8 +507,8 @@ def speculative_generate_step(
[mx.array(draft_tokens[-1:], mx.uint32), draft_y] [mx.array(draft_tokens[-1:], mx.uint32), draft_y]
) )
if prev_tokens is not None and n < num_draft: if prev_tokens is not None:
prev_tokens = prev_tokens[: -(num_draft - n)] prev_tokens = prev_tokens[: -max(num_draft - n, 1)]
_rewind_cache(num_draft, n) _rewind_cache(num_draft, n)
finally: finally:
_rewind_cache(num_draft, n) _rewind_cache(num_draft, n)