mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
remove simplify (#379)
This commit is contained in:
parent
0b57f0eae6
commit
5aa652d3c2
@ -151,13 +151,11 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
|||||||
shards = []
|
shards = []
|
||||||
shard, shard_size = {}, 0
|
shard, shard_size = {}, 0
|
||||||
for k, v in weights.items():
|
for k, v in weights.items():
|
||||||
# TODO: simplify to v.nbytes as soon as mx.array exposes it
|
if shard_size + v.nbytes > max_file_size_bytes:
|
||||||
estimated_size = v.size * v.dtype.size if isinstance(v, mx.array) else v.nbytes
|
|
||||||
if shard_size + estimated_size > max_file_size_bytes:
|
|
||||||
shards.append(shard)
|
shards.append(shard)
|
||||||
shard, shard_size = {}, 0
|
shard, shard_size = {}, 0
|
||||||
shard[k] = v
|
shard[k] = v
|
||||||
shard_size += estimated_size
|
shard_size += v.nbytes
|
||||||
shards.append(shard)
|
shards.append(shard)
|
||||||
return shards
|
return shards
|
||||||
|
|
||||||
|
@ -311,12 +311,11 @@ def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list
|
|||||||
shards = []
|
shards = []
|
||||||
shard, shard_size = {}, 0
|
shard, shard_size = {}, 0
|
||||||
for k, v in weights.items():
|
for k, v in weights.items():
|
||||||
estimated_size = v.size * v.dtype.size
|
if shard_size + v.nbytes > max_file_size_bytes:
|
||||||
if shard_size + estimated_size > max_file_size_bytes:
|
|
||||||
shards.append(shard)
|
shards.append(shard)
|
||||||
shard, shard_size = {}, 0
|
shard, shard_size = {}, 0
|
||||||
shard[k] = v
|
shard[k] = v
|
||||||
shard_size += estimated_size
|
shard_size += v.nbytes
|
||||||
shards.append(shard)
|
shards.append(shard)
|
||||||
return shards
|
return shards
|
||||||
|
|
||||||
|
@ -102,12 +102,11 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
|||||||
shards = []
|
shards = []
|
||||||
shard, shard_size = {}, 0
|
shard, shard_size = {}, 0
|
||||||
for k, v in weights.items():
|
for k, v in weights.items():
|
||||||
estimated_size = v.size * v.dtype.size
|
if shard_size + v.nbytes > max_file_size_bytes:
|
||||||
if shard_size + estimated_size > max_file_size_bytes:
|
|
||||||
shards.append(shard)
|
shards.append(shard)
|
||||||
shard, shard_size = {}, 0
|
shard, shard_size = {}, 0
|
||||||
shard[k] = v
|
shard[k] = v
|
||||||
shard_size += estimated_size
|
shard_size += v.nbytes
|
||||||
shards.append(shard)
|
shards.append(shard)
|
||||||
return shards
|
return shards
|
||||||
|
|
||||||
|
@ -50,8 +50,6 @@ if __name__ == "__main__":
|
|||||||
negative_text=args.negative_prompt,
|
negative_text=args.negative_prompt,
|
||||||
)
|
)
|
||||||
for x_t in tqdm(latents, total=int(args.steps * args.strength)):
|
for x_t in tqdm(latents, total=int(args.steps * args.strength)):
|
||||||
mx.simplify(x_t)
|
|
||||||
mx.simplify(x_t)
|
|
||||||
mx.eval(x_t)
|
mx.eval(x_t)
|
||||||
|
|
||||||
# Decode them into images
|
# Decode them into images
|
||||||
|
@ -34,8 +34,6 @@ if __name__ == "__main__":
|
|||||||
negative_text=args.negative_prompt,
|
negative_text=args.negative_prompt,
|
||||||
)
|
)
|
||||||
for x_t in tqdm(latents, total=args.steps):
|
for x_t in tqdm(latents, total=args.steps):
|
||||||
mx.simplify(x_t)
|
|
||||||
mx.simplify(x_t)
|
|
||||||
mx.eval(x_t)
|
mx.eval(x_t)
|
||||||
|
|
||||||
# Decode them into images
|
# Decode them into images
|
||||||
|
@ -28,8 +28,6 @@ class TransformerLM(nn.Module):
|
|||||||
def loss(self, x, y, reduce=True):
|
def loss(self, x, y, reduce=True):
|
||||||
logits = self(x)
|
logits = self(x)
|
||||||
losses = nn.losses.cross_entropy(logits, y)
|
losses = nn.losses.cross_entropy(logits, y)
|
||||||
mx.simplify(losses)
|
|
||||||
|
|
||||||
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
|
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
|
||||||
|
|
||||||
|
|
||||||
@ -96,7 +94,6 @@ def main(args):
|
|||||||
inputs, targets = map(mx.array, (inputs, targets))
|
inputs, targets = map(mx.array, (inputs, targets))
|
||||||
loss, grads = loss_and_grad_fn(inputs, targets)
|
loss, grads = loss_and_grad_fn(inputs, targets)
|
||||||
model.update(optimizer.apply_gradients(grads, model))
|
model.update(optimizer.apply_gradients(grads, model))
|
||||||
mx.simplify(loss, model.parameters())
|
|
||||||
mx.eval(loss, model.parameters())
|
mx.eval(loss, model.parameters())
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
if (it + 1) % steps_per_report == 0:
|
if (it + 1) % steps_per_report == 0:
|
||||||
|
Loading…
Reference in New Issue
Block a user