diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index c3e0b191..ba22764c 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -502,7 +502,7 @@ def save_weights( shard_name = shard_file_format.format(i + 1, shards_count) shard_path = save_path / shard_name - mx.save_safetensors(str(shard_path), shard) + mx.save_safetensors(str(shard_path), shard, metadata={"format": "mlx"}) for weight_name in shard.keys(): index_data["weight_map"][weight_name] = shard_name