remove simplify (#379)

This commit is contained in:
Awni Hannun 2024-01-26 13:54:49 -08:00 committed by GitHub
parent 0b57f0eae6
commit 5aa652d3c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 6 additions and 17 deletions

View File

@ -151,13 +151,11 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
shards = []
shard, shard_size = {}, 0
for k, v in weights.items():
# TODO: simplify to v.nbytes as soon as mx.array exposes it
estimated_size = v.size * v.dtype.size if isinstance(v, mx.array) else v.nbytes
if shard_size + estimated_size > max_file_size_bytes:
if shard_size + v.nbytes > max_file_size_bytes:
shards.append(shard)
shard, shard_size = {}, 0
shard[k] = v
shard_size += estimated_size
shard_size += v.nbytes
shards.append(shard)
return shards

View File

@ -311,12 +311,11 @@ def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list
shards = []
shard, shard_size = {}, 0
for k, v in weights.items():
estimated_size = v.size * v.dtype.size
if shard_size + estimated_size > max_file_size_bytes:
if shard_size + v.nbytes > max_file_size_bytes:
shards.append(shard)
shard, shard_size = {}, 0
shard[k] = v
shard_size += estimated_size
shard_size += v.nbytes
shards.append(shard)
return shards

View File

@ -102,12 +102,11 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
shards = []
shard, shard_size = {}, 0
for k, v in weights.items():
estimated_size = v.size * v.dtype.size
if shard_size + estimated_size > max_file_size_bytes:
if shard_size + v.nbytes > max_file_size_bytes:
shards.append(shard)
shard, shard_size = {}, 0
shard[k] = v
shard_size += estimated_size
shard_size += v.nbytes
shards.append(shard)
return shards

View File

@ -50,8 +50,6 @@ if __name__ == "__main__":
negative_text=args.negative_prompt,
)
for x_t in tqdm(latents, total=int(args.steps * args.strength)):
mx.simplify(x_t)
mx.simplify(x_t)
mx.eval(x_t)
# Decode them into images

View File

@ -34,8 +34,6 @@ if __name__ == "__main__":
negative_text=args.negative_prompt,
)
for x_t in tqdm(latents, total=args.steps):
mx.simplify(x_t)
mx.simplify(x_t)
mx.eval(x_t)
# Decode them into images

View File

@ -28,8 +28,6 @@ class TransformerLM(nn.Module):
def loss(self, x, y, reduce=True):
logits = self(x)
losses = nn.losses.cross_entropy(logits, y)
mx.simplify(losses)
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
@ -96,7 +94,6 @@ def main(args):
inputs, targets = map(mx.array, (inputs, targets))
loss, grads = loss_and_grad_fn(inputs, targets)
model.update(optimizer.apply_gradients(grads, model))
mx.simplify(loss, model.parameters())
mx.eval(loss, model.parameters())
losses.append(loss.item())
if (it + 1) % steps_per_report == 0: