mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
remove simplify (#379)
This commit is contained in:
@@ -151,13 +151,11 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
||||
shards = []
|
||||
shard, shard_size = {}, 0
|
||||
for k, v in weights.items():
|
||||
# TODO: simplify to v.nbytes as soon as mx.array exposes it
|
||||
estimated_size = v.size * v.dtype.size if isinstance(v, mx.array) else v.nbytes
|
||||
if shard_size + estimated_size > max_file_size_bytes:
|
||||
if shard_size + v.nbytes > max_file_size_bytes:
|
||||
shards.append(shard)
|
||||
shard, shard_size = {}, 0
|
||||
shard[k] = v
|
||||
shard_size += estimated_size
|
||||
shard_size += v.nbytes
|
||||
shards.append(shard)
|
||||
return shards
|
||||
|
||||
|
@@ -311,12 +311,11 @@ def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list
|
||||
shards = []
|
||||
shard, shard_size = {}, 0
|
||||
for k, v in weights.items():
|
||||
estimated_size = v.size * v.dtype.size
|
||||
if shard_size + estimated_size > max_file_size_bytes:
|
||||
if shard_size + v.nbytes > max_file_size_bytes:
|
||||
shards.append(shard)
|
||||
shard, shard_size = {}, 0
|
||||
shard[k] = v
|
||||
shard_size += estimated_size
|
||||
shard_size += v.nbytes
|
||||
shards.append(shard)
|
||||
return shards
|
||||
|
||||
|
Reference in New Issue
Block a user