diff --git a/llms/llama/convert.py b/llms/llama/convert.py index d8f2c8e6..6c9dcea4 100644 --- a/llms/llama/convert.py +++ b/llms/llama/convert.py @@ -151,13 +151,11 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15): shards = [] shard, shard_size = {}, 0 for k, v in weights.items(): - # TODO: simplify to v.nbytes as soon as mx.array exposes it - estimated_size = v.size * v.dtype.size if isinstance(v, mx.array) else v.nbytes - if shard_size + estimated_size > max_file_size_bytes: + if shard_size + v.nbytes > max_file_size_bytes: shards.append(shard) shard, shard_size = {}, 0 shard[k] = v - shard_size += estimated_size + shard_size += v.nbytes shards.append(shard) return shards diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index e7d9a429..44f3abd4 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -311,12 +311,11 @@ def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list shards = [] shard, shard_size = {}, 0 for k, v in weights.items(): - estimated_size = v.size * v.dtype.size - if shard_size + estimated_size > max_file_size_bytes: + if shard_size + v.nbytes > max_file_size_bytes: shards.append(shard) shard, shard_size = {}, 0 shard[k] = v - shard_size += estimated_size + shard_size += v.nbytes shards.append(shard) return shards diff --git a/lora/utils.py b/lora/utils.py index dee35bc4..c76b097a 100644 --- a/lora/utils.py +++ b/lora/utils.py @@ -102,12 +102,11 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15): shards = [] shard, shard_size = {}, 0 for k, v in weights.items(): - estimated_size = v.size * v.dtype.size - if shard_size + estimated_size > max_file_size_bytes: + if shard_size + v.nbytes > max_file_size_bytes: shards.append(shard) shard, shard_size = {}, 0 shard[k] = v - shard_size += estimated_size + shard_size += v.nbytes shards.append(shard) return shards diff --git a/stable_diffusion/image2image.py b/stable_diffusion/image2image.py index 7a5d5eb2..46b25c29 100644 --- a/stable_diffusion/image2image.py +++ b/stable_diffusion/image2image.py @@ -50,8 +50,6 @@ if __name__ == "__main__": negative_text=args.negative_prompt, ) for x_t in tqdm(latents, total=int(args.steps * args.strength)): - mx.simplify(x_t) - mx.simplify(x_t) mx.eval(x_t) # Decode them into images diff --git a/stable_diffusion/txt2image.py b/stable_diffusion/txt2image.py index 9c49e1d2..d5d72974 100644 --- a/stable_diffusion/txt2image.py +++ b/stable_diffusion/txt2image.py @@ -34,8 +34,6 @@ if __name__ == "__main__": negative_text=args.negative_prompt, ) for x_t in tqdm(latents, total=args.steps): - mx.simplify(x_t) - mx.simplify(x_t) mx.eval(x_t) # Decode them into images diff --git a/transformer_lm/main.py b/transformer_lm/main.py index dea5a982..a546c874 100644 --- a/transformer_lm/main.py +++ b/transformer_lm/main.py @@ -28,8 +28,6 @@ class TransformerLM(nn.Module): def loss(self, x, y, reduce=True): logits = self(x) losses = nn.losses.cross_entropy(logits, y) - mx.simplify(losses) - return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2)) @@ -96,7 +94,6 @@ def main(args): inputs, targets = map(mx.array, (inputs, targets)) loss, grads = loss_and_grad_fn(inputs, targets) model.update(optimizer.apply_gradients(grads, model)) - mx.simplify(loss, model.parameters()) mx.eval(loss, model.parameters()) losses.append(loss.item()) if (it + 1) % steps_per_report == 0: