mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
remove simplify (#379)
This commit is contained in:
parent
0b57f0eae6
commit
5aa652d3c2
@ -151,13 +151,11 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
||||
shards = []
|
||||
shard, shard_size = {}, 0
|
||||
for k, v in weights.items():
|
||||
# TODO: simplify to v.nbytes as soon as mx.array exposes it
|
||||
estimated_size = v.size * v.dtype.size if isinstance(v, mx.array) else v.nbytes
|
||||
if shard_size + estimated_size > max_file_size_bytes:
|
||||
if shard_size + v.nbytes > max_file_size_bytes:
|
||||
shards.append(shard)
|
||||
shard, shard_size = {}, 0
|
||||
shard[k] = v
|
||||
shard_size += estimated_size
|
||||
shard_size += v.nbytes
|
||||
shards.append(shard)
|
||||
return shards
|
||||
|
||||
|
@ -311,12 +311,11 @@ def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list
|
||||
shards = []
|
||||
shard, shard_size = {}, 0
|
||||
for k, v in weights.items():
|
||||
estimated_size = v.size * v.dtype.size
|
||||
if shard_size + estimated_size > max_file_size_bytes:
|
||||
if shard_size + v.nbytes > max_file_size_bytes:
|
||||
shards.append(shard)
|
||||
shard, shard_size = {}, 0
|
||||
shard[k] = v
|
||||
shard_size += estimated_size
|
||||
shard_size += v.nbytes
|
||||
shards.append(shard)
|
||||
return shards
|
||||
|
||||
|
@ -102,12 +102,11 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
||||
shards = []
|
||||
shard, shard_size = {}, 0
|
||||
for k, v in weights.items():
|
||||
estimated_size = v.size * v.dtype.size
|
||||
if shard_size + estimated_size > max_file_size_bytes:
|
||||
if shard_size + v.nbytes > max_file_size_bytes:
|
||||
shards.append(shard)
|
||||
shard, shard_size = {}, 0
|
||||
shard[k] = v
|
||||
shard_size += estimated_size
|
||||
shard_size += v.nbytes
|
||||
shards.append(shard)
|
||||
return shards
|
||||
|
||||
|
@ -50,8 +50,6 @@ if __name__ == "__main__":
|
||||
negative_text=args.negative_prompt,
|
||||
)
|
||||
for x_t in tqdm(latents, total=int(args.steps * args.strength)):
|
||||
mx.simplify(x_t)
|
||||
mx.simplify(x_t)
|
||||
mx.eval(x_t)
|
||||
|
||||
# Decode them into images
|
||||
|
@ -34,8 +34,6 @@ if __name__ == "__main__":
|
||||
negative_text=args.negative_prompt,
|
||||
)
|
||||
for x_t in tqdm(latents, total=args.steps):
|
||||
mx.simplify(x_t)
|
||||
mx.simplify(x_t)
|
||||
mx.eval(x_t)
|
||||
|
||||
# Decode them into images
|
||||
|
@ -28,8 +28,6 @@ class TransformerLM(nn.Module):
|
||||
def loss(self, x, y, reduce=True):
|
||||
logits = self(x)
|
||||
losses = nn.losses.cross_entropy(logits, y)
|
||||
mx.simplify(losses)
|
||||
|
||||
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
|
||||
|
||||
|
||||
@ -96,7 +94,6 @@ def main(args):
|
||||
inputs, targets = map(mx.array, (inputs, targets))
|
||||
loss, grads = loss_and_grad_fn(inputs, targets)
|
||||
model.update(optimizer.apply_gradients(grads, model))
|
||||
mx.simplify(loss, model.parameters())
|
||||
mx.eval(loss, model.parameters())
|
||||
losses.append(loss.item())
|
||||
if (it + 1) % steps_per_report == 0:
|
||||
|
Loading…
Reference in New Issue
Block a user