Merge branch 'ml-explore:main' into fix-unsupported-scalartype

This commit is contained in:
Ricardo La Rosa 2023-12-07 17:04:01 +01:00 committed by GitHub
commit 85345d42cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 75 additions and 51 deletions

2
.gitignore vendored
View File

@ -127,3 +127,5 @@ dmypy.json
# Pyre type checker # Pyre type checker
.pyre/ .pyre/
.idea/
.vscode/

View File

@ -17,6 +17,9 @@ weights you will need to [request
access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
from Meta. from Meta.
Alternatively, you can also download a select converted checkpoints from the [mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging Face and skip the conversion step.
Convert the weights with: Convert the weights with:
``` ```

View File

@ -1,3 +1,4 @@
mlx mlx
sentencepiece sentencepiece
torch torch
numpy

View File

@ -65,8 +65,9 @@ Performance
----------- -----------
The following table compares the performance of the UNet in stable diffusion. The following table compares the performance of the UNet in stable diffusion.
We report throughput in images per second for the provided `txt2image.py` We report throughput in images per second **processed by the UNet** for the
script and the `diffusers` library using the MPS PyTorch backend. provided `txt2image.py` script and the `diffusers` library using the MPS
PyTorch backend.
At the time of writing this comparison convolutions are still some of the least At the time of writing this comparison convolutions are still some of the least
optimized operations in MLX. Despite that, MLX still achieves **~40% higher optimized operations in MLX. Despite that, MLX still achieves **~40% higher
@ -93,3 +94,7 @@ The above experiments were made on an M2 Ultra with PyTorch version 2.1,
diffusers version 0.21.4 and transformers version 4.33.3. For the generation we diffusers version 0.21.4 and transformers version 4.33.3. For the generation we
used classifier free guidance which means that the above batch sizes result used classifier free guidance which means that the above batch sizes result
double the images processed by the UNet. double the images processed by the UNet.
Note that the above table means that it takes about 90 seconds to fully
generate 16 images with MLX and 50 diffusion steps with classifier free
guidance and about 120 for PyTorch.

View File

@ -1,3 +1,4 @@
mlx
safetensors safetensors
huggingface-hub huggingface-hub
regex regex

View File

@ -11,4 +11,4 @@ python main.py --gpu
By default the dataset is the [PTB corpus](https://paperswithcode.com/dataset/penn-treebank). Choose a different dataset with the `--dataset` option. By default the dataset is the [PTB corpus](https://paperswithcode.com/dataset/penn-treebank). Choose a different dataset with the `--dataset` option.
To run the PyTorch, Jax or TensorFlowexamples install the respective framework. To run the PyTorch, Jax or TensorFlow examples install the respective framework.

View File

@ -81,13 +81,13 @@ def main(args):
optimizer = optim.SGD(learning_rate=args.learning_rate) optimizer = optim.SGD(learning_rate=args.learning_rate)
loss_and_grad_fn = nn.value_and_grad(model, model.loss) loss_and_grad_fn = nn.value_and_grad(model, model.loss)
def eval_fn(params, dataset): def eval_fn(model, dataset):
inputs, targets = map(mx.array, to_samples(context_size, dataset)) inputs, targets = map(mx.array, to_samples(context_size, dataset))
loss = 0 loss = 0
for s in range(0, targets.shape[0], batch_size): for s in range(0, targets.shape[0], batch_size):
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size] bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
bx, by = map(mx.array, (bx, by)) bx, by = map(mx.array, (bx, by))
losses = self.loss(bx, by, reduce=False) losses = model.loss(bx, by, reduce=False)
loss += mx.sum(losses).item() loss += mx.sum(losses).item()
return loss / len(targets) return loss / len(targets)
@ -110,9 +110,8 @@ def main(args):
) )
losses = [] losses = []
tic = time.perf_counter() tic = time.perf_counter()
if (it + 1) % steps_per_eval == 0: if (it + 1) % steps_per_eval == 0:
val_loss = eval_fn(params, valid) val_loss = eval_fn(model, valid)
toc = time.perf_counter() toc = time.perf_counter()
print( print(
f"Iter {it + 1}: " f"Iter {it + 1}: "
@ -123,7 +122,7 @@ def main(args):
tic = time.perf_counter() tic = time.perf_counter()
if args.eval_test: if args.eval_test:
test_loss = eval_fn(params, test) test_loss = eval_fn(model, test)
test_ppl = math.exp(test_loss) test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")

View File

@ -1,5 +1,6 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import sys
import time import time
import mlx.core as mx import mlx.core as mx
@ -48,46 +49,58 @@ def everything():
if __name__ == "__main__": if __name__ == "__main__":
feat_time = timer(feats)
print(f"Feature time {feat_time:.3f}") # get command line arguments without 3rd party libraries
mels = feats()[None] # the command line argument to benchmark all models is "all"
tokens = mx.array( models = ["tiny"]
[ if len(sys.argv) > 1:
50364, if sys.argv[1] == "--all":
1396, models = ["tiny", "small", "medium", "large"]
264,
665, for model_name in models:
5133, feat_time = timer(feats)
23109,
25462, print(f"\nModel: {model_name.upper()}")
264, print(f"\nFeature time {feat_time:.3f}")
6582, mels = feats()[None]
293, tokens = mx.array(
750, [
632, 50364,
42841, 1396,
292, 264,
370, 665,
938, 5133,
294, 23109,
4054, 25462,
293, 264,
12653, 6582,
356, 293,
50620, 750,
50620, 632,
23563, 42841,
322, 292,
3312, 370,
13, 938,
50680, 294,
], 4054,
mx.int32, 293,
)[None] 12653,
model = load_models.load_model("tiny") 356,
model_forward_time = timer(model_forward, model, mels, tokens) 50620,
print(f"Model forward time {model_forward_time:.3f}") 50620,
decode_time = timer(decode, model, mels) 23563,
print(f"Decode time {decode_time:.3f}") 322,
everything_time = timer(everything) 3312,
print(f"Everything time {everything_time:.3f}") 13,
50680,
],
mx.int32,
)[None]
model = load_models.load_model(f"{model_name}")
model_forward_time = timer(model_forward, model, mels, tokens)
print(f"Model forward time {model_forward_time:.3f}")
decode_time = timer(decode, model, mels)
print(f"Decode time {decode_time:.3f}")
everything_time = timer(everything)
print(f"Everything time {everything_time:.3f}")
print(f"\n{'-----' * 10}\n")