Merge branch 'ml-explore:main' into adding-orpo-training

This commit is contained in:
Gökdeniz Gülmez
2025-02-04 11:04:40 +01:00
committed by GitHub
7 changed files with 52 additions and 32 deletions

View File

@@ -44,7 +44,8 @@ def shard_and_load(repo):
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
)
# Lazy load and shard model
# Lazy load and shard model to figure out
# which weights we need
model, _ = load_model(model_path, lazy=True, strict=False)
group = mx.distributed.init(backend="mpi")
@@ -62,8 +63,11 @@ def shard_and_load(repo):
# Download weights for local shard
download(args.model, allow_patterns=local_files)
# Load and shard the model, and load the weights
tokenizer = load_tokenizer(model_path)
model, _ = load_model(model_path)
model, _ = load_model(model_path, lazy=True, strict=False)
model.model.pipeline(group)
mx.eval(model.parameters())
# Synchronize processes before generation to avoid timeout if downloading
# model for the first time.