remove simplify (#379)

This commit is contained in:
Awni Hannun 2024-01-26 13:54:49 -08:00 committed by GitHub
parent 0b57f0eae6
commit 5aa652d3c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 6 additions and 17 deletions

View File

@ -151,13 +151,11 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
shards = [] shards = []
shard, shard_size = {}, 0 shard, shard_size = {}, 0
for k, v in weights.items(): for k, v in weights.items():
# TODO: simplify to v.nbytes as soon as mx.array exposes it if shard_size + v.nbytes > max_file_size_bytes:
estimated_size = v.size * v.dtype.size if isinstance(v, mx.array) else v.nbytes
if shard_size + estimated_size > max_file_size_bytes:
shards.append(shard) shards.append(shard)
shard, shard_size = {}, 0 shard, shard_size = {}, 0
shard[k] = v shard[k] = v
shard_size += estimated_size shard_size += v.nbytes
shards.append(shard) shards.append(shard)
return shards return shards

View File

@ -311,12 +311,11 @@ def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list
shards = [] shards = []
shard, shard_size = {}, 0 shard, shard_size = {}, 0
for k, v in weights.items(): for k, v in weights.items():
estimated_size = v.size * v.dtype.size if shard_size + v.nbytes > max_file_size_bytes:
if shard_size + estimated_size > max_file_size_bytes:
shards.append(shard) shards.append(shard)
shard, shard_size = {}, 0 shard, shard_size = {}, 0
shard[k] = v shard[k] = v
shard_size += estimated_size shard_size += v.nbytes
shards.append(shard) shards.append(shard)
return shards return shards

View File

@ -102,12 +102,11 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
shards = [] shards = []
shard, shard_size = {}, 0 shard, shard_size = {}, 0
for k, v in weights.items(): for k, v in weights.items():
estimated_size = v.size * v.dtype.size if shard_size + v.nbytes > max_file_size_bytes:
if shard_size + estimated_size > max_file_size_bytes:
shards.append(shard) shards.append(shard)
shard, shard_size = {}, 0 shard, shard_size = {}, 0
shard[k] = v shard[k] = v
shard_size += estimated_size shard_size += v.nbytes
shards.append(shard) shards.append(shard)
return shards return shards

View File

@ -50,8 +50,6 @@ if __name__ == "__main__":
negative_text=args.negative_prompt, negative_text=args.negative_prompt,
) )
for x_t in tqdm(latents, total=int(args.steps * args.strength)): for x_t in tqdm(latents, total=int(args.steps * args.strength)):
mx.simplify(x_t)
mx.simplify(x_t)
mx.eval(x_t) mx.eval(x_t)
# Decode them into images # Decode them into images

View File

@ -34,8 +34,6 @@ if __name__ == "__main__":
negative_text=args.negative_prompt, negative_text=args.negative_prompt,
) )
for x_t in tqdm(latents, total=args.steps): for x_t in tqdm(latents, total=args.steps):
mx.simplify(x_t)
mx.simplify(x_t)
mx.eval(x_t) mx.eval(x_t)
# Decode them into images # Decode them into images

View File

@ -28,8 +28,6 @@ class TransformerLM(nn.Module):
def loss(self, x, y, reduce=True): def loss(self, x, y, reduce=True):
logits = self(x) logits = self(x)
losses = nn.losses.cross_entropy(logits, y) losses = nn.losses.cross_entropy(logits, y)
mx.simplify(losses)
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2)) return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
@ -96,7 +94,6 @@ def main(args):
inputs, targets = map(mx.array, (inputs, targets)) inputs, targets = map(mx.array, (inputs, targets))
loss, grads = loss_and_grad_fn(inputs, targets) loss, grads = loss_and_grad_fn(inputs, targets)
model.update(optimizer.apply_gradients(grads, model)) model.update(optimizer.apply_gradients(grads, model))
mx.simplify(loss, model.parameters())
mx.eval(loss, model.parameters()) mx.eval(loss, model.parameters())
losses.append(loss.item()) losses.append(loss.item())
if (it + 1) % steps_per_report == 0: if (it + 1) % steps_per_report == 0: