mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-11-03 13:08:08 +08:00
feat(lora): add de-quantized support for fuse.py (#351)
* feat(lora): add de-quantized support for fuse.py * address comments
This commit is contained in:
@@ -115,11 +115,21 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
||||
def save_model(save_dir: str, weights, tokenizer, config):
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
shards = make_shards(weights)
|
||||
|
||||
shards = make_shards(weights, max_file_size_gibibyte=5)
|
||||
shards_count = len(shards)
|
||||
shard_file_format = (
|
||||
"model-{:05d}-of-{:05d}.safetensors"
|
||||
if shards_count > 1
|
||||
else "model.safetensors"
|
||||
)
|
||||
|
||||
for i, shard in enumerate(shards):
|
||||
# TODO use HF file name scheme for simplicity
|
||||
mx.save_safetensors(str(save_dir / f"weights.{i:02d}.safetensors"), shard)
|
||||
shard_name = shard_file_format.format(i + 1, shards_count)
|
||||
mx.save_safetensors(str(save_dir / shard_name), shard)
|
||||
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
|
||||
with open(save_dir / "config.json", "w") as fid:
|
||||
json.dump(config, fid, indent=4)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user