remove simplify (#379)

This commit is contained in:
Awni Hannun
2024-01-26 13:54:49 -08:00
committed by GitHub
parent 0b57f0eae6
commit 5aa652d3c2
6 changed files with 6 additions and 17 deletions

View File

@@ -151,13 +151,11 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
shards = []
shard, shard_size = {}, 0
for k, v in weights.items():
# TODO: simplify to v.nbytes as soon as mx.array exposes it
estimated_size = v.size * v.dtype.size if isinstance(v, mx.array) else v.nbytes
if shard_size + estimated_size > max_file_size_bytes:
if shard_size + v.nbytes > max_file_size_bytes:
shards.append(shard)
shard, shard_size = {}, 0
shard[k] = v
shard_size += estimated_size
shard_size += v.nbytes
shards.append(shard)
return shards

View File

@@ -311,12 +311,11 @@ def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list
shards = []
shard, shard_size = {}, 0
for k, v in weights.items():
estimated_size = v.size * v.dtype.size
if shard_size + estimated_size > max_file_size_bytes:
if shard_size + v.nbytes > max_file_size_bytes:
shards.append(shard)
shard, shard_size = {}, 0
shard[k] = v
shard_size += estimated_size
shard_size += v.nbytes
shards.append(shard)
return shards