7 Commits

Author SHA1 Message Date
Anthony
e52c128d11 Use model.safetensors with Whisper (#1399)
Some checks failed
Test / check_lint (push) Has been cancelled
2025-12-15 06:17:08 -08:00
Awni Hannun
7ddca42f4d switch to github actions (#1394)
Some checks failed
Test / check_lint (push) Has been cancelled
2025-11-20 09:57:43 -08:00
Armin Stross-Radschinski
21a4d4cdab Update whisper command line help mentioning --word-timestamps (#1390) 2025-10-07 11:19:46 -07:00
Awni Hannun
8e4391ca21 whisper nits (#1388) 2025-09-03 13:18:50 -07:00
Awni Hannun
c1af8c46bd version (#1387) 2025-08-29 08:03:52 -07:00
Awni Hannun
f143957a06 switch quantized and non-quantized to be correct (#1385) 2025-08-29 07:53:44 -07:00
Awni Hannun
cfc5d25acd fix temperature based sampling (#1386) 2025-08-29 07:53:37 -07:00
18 changed files with 79 additions and 1376 deletions

View File

@@ -1,40 +0,0 @@
version: 2.1
orbs:
apple: ml-explore/pr-approval@0.1.0
jobs:
linux_build_and_test:
docker:
- image: cimg/python:3.9
steps:
- checkout
- run:
name: Run style checks
command: |
pip install pre-commit
pre-commit run --all
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
workflows:
build_and_test:
when:
matches:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
jobs:
- linux_build_and_test
prb:
when:
matches:
pattern: "^pull/\\d+(/head)?$"
value: << pipeline.git.branch >>
jobs:
- hold:
type: approval
- apple/authenticate:
context: pr-approval
- linux_build_and_test:
requires: [ hold ]

25
.github/workflows/pull_request.yml vendored Normal file
View File

@@ -0,0 +1,25 @@
name: Test
on:
push:
branches: ["main"]
pull_request:
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/head/main' }}
jobs:
check_lint:
if: github.repository == 'ml-explore/mlx-examples'
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: actions/setup-python@v6
with:
python-version: "3.10"
- uses: pre-commit/action@v3.0.1

Binary file not shown.

Before

Width:  |  Height:  |  Size: 208 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 142 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 101 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 79 KiB

View File

@@ -1,195 +0,0 @@
import mnist
import argparse
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
# Generator Block
def GenBlock(in_dim:int,out_dim:int):
return nn.Sequential(
nn.Linear(in_dim,out_dim),
nn.BatchNorm(out_dim, 0.8),
nn.LeakyReLU(0.2)
)
# Generator Model
class Generator(nn.Module):
def __init__(self, z_dim:int = 32, im_dim:int = 784, hidden_dim: int = 256):
super(Generator, self).__init__()
self.gen = nn.Sequential(
GenBlock(z_dim, hidden_dim),
GenBlock(hidden_dim, hidden_dim * 2),
GenBlock(hidden_dim * 2, hidden_dim * 4),
nn.Linear(hidden_dim * 4,im_dim),
)
def __call__(self, noise):
x = self.gen(noise)
return mx.tanh(x)
# make 2D noise with shape n_samples x z_dim
def get_noise(n_samples:list[int], z_dim:int)->list[int]:
return mx.random.normal(shape=(n_samples, z_dim))
#---------------------------------------------#
# Discriminator Block
def DisBlock(in_dim:int,out_dim:int):
return nn.Sequential(
nn.Linear(in_dim,out_dim),
nn.LeakyReLU(negative_slope=0.2),
nn.Dropout(0.3),
)
# Discriminator Model
class Discriminator(nn.Module):
def __init__(self,im_dim:int = 784, hidden_dim:int = 256):
super(Discriminator, self).__init__()
self.disc = nn.Sequential(
DisBlock(im_dim, hidden_dim * 4),
DisBlock(hidden_dim * 4, hidden_dim * 2),
DisBlock(hidden_dim * 2, hidden_dim),
nn.Linear(hidden_dim,1),
nn.Sigmoid()
)
def __call__(self, noise):
return self.disc(noise)
# Discriminator Loss
def disc_loss(gen, disc, real, num_images, z_dim):
noise = mx.array(get_noise(num_images, z_dim))
fake_images = gen(noise)
fake_disc = disc(fake_images)
fake_labels = mx.zeros((fake_images.shape[0],1))
fake_loss = mx.mean(nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True))
real_disc = mx.array(disc(real))
real_labels = mx.ones((real.shape[0],1))
real_loss = mx.mean(nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True))
disc_loss = (fake_loss + real_loss) / 2.0
return disc_loss
# Genearator Loss
def gen_loss(gen, disc, num_images, z_dim):
noise = mx.array(get_noise(num_images, z_dim))
fake_images = gen(noise)
fake_disc = mx.array(disc(fake_images))
fake_labels = mx.ones((fake_images.shape[0],1))
gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)
return mx.mean(gen_loss)
# make batch of images
def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:
perm = np.random.permutation(len(ipt))
for s in range(0, len(ipt), batch_size):
ids = perm[s : s + batch_size]
yield ipt[ids]
# plot batch of images at epoch steps
def show_images(epoch_num:int,imgs:list[int],num_imgs:int = 25):
if (imgs.shape[0] > 0):
fig,axes = plt.subplots(5, 5, figsize=(5, 5))
for i, ax in enumerate(axes.flat):
img = mx.array(imgs[i]).reshape(28,28)
ax.imshow(img,cmap='gray')
ax.axis('off')
plt.tight_layout()
plt.savefig('gen_images/img_{}.png'.format(epoch_num))
plt.show()
def main(args:dict):
seed = 42
n_epochs = 500
z_dim = 128
batch_size = 128
lr = 2e-5
mx.random.seed(seed)
# Load the data
train_images,*_ = map(np.array, getattr(mnist,'mnist')())
# Normalization images => [-1,1]
train_images = train_images * 2.0 - 1.0
gen = Generator(z_dim)
mx.eval(gen.parameters())
gen_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])
disc = Discriminator()
mx.eval(disc.parameters())
disc_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])
# TODO training...
D_loss_grad = nn.value_and_grad(disc, disc_loss)
G_loss_grad = nn.value_and_grad(gen, gen_loss)
for epoch in tqdm(range(n_epochs)):
for idx,real in enumerate(batch_iterate(batch_size, train_images)):
# TODO Train Discriminator
D_loss,D_grads = D_loss_grad(gen, disc,mx.array(real), batch_size, z_dim)
# Update optimizer
disc_opt.update(disc, D_grads)
# Update gradients
mx.eval(disc.parameters(), disc_opt.state)
# TODO Train Generator
G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)
# Update optimizer
gen_opt.update(gen, G_grads)
# Update gradients
mx.eval(gen.parameters(), gen_opt.state)
if epoch%100==0:
print("Epoch: {}, iteration: {}, Discriminator Loss:{}, Generator Loss: {}".format(epoch,idx,D_loss,G_loss))
fake_noise = mx.array(get_noise(batch_size, z_dim))
fake = gen(fake_noise)
show_images(epoch,fake)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Train a simple GAN on MNIST with MLX.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
parser.add_argument(
"--dataset",
type=str,
default="mnist",
choices=["mnist", "fashion_mnist"],
help="The dataset to use.",
)
args = parser.parse_args()
if not args.gpu:
mx.set_default_device(mx.cpu)
main(args)

View File

@@ -1,83 +0,0 @@
# Copyright © 2023 Apple Inc.
import gzip
import os
import pickle
from urllib import request
import numpy as np
def mnist(
save_dir="/tmp",
base_url="https://raw.githubusercontent.com/fgnt/mnist/master/",
filename="mnist.pkl",
):
"""
Load the MNIST dataset in 4 tensors: train images, train labels,
test images, and test labels.
Checks `save_dir` for already downloaded data otherwise downloads.
Download code modified from:
https://github.com/hsjeong5/MNIST-for-Numpy
"""
def download_and_save(save_file):
filename = [
["training_images", "train-images-idx3-ubyte.gz"],
["test_images", "t10k-images-idx3-ubyte.gz"],
["training_labels", "train-labels-idx1-ubyte.gz"],
["test_labels", "t10k-labels-idx1-ubyte.gz"],
]
mnist = {}
for name in filename:
out_file = os.path.join("/tmp", name[1])
request.urlretrieve(base_url + name[1], out_file)
for name in filename[:2]:
out_file = os.path.join("/tmp", name[1])
with gzip.open(out_file, "rb") as f:
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(
-1, 28 * 28
)
for name in filename[-2:]:
out_file = os.path.join("/tmp", name[1])
with gzip.open(out_file, "rb") as f:
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
with open(save_file, "wb") as f:
pickle.dump(mnist, f)
save_file = os.path.join(save_dir, filename)
if not os.path.exists(save_file):
download_and_save(save_file)
with open(save_file, "rb") as f:
mnist = pickle.load(f)
def preproc(x):
return x.astype(np.float32) / 255.0
mnist["training_images"] = preproc(mnist["training_images"])
mnist["test_images"] = preproc(mnist["test_images"])
return (
mnist["training_images"],
mnist["training_labels"].astype(np.uint32),
mnist["test_images"],
mnist["test_labels"].astype(np.uint32),
)
def fashion_mnist(save_dir="/tmp"):
return mnist(
save_dir,
base_url="http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/",
filename="fashion_mnist.pkl",
)
if __name__ == "__main__":
train_x, train_y, test_x, test_y = mnist()
assert train_x.shape == (60000, 28 * 28), "Wrong training set size"
assert train_y.shape == (60000,), "Wrong training set size"
assert test_x.shape == (10000, 28 * 28), "Wrong test set size"
assert test_y.shape == (10000,), "Wrong test set size"

File diff suppressed because one or more lines are too long

View File

@@ -11,12 +11,6 @@ audio_file = "mlx_whisper/assets/ls_test.flac"
def parse_arguments():
parser = argparse.ArgumentParser(description="Benchmark script.")
parser.add_argument(
"--mlx-dir",
type=str,
default="mlx_models",
help="The folder of MLX models",
)
parser.add_argument(
"--all",
action="store_true",

View File

@@ -382,7 +382,7 @@ if __name__ == "__main__":
# Save weights
print("[INFO] Saving")
mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights)
mx.save_safetensors(str(mlx_path / "model.safetensors"), weights)
# Save config.json with model_type
with open(str(mlx_path / "config.json"), "w") as f:

View File

@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.4.1"
__version__ = "0.4.3"

View File

@@ -156,42 +156,42 @@ def build_parser():
"--prepend-punctuations",
type=str,
default="\"'“¿([{-",
help="If word-timestamps is True, merge these punctuation symbols with the next word",
help="If --word-timestamps is True, merge these punctuation symbols with the next word",
)
parser.add_argument(
"--append-punctuations",
type=str,
default="\"'.。,!?::”)]}、",
help="If word_timestamps is True, merge these punctuation symbols with the previous word",
help="If --word-timestamps is True, merge these punctuation symbols with the previous word",
)
parser.add_argument(
"--highlight-words",
type=str2bool,
default=False,
help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt",
help="(requires --word-timestamps True) underline each word as it is spoken in srt and vtt",
)
parser.add_argument(
"--max-line-width",
type=int,
default=None,
help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line",
help="(requires --word-timestampss True) the maximum number of characters in a line before breaking the line",
)
parser.add_argument(
"--max-line-count",
type=int,
default=None,
help="(requires --word_timestamps True) the maximum number of lines in a segment",
help="(requires --word-timestamps True) the maximum number of lines in a segment",
)
parser.add_argument(
"--max-words-per-line",
type=int,
default=None,
help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment",
help="(requires --word-timestamps True, no effect with --max-line-width) the maximum number of words in a segment",
)
parser.add_argument(
"--hallucination-silence-threshold",
type=optional_float,
help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected",
help="(requires --word-timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected",
)
parser.add_argument(
"--clip-timestamps",

View File

@@ -265,7 +265,7 @@ class GreedyDecoder(TokenDecoder):
else:
next_tokens = categorical(logits, self.temperature)
logprobs = logits - mx.logsumexp(logits, axis=-1)
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens]
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
@@ -380,7 +380,7 @@ class ApplyTimestampRules(LogitFilter):
# if sum of probability over timestamps is above any other token, sample timestamp
mask = mx.array(mask)
logprobs = logits - mx.logsumexp(logits, axis=-1)
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
timestamp_logprob = logprobs[:, self.tokenizer.timestamp_begin :].logsumexp(
axis=-1, keepdims=True
)
@@ -603,6 +603,7 @@ class DecodingTask:
inputs = tokens[:, -1:]
if tokens.shape[-1] > self.n_ctx:
break
next_tokens, next_completed, next_sum_logprobs, _ = _step(
inputs, audio_features, tokens, sum_logprobs
)
@@ -643,9 +644,7 @@ class DecodingTask:
tokens = mx.broadcast_to(
tokens, [n_audio, self.n_group, len(self.initial_tokens)]
)
tokens = tokens.reshape(
tokens, (n_audio * self.n_group, len(self.initial_tokens))
)
tokens = tokens.reshape((n_audio * self.n_group, len(self.initial_tokens)))
# call the main sampling loop
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)

View File

@@ -26,7 +26,10 @@ def load_model(
model_args = whisper.ModelDimensions(**config)
wf = model_path / "weights.safetensors"
# Prefer model.safetensors, fall back to weights.safetensors, then weights.npz
wf = model_path / "model.safetensors"
if not wf.exists():
wf = model_path / "weights.safetensors"
if not wf.exists():
wf = model_path / "weights.npz"
weights = mx.load(str(wf))

View File

@@ -62,7 +62,7 @@ class ModelHolder:
def transcribe(
audio: Union[str, np.ndarray, mx.array],
*,
path_or_hf_repo: str = "mlx-community/whisper-tiny",
path_or_hf_repo: str = "mlx-community/whisper-turbo",
verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,

View File

@@ -10,7 +10,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "a9f4b67f",
"metadata": {},
"outputs": [],
@@ -31,7 +31,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "5a45bf5a",
"metadata": {},
"outputs": [],
@@ -58,30 +58,10 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "51bd2ed4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fetching 7 files: 100%|███████████████████████| 7/7 [00:00<00:00, 120328.39it/s]\n",
"==========\n",
"Here's a simple implementation of the QuickSort algorithm in Swift. This version uses Swift's built-in `swapAt()` function to swap elements in an array.\n",
"\n",
"```swift\n",
"func quickSort(_ array: inout [Int], _ low: Int, _ high: Int) {\n",
" if low < high {\n",
" let pivotIndex = partition(array, low, high)\n",
" quickSort(&array, low, pivot\n",
"==========\n",
"Prompt: 12 tokens, 78.111 tokens-per-sec\n",
"Generation: 100 tokens, 32.263 tokens-per-sec\n",
"Peak memory: 4.138 GB\n"
]
}
],
"outputs": [],
"source": [
"!mlx_lm.generate --model \"mlx-community/Mistral-7B-Instruct-v0.3-4bit\" \\\n",
" --prompt \"Write a quick sort in Swift\""
@@ -97,60 +77,10 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "f7add212",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fetching 7 files: 100%|███████████████████████| 7/7 [00:00<00:00, 100205.22it/s]\n",
"==========\n",
"Here's a simple implementation of the QuickSort algorithm in Swift. This version uses Swift's built-in `swapAt()` function to swap elements in an array.\n",
"\n",
"```swift\n",
"func quickSort(_ array: inout [Int], _ low: Int, _ high: Int) {\n",
" if low < high {\n",
" let pivotIndex = partition(array, low, high)\n",
" quickSort(&array, low, pivotIndex - 1)\n",
" quickSort(&array, pivotIndex + 1, high)\n",
" }\n",
"}\n",
"\n",
"func partition(_ array: inout [Int], _ low: Int, _ high: Int) -> Int {\n",
" let pivot = array[high]\n",
" var i = low\n",
" for j in low..<high {\n",
" if array[j] < pivot {\n",
" swapAt(&array, i, j)\n",
" i += 1\n",
" }\n",
" }\n",
" swapAt(&array, i, high)\n",
" return i\n",
"}\n",
"\n",
"func swapAt(_ array: inout [Int], _ i: Int, _ j: Int) {\n",
" let temp = array[i]\n",
" array[i] = array[j]\n",
" array[j] = temp\n",
"}\n",
"\n",
"// Example usage:\n",
"var arr = [3,6,8,5,2,1,9,7,4]\n",
"quickSort(&arr, 0, arr.count - 1)\n",
"print(arr) // Output: [1, 2, 3, 4, 5, 6, 7, 8, 9]\n",
"```\n",
"\n",
"This code sorts an array of integers in ascending order using the QuickSort algorithm. The `quickSort` function takes an array, a starting index, and an ending index, and recursively sorts the subarrays on either side of the pivot element. The `partition` function finds the pivot index, and the `swapAt` function swaps two elements at given indices.\n",
"==========\n",
"Prompt: 12 tokens, 79.511 tokens-per-sec\n",
"Generation: 448 tokens, 31.514 tokens-per-sec\n",
"Peak memory: 4.184 GB\n"
]
}
],
"outputs": [],
"source": [
"!mlx_lm.generate --model \"mlx-community/Mistral-7B-Instruct-v0.3-4bit\" \\\n",
" --prompt \"Write a quick sort in Swift\" \\\n",
@@ -170,62 +100,10 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"id": "e042a321",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cf86d861ac5e4879a194bfbc3f0e908d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Fetching 7 files: 0%| | 0/7 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========\n",
"Here's a simple implementation of the QuickSort algorithm in Swift. This version uses Swift's built-in `swapAt()` function to swap elements in an array.\n",
"\n",
"```swift\n",
"func quickSort(_ array: inout [Int], _ low: Int, _ high: Int) {\n",
" if low < high {\n",
" let pivotIndex = partition(array, low, high)\n",
" quickSort(&array, low, pivotIndex - 1)\n",
" quickSort(&array, pivotIndex + 1, high)\n",
" }\n",
"}\n",
"\n",
"func partition(_ array: inout [Int], _ low: Int, _ high: Int) -> Int {\n",
" let pivot = array[high]\n",
" var i = low\n",
" for j in low..<high {\n",
" if array[j] < pivot {\n",
" swapAt(&array, i, j)\n",
" i += 1\n",
" }\n",
" }\n",
" swapAt(&array, i, high)\n",
" return i\n",
"}\n",
"\n",
"func swapAt(_ array: inout [Int], _ i: Int, _ j: Int) {\n",
" let temp\n",
"==========\n",
"Prompt: 12 tokens, 78.600 tokens-per-sec\n",
"Generation: 256 tokens, 31.893 tokens-per-sec\n",
"Peak memory: 4.184 GB\n"
]
}
],
"outputs": [],
"source": [
"# Using MLX LM from Python\n",
"\n",
@@ -255,25 +133,10 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"id": "629dfa50",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6602e0adecbe4b58ba99e514fc9c9032",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Fetching 7 files: 0%| | 0/7 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"outputs": [],
"source": [
"from mlx_lm import load, generate\n",
"\n",
@@ -302,24 +165,10 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"id": "a3b56bdc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Attention(\n",
" (q_proj): QuantizedLinear(input_dims=4096, output_dims=4096, bias=False, group_size=64, bits=4)\n",
" (k_proj): QuantizedLinear(input_dims=4096, output_dims=1024, bias=False, group_size=64, bits=4)\n",
" (v_proj): QuantizedLinear(input_dims=4096, output_dims=1024, bias=False, group_size=64, bits=4)\n",
" (o_proj): QuantizedLinear(input_dims=4096, output_dims=4096, bias=False, group_size=64, bits=4)\n",
" (rope): RoPE(128, traditional=False)\n",
")\n"
]
}
],
"outputs": [],
"source": [
"print(model.layers[0].self_attn)"
]
@@ -334,62 +183,10 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"id": "775fd3f3",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ed3f0f097da64b379819f577a29dc9f6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Fetching 7 files: 0%| | 0/7 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========\n",
"Here's a simple implementation of the QuickSort algorithm in Swift. This version uses Swift's built-in `swapAt()` function to swap elements in an array.\n",
"\n",
"```swift\n",
"func quickSort(_ array: inout [Int], _ low: Int, _ high: Int) {\n",
" if low < high {\n",
" let pivotIndex = partition(array, low, high)\n",
" quickSort(&array, low, pivotIndex - 1)\n",
" quickSort(&array, pivotIndex + 1, high)\n",
" }\n",
"}\n",
"\n",
"func partition(_ array: inout [Int], _ low: Int, _ high: Int) -> Int {\n",
" let pivot = array[high]\n",
" var i = low\n",
" for j in low..<high {\n",
" if array[j] < pivot {\n",
" swapAt(&array, i, j)\n",
" i += 1\n",
" }\n",
" }\n",
" swapAt(&array, i, high)\n",
" return i\n",
"}\n",
"\n",
"func swapAt(_ array: inout [Int], _ i: Int, _ j: Int) {\n",
" let temp\n",
"==========\n",
"Prompt: 12 tokens, 76.085 tokens-per-sec\n",
"Generation: 256 tokens, 31.792 tokens-per-sec\n",
"Peak memory: 8.155 GB\n"
]
}
],
"outputs": [],
"source": [
"from mlx_lm import load, generate\n",
"from mlx_lm.models.cache import make_prompt_cache\n",
@@ -420,30 +217,10 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"id": "0d669073",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========\n",
"Imagine you have a big box full of toys. You want to sort them so that all the red toys are together, all the blue toys are together, and all the green toys are together.\n",
"\n",
"1. First, you pick one toy (this is your pivot toy).\n",
"2. Then, you look at all the other toys one by one. If a toy is not red, you move it to the left if it's not red, and if it's blue, you move it to the right. You keep doing this until you have looked at all the toys.\n",
"3. Now, you have a group of toys on the left that are red or blue, and a group of toys on the right that are green or blue. You swap the pivot toy with one of the toys in the group on the left or right, depending on whether you want red toys on the left or right.\n",
"4. Now, you repeat the same process with the group of toys on the left and the group of toys on the right, until all the toys are sorted!\n",
"\n",
"This is a quick way to sort a big box of toys, and it's called QuickSort!\n",
"==========\n",
"Prompt: 16 tokens, 116.542 tokens-per-sec\n",
"Generation: 245 tokens, 29.632 tokens-per-sec\n",
"Peak memory: 8.155 GB\n"
]
}
],
"outputs": [],
"source": [
"prompt = \"how can I explain it to a five year old?\"\n",
"messages = [{\"role\": \"user\", \"content\": prompt}]\n",
@@ -468,22 +245,10 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"id": "f8218994",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[INFO] Loading\n",
"Fetching 9 files: 100%|███████████████████████| 9/9 [00:00<00:00, 161319.38it/s]\n",
"[INFO] Using dtype: float16\n",
"[INFO] Quantizing\n",
"[INFO] Quantized model with 4.500 bits per weight.\n"
]
}
],
"outputs": [],
"source": [
"import os\n",
"mlx_path=\"./mistral-7b-v0.3-4bit\"\n",
@@ -496,24 +261,10 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"id": "d4e62b96",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Size of original bfloat16 model\n",
"===============================\n",
"3.8174 GB\n",
"\n",
"Size of quantized model\n",
"===============================\n",
"13.5049 GB\n"
]
}
],
"outputs": [],
"source": [
"import subprocess\n",
"\n",
@@ -523,12 +274,13 @@
" size_gb = size_mb / 1024\n",
" return size_gb\n",
"\n",
"directory_path = './mistral-7b-v0.3-4bit'\n",
"\n",
"directory_path = os.path.expanduser('~/.cache/huggingface/hub/models--mlx-community--Mistral-7B-Instruct-v0.3')\n",
"print(\"Size of original bfloat16 model\")\n",
"print(\"===============================\")\n",
"print(f\"{get_directory_size_mb(directory_path):2.4f} GB\")\n",
"print()\n",
"directory_path = os.path.expanduser('~/.cache/huggingface/hub/models--mlx-community--Mistral-7B-Instruct-v0.3')\n",
"directory_path = './mistral-7b-v0.3-4bit'\n",
"print(\"Size of quantized model\")\n",
"print(\"===============================\")\n",
"print(f\"{get_directory_size_mb(directory_path):2.4f} GB\")"
@@ -544,79 +296,10 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"id": "9d2cd325",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[INFO] Loading\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "15c0aac6b06b4541ab1d5d20f5c5a255",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Fetching 12 files: 0%| | 0/12 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ba81f7892189428cadcfed19173a0731",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"consolidated.safetensors: 42%|####1 | 10.5G/25.0G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[INFO] Using dtype: bfloat16\n",
"[INFO] Quantizing\n",
"[INFO] Quantized model with 4.574 bits per weight.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "60442f8dcab848ef9771a8e2b5516a13",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"README.md: 0%| | 0.00/7.82k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Size of mixed 4-6-bit quantized model\n",
"============================\n",
"3.8799 GB\n"
]
}
],
"outputs": [],
"source": [
"# Model quantization with MLX LM in Python\n",
"\n",
@@ -658,23 +341,10 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": null,
"id": "5efb794d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========\n",
"The latest Super Bowl, Super Bowl LV (55), was played on February 7, 2021, between the Kansas City Chiefs and the Tampa Bay Buccaneers. The Tampa Bay Buccaneers, led by quarterback Tom Brady, won the game, making it his seventh Super Bowl victory. This made Tom Brady the most successful quarterback in Super Bowl history.\n",
"==========\n",
"Prompt: 11 tokens, 8.131 tokens-per-sec\n",
"Generation: 87 tokens, 31.385 tokens-per-sec\n",
"Peak memory: 4.137 GB\n"
]
}
],
"outputs": [],
"source": [
"!mlx_lm.generate --model \"./mistral-7b-v0.3-4bit\" \\\n",
" --prompt \"Who played in the latest super bowl?\""
@@ -690,7 +360,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": null,
"id": "b4c31126",
"metadata": {},
"outputs": [],
@@ -713,23 +383,10 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": null,
"id": "7dcf9874",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========\n",
"In the latest Super Bowl, the Philadelphia Eagles soared to victory, claiming their championship title with a resounding 40-22 win over the Kansas City Chiefs. The Eagles' triumphant flight was led by their fearless leader, Jalen Hurts, who not only secured his place in the annals of Super Bowl history but also etched his name into the hearts of Eagles fans everywhere. This wasn't just any Super Bowl; it was Super Bowl\n",
"==========\n",
"Prompt: 11 tokens, 28.533 tokens-per-sec\n",
"Generation: 100 tokens, 30.986 tokens-per-sec\n",
"Peak memory: 4.151 GB\n"
]
}
],
"outputs": [],
"source": [
"!mlx_lm.generate --model \"./mistral-7b-v0.3-4bit\" \\\n",
" --prompt \"Who played in the latest super bowl?\" \\\n",
@@ -746,18 +403,10 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": null,
"id": "8935f7b6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading pretrained model\n"
]
}
],
"outputs": [],
"source": [
"!mlx_lm.fuse --model \"./mistral-7b-v0.3-4bit\" \\\n",
" --adapter-path \"adapters\" \\\n",
@@ -774,23 +423,10 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": null,
"id": "343a8977",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========\n",
"The latest Super Bowl, Super Bowl LIX, was played between the Philadelphia Eagles and the Kansas City Chiefs. The Philadelphia Eagles emerged victorious, with Jalen Hurts leading the charge for the Eagles.\n",
"==========\n",
"Prompt: 11 tokens, 11.760 tokens-per-sec\n",
"Generation: 46 tokens, 32.194 tokens-per-sec\n",
"Peak memory: 4.137 GB\n"
]
}
],
"outputs": [],
"source": [
"!mlx_lm.generate --model \"./fused-mistral-7b-v0.3-4bit\" \\\n",
" --prompt \"Who played in the latest super bowl?\" \\\n",
@@ -800,7 +436,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "mlx",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -814,7 +450,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
"version": "3.9.17"
}
},
"nbformat": 4,