39 Commits

Author SHA1 Message Date
Huss
43ff302638 Merge a5752be9d9 into 4b2a0df237 2025-06-11 17:49:12 +03:00
Shubbair
a5752be9d9 Code Arrangement 2024-08-01 15:41:21 +03:00
Shubbair
f84b231cf2 Code Arrangement 2024-08-01 15:29:43 +03:00
Shubbair
7e0bdacef3 Code Arrangement 2024-08-01 15:22:19 +03:00
Shubbair
37bbf3ec54 Updating GAN Code... 2024-08-01 01:04:14 +03:00
Shubbair
4d17f80efb Updating GAN Code... 2024-07-31 20:23:57 +03:00
Shubbair
1ef3ad2c6c Updating GAN Code... 2024-07-31 19:59:36 +03:00
Shubbair
a8ffa9cb18 Updating GAN Code... 2024-07-31 11:50:32 +03:00
Shubbair
f70cef9567 Updating GAN Code... 2024-07-31 11:25:39 +03:00
Shubbair
6f7a6609b9 Updating MLX Notebook 2024-07-30 20:01:14 +03:00
Shubbair
0644cc101b Updating MLX Notebook 2024-07-30 19:50:02 +03:00
Shubbair
ad2b6643c0 Updating GAN Code... 2024-07-30 16:59:35 +03:00
Shubbair
3bea855bd2 Updating GAN Code... 2024-07-30 13:45:09 +03:00
Shubbair
c2d731d8a3 Updating GAN Code... 2024-07-30 13:24:53 +03:00
Shubbair
ba52447385 Updating GAN Code... 2024-07-30 13:21:38 +03:00
Shubbair
1e386b5c20 Updating GAN Code... 2024-07-30 02:56:13 +03:00
Shubbair
7438b54ecd Updating GAN Code... 2024-07-30 02:44:41 +03:00
Shubbair
7fea34d65e Updating GAN Code... 2024-07-30 02:37:09 +03:00
Shubbair
f505fe6e55 Updating GAN Code... 2024-07-30 02:17:12 +03:00
Shubbair
4e80759b39 Updating GAN Code... 2024-07-30 02:06:52 +03:00
Shubbair
306e53c402 Updating GAN Code... 2024-07-29 19:44:16 +03:00
Shubbair
bacaa9ec0e Updating GAN Code... 2024-07-29 01:30:08 +03:00
Shubbair
8d27be1442 Updating GAN Code... 2024-07-29 01:24:50 +03:00
Shubbair
4de0583b49 Updating GAN Code... 2024-07-28 19:18:35 +03:00
Shubbair
a07ef6d03b Updating GAN Code... 2024-07-28 18:11:39 +03:00
Shubbair
c0c8293842 Updating GAN Code... 2024-07-28 17:56:26 +03:00
Shubbair
d17d293df9 Updating GAN Code... 2024-07-28 17:35:36 +03:00
Shubbair
3e63cd93fe Updating GAN Code... 2024-07-28 17:26:24 +03:00
Shubbair
3716501e8d Updating GAN Code... 2024-07-28 17:22:40 +03:00
Shubbair
88a20b7276 Updating GAN Code... 2024-07-28 01:10:19 +03:00
Shubbair
8b1713737a Updating GAN Code... 2024-07-27 01:20:00 +03:00
Shubbair
f8b7094fb8 Updating GAN Code... 2024-07-27 01:19:50 +03:00
Shubbair
147cb3d2bc Updating GAN Code... 2024-07-27 01:09:51 +03:00
Shubbair
a05608c34d Updating GAN Code... 2024-07-27 00:22:29 +03:00
Shubbair
f176cce74d Updating GAN Code... 2024-07-27 00:19:08 +03:00
Shubbair
959c623908 Updating GAN Code... 2024-07-26 16:38:55 +03:00
Shubbair
591074bea8 Updating GAN Code... 2024-07-26 16:36:29 +03:00
Shubbair
d426586b03 Updating GAN Code... 2024-07-26 16:07:40 +03:00
Shubbair
5e7ce1048c Add GAN model 25/7 2024-07-25 21:00:41 +03:00
18 changed files with 1376 additions and 79 deletions

40
.circleci/config.yml Normal file
View File

@@ -0,0 +1,40 @@
version: 2.1
orbs:
apple: ml-explore/pr-approval@0.1.0
jobs:
linux_build_and_test:
docker:
- image: cimg/python:3.9
steps:
- checkout
- run:
name: Run style checks
command: |
pip install pre-commit
pre-commit run --all
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
workflows:
build_and_test:
when:
matches:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
jobs:
- linux_build_and_test
prb:
when:
matches:
pattern: "^pull/\\d+(/head)?$"
value: << pipeline.git.branch >>
jobs:
- hold:
type: approval
- apple/authenticate:
context: pr-approval
- linux_build_and_test:
requires: [ hold ]

View File

@@ -1,25 +0,0 @@
name: Test
on:
push:
branches: ["main"]
pull_request:
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/head/main' }}
jobs:
check_lint:
if: github.repository == 'ml-explore/mlx-examples'
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: actions/setup-python@v6
with:
python-version: "3.10"
- uses: pre-commit/action@v3.0.1

BIN
gan/gen_images/img_0.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 208 KiB

BIN
gan/gen_images/img_100.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 142 KiB

BIN
gan/gen_images/img_200.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

BIN
gan/gen_images/img_300.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 101 KiB

BIN
gan/gen_images/img_400.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 79 KiB

195
gan/main.py Normal file
View File

@@ -0,0 +1,195 @@
import mnist
import argparse
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
# Generator Block
def GenBlock(in_dim:int,out_dim:int):
return nn.Sequential(
nn.Linear(in_dim,out_dim),
nn.BatchNorm(out_dim, 0.8),
nn.LeakyReLU(0.2)
)
# Generator Model
class Generator(nn.Module):
def __init__(self, z_dim:int = 32, im_dim:int = 784, hidden_dim: int = 256):
super(Generator, self).__init__()
self.gen = nn.Sequential(
GenBlock(z_dim, hidden_dim),
GenBlock(hidden_dim, hidden_dim * 2),
GenBlock(hidden_dim * 2, hidden_dim * 4),
nn.Linear(hidden_dim * 4,im_dim),
)
def __call__(self, noise):
x = self.gen(noise)
return mx.tanh(x)
# make 2D noise with shape n_samples x z_dim
def get_noise(n_samples:list[int], z_dim:int)->list[int]:
return mx.random.normal(shape=(n_samples, z_dim))
#---------------------------------------------#
# Discriminator Block
def DisBlock(in_dim:int,out_dim:int):
return nn.Sequential(
nn.Linear(in_dim,out_dim),
nn.LeakyReLU(negative_slope=0.2),
nn.Dropout(0.3),
)
# Discriminator Model
class Discriminator(nn.Module):
def __init__(self,im_dim:int = 784, hidden_dim:int = 256):
super(Discriminator, self).__init__()
self.disc = nn.Sequential(
DisBlock(im_dim, hidden_dim * 4),
DisBlock(hidden_dim * 4, hidden_dim * 2),
DisBlock(hidden_dim * 2, hidden_dim),
nn.Linear(hidden_dim,1),
nn.Sigmoid()
)
def __call__(self, noise):
return self.disc(noise)
# Discriminator Loss
def disc_loss(gen, disc, real, num_images, z_dim):
noise = mx.array(get_noise(num_images, z_dim))
fake_images = gen(noise)
fake_disc = disc(fake_images)
fake_labels = mx.zeros((fake_images.shape[0],1))
fake_loss = mx.mean(nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True))
real_disc = mx.array(disc(real))
real_labels = mx.ones((real.shape[0],1))
real_loss = mx.mean(nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True))
disc_loss = (fake_loss + real_loss) / 2.0
return disc_loss
# Genearator Loss
def gen_loss(gen, disc, num_images, z_dim):
noise = mx.array(get_noise(num_images, z_dim))
fake_images = gen(noise)
fake_disc = mx.array(disc(fake_images))
fake_labels = mx.ones((fake_images.shape[0],1))
gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)
return mx.mean(gen_loss)
# make batch of images
def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:
perm = np.random.permutation(len(ipt))
for s in range(0, len(ipt), batch_size):
ids = perm[s : s + batch_size]
yield ipt[ids]
# plot batch of images at epoch steps
def show_images(epoch_num:int,imgs:list[int],num_imgs:int = 25):
if (imgs.shape[0] > 0):
fig,axes = plt.subplots(5, 5, figsize=(5, 5))
for i, ax in enumerate(axes.flat):
img = mx.array(imgs[i]).reshape(28,28)
ax.imshow(img,cmap='gray')
ax.axis('off')
plt.tight_layout()
plt.savefig('gen_images/img_{}.png'.format(epoch_num))
plt.show()
def main(args:dict):
seed = 42
n_epochs = 500
z_dim = 128
batch_size = 128
lr = 2e-5
mx.random.seed(seed)
# Load the data
train_images,*_ = map(np.array, getattr(mnist,'mnist')())
# Normalization images => [-1,1]
train_images = train_images * 2.0 - 1.0
gen = Generator(z_dim)
mx.eval(gen.parameters())
gen_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])
disc = Discriminator()
mx.eval(disc.parameters())
disc_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])
# TODO training...
D_loss_grad = nn.value_and_grad(disc, disc_loss)
G_loss_grad = nn.value_and_grad(gen, gen_loss)
for epoch in tqdm(range(n_epochs)):
for idx,real in enumerate(batch_iterate(batch_size, train_images)):
# TODO Train Discriminator
D_loss,D_grads = D_loss_grad(gen, disc,mx.array(real), batch_size, z_dim)
# Update optimizer
disc_opt.update(disc, D_grads)
# Update gradients
mx.eval(disc.parameters(), disc_opt.state)
# TODO Train Generator
G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)
# Update optimizer
gen_opt.update(gen, G_grads)
# Update gradients
mx.eval(gen.parameters(), gen_opt.state)
if epoch%100==0:
print("Epoch: {}, iteration: {}, Discriminator Loss:{}, Generator Loss: {}".format(epoch,idx,D_loss,G_loss))
fake_noise = mx.array(get_noise(batch_size, z_dim))
fake = gen(fake_noise)
show_images(epoch,fake)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Train a simple GAN on MNIST with MLX.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
parser.add_argument(
"--dataset",
type=str,
default="mnist",
choices=["mnist", "fashion_mnist"],
help="The dataset to use.",
)
args = parser.parse_args()
if not args.gpu:
mx.set_default_device(mx.cpu)
main(args)

83
gan/mnist.py Normal file
View File

@@ -0,0 +1,83 @@
# Copyright © 2023 Apple Inc.
import gzip
import os
import pickle
from urllib import request
import numpy as np
def mnist(
save_dir="/tmp",
base_url="https://raw.githubusercontent.com/fgnt/mnist/master/",
filename="mnist.pkl",
):
"""
Load the MNIST dataset in 4 tensors: train images, train labels,
test images, and test labels.
Checks `save_dir` for already downloaded data otherwise downloads.
Download code modified from:
https://github.com/hsjeong5/MNIST-for-Numpy
"""
def download_and_save(save_file):
filename = [
["training_images", "train-images-idx3-ubyte.gz"],
["test_images", "t10k-images-idx3-ubyte.gz"],
["training_labels", "train-labels-idx1-ubyte.gz"],
["test_labels", "t10k-labels-idx1-ubyte.gz"],
]
mnist = {}
for name in filename:
out_file = os.path.join("/tmp", name[1])
request.urlretrieve(base_url + name[1], out_file)
for name in filename[:2]:
out_file = os.path.join("/tmp", name[1])
with gzip.open(out_file, "rb") as f:
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(
-1, 28 * 28
)
for name in filename[-2:]:
out_file = os.path.join("/tmp", name[1])
with gzip.open(out_file, "rb") as f:
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
with open(save_file, "wb") as f:
pickle.dump(mnist, f)
save_file = os.path.join(save_dir, filename)
if not os.path.exists(save_file):
download_and_save(save_file)
with open(save_file, "rb") as f:
mnist = pickle.load(f)
def preproc(x):
return x.astype(np.float32) / 255.0
mnist["training_images"] = preproc(mnist["training_images"])
mnist["test_images"] = preproc(mnist["test_images"])
return (
mnist["training_images"],
mnist["training_labels"].astype(np.uint32),
mnist["test_images"],
mnist["test_labels"].astype(np.uint32),
)
def fashion_mnist(save_dir="/tmp"):
return mnist(
save_dir,
base_url="http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/",
filename="fashion_mnist.pkl",
)
if __name__ == "__main__":
train_x, train_y, test_x, test_y = mnist()
assert train_x.shape == (60000, 28 * 28), "Wrong training set size"
assert train_y.shape == (60000,), "Wrong training set size"
assert test_x.shape == (10000, 28 * 28), "Wrong test set size"
assert test_y.shape == (10000,), "Wrong test set size"

636
gan/playground.ipynb Normal file

File diff suppressed because one or more lines are too long

View File

@@ -11,6 +11,12 @@ audio_file = "mlx_whisper/assets/ls_test.flac"
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser(description="Benchmark script.") parser = argparse.ArgumentParser(description="Benchmark script.")
parser.add_argument(
"--mlx-dir",
type=str,
default="mlx_models",
help="The folder of MLX models",
)
parser.add_argument( parser.add_argument(
"--all", "--all",
action="store_true", action="store_true",

View File

@@ -382,7 +382,7 @@ if __name__ == "__main__":
# Save weights # Save weights
print("[INFO] Saving") print("[INFO] Saving")
mx.save_safetensors(str(mlx_path / "model.safetensors"), weights) mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights)
# Save config.json with model_type # Save config.json with model_type
with open(str(mlx_path / "config.json"), "w") as f: with open(str(mlx_path / "config.json"), "w") as f:

View File

@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.4.3" __version__ = "0.4.1"

View File

@@ -156,42 +156,42 @@ def build_parser():
"--prepend-punctuations", "--prepend-punctuations",
type=str, type=str,
default="\"'“¿([{-", default="\"'“¿([{-",
help="If --word-timestamps is True, merge these punctuation symbols with the next word", help="If word-timestamps is True, merge these punctuation symbols with the next word",
) )
parser.add_argument( parser.add_argument(
"--append-punctuations", "--append-punctuations",
type=str, type=str,
default="\"'.。,!?::”)]}、", default="\"'.。,!?::”)]}、",
help="If --word-timestamps is True, merge these punctuation symbols with the previous word", help="If word_timestamps is True, merge these punctuation symbols with the previous word",
) )
parser.add_argument( parser.add_argument(
"--highlight-words", "--highlight-words",
type=str2bool, type=str2bool,
default=False, default=False,
help="(requires --word-timestamps True) underline each word as it is spoken in srt and vtt", help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt",
) )
parser.add_argument( parser.add_argument(
"--max-line-width", "--max-line-width",
type=int, type=int,
default=None, default=None,
help="(requires --word-timestampss True) the maximum number of characters in a line before breaking the line", help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line",
) )
parser.add_argument( parser.add_argument(
"--max-line-count", "--max-line-count",
type=int, type=int,
default=None, default=None,
help="(requires --word-timestamps True) the maximum number of lines in a segment", help="(requires --word_timestamps True) the maximum number of lines in a segment",
) )
parser.add_argument( parser.add_argument(
"--max-words-per-line", "--max-words-per-line",
type=int, type=int,
default=None, default=None,
help="(requires --word-timestamps True, no effect with --max-line-width) the maximum number of words in a segment", help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment",
) )
parser.add_argument( parser.add_argument(
"--hallucination-silence-threshold", "--hallucination-silence-threshold",
type=optional_float, type=optional_float,
help="(requires --word-timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected", help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected",
) )
parser.add_argument( parser.add_argument(
"--clip-timestamps", "--clip-timestamps",

View File

@@ -265,7 +265,7 @@ class GreedyDecoder(TokenDecoder):
else: else:
next_tokens = categorical(logits, self.temperature) next_tokens = categorical(logits, self.temperature)
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) logprobs = logits - mx.logsumexp(logits, axis=-1)
current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens] current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens]
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot) sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
@@ -380,7 +380,7 @@ class ApplyTimestampRules(LogitFilter):
# if sum of probability over timestamps is above any other token, sample timestamp # if sum of probability over timestamps is above any other token, sample timestamp
mask = mx.array(mask) mask = mx.array(mask)
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) logprobs = logits - mx.logsumexp(logits, axis=-1)
timestamp_logprob = logprobs[:, self.tokenizer.timestamp_begin :].logsumexp( timestamp_logprob = logprobs[:, self.tokenizer.timestamp_begin :].logsumexp(
axis=-1, keepdims=True axis=-1, keepdims=True
) )
@@ -603,7 +603,6 @@ class DecodingTask:
inputs = tokens[:, -1:] inputs = tokens[:, -1:]
if tokens.shape[-1] > self.n_ctx: if tokens.shape[-1] > self.n_ctx:
break break
next_tokens, next_completed, next_sum_logprobs, _ = _step( next_tokens, next_completed, next_sum_logprobs, _ = _step(
inputs, audio_features, tokens, sum_logprobs inputs, audio_features, tokens, sum_logprobs
) )
@@ -644,7 +643,9 @@ class DecodingTask:
tokens = mx.broadcast_to( tokens = mx.broadcast_to(
tokens, [n_audio, self.n_group, len(self.initial_tokens)] tokens, [n_audio, self.n_group, len(self.initial_tokens)]
) )
tokens = tokens.reshape((n_audio * self.n_group, len(self.initial_tokens))) tokens = tokens.reshape(
tokens, (n_audio * self.n_group, len(self.initial_tokens))
)
# call the main sampling loop # call the main sampling loop
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens) tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)

View File

@@ -26,9 +26,6 @@ def load_model(
model_args = whisper.ModelDimensions(**config) model_args = whisper.ModelDimensions(**config)
# Prefer model.safetensors, fall back to weights.safetensors, then weights.npz
wf = model_path / "model.safetensors"
if not wf.exists():
wf = model_path / "weights.safetensors" wf = model_path / "weights.safetensors"
if not wf.exists(): if not wf.exists():
wf = model_path / "weights.npz" wf = model_path / "weights.npz"

View File

@@ -62,7 +62,7 @@ class ModelHolder:
def transcribe( def transcribe(
audio: Union[str, np.ndarray, mx.array], audio: Union[str, np.ndarray, mx.array],
*, *,
path_or_hf_repo: str = "mlx-community/whisper-turbo", path_or_hf_repo: str = "mlx-community/whisper-tiny",
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4, compression_ratio_threshold: Optional[float] = 2.4,

View File

@@ -10,7 +10,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 1,
"id": "a9f4b67f", "id": "a9f4b67f",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -31,7 +31,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 2,
"id": "5a45bf5a", "id": "5a45bf5a",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -58,10 +58,30 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 3,
"id": "51bd2ed4", "id": "51bd2ed4",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fetching 7 files: 100%|███████████████████████| 7/7 [00:00<00:00, 120328.39it/s]\n",
"==========\n",
"Here's a simple implementation of the QuickSort algorithm in Swift. This version uses Swift's built-in `swapAt()` function to swap elements in an array.\n",
"\n",
"```swift\n",
"func quickSort(_ array: inout [Int], _ low: Int, _ high: Int) {\n",
" if low < high {\n",
" let pivotIndex = partition(array, low, high)\n",
" quickSort(&array, low, pivot\n",
"==========\n",
"Prompt: 12 tokens, 78.111 tokens-per-sec\n",
"Generation: 100 tokens, 32.263 tokens-per-sec\n",
"Peak memory: 4.138 GB\n"
]
}
],
"source": [ "source": [
"!mlx_lm.generate --model \"mlx-community/Mistral-7B-Instruct-v0.3-4bit\" \\\n", "!mlx_lm.generate --model \"mlx-community/Mistral-7B-Instruct-v0.3-4bit\" \\\n",
" --prompt \"Write a quick sort in Swift\"" " --prompt \"Write a quick sort in Swift\""
@@ -77,10 +97,60 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 4,
"id": "f7add212", "id": "f7add212",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fetching 7 files: 100%|███████████████████████| 7/7 [00:00<00:00, 100205.22it/s]\n",
"==========\n",
"Here's a simple implementation of the QuickSort algorithm in Swift. This version uses Swift's built-in `swapAt()` function to swap elements in an array.\n",
"\n",
"```swift\n",
"func quickSort(_ array: inout [Int], _ low: Int, _ high: Int) {\n",
" if low < high {\n",
" let pivotIndex = partition(array, low, high)\n",
" quickSort(&array, low, pivotIndex - 1)\n",
" quickSort(&array, pivotIndex + 1, high)\n",
" }\n",
"}\n",
"\n",
"func partition(_ array: inout [Int], _ low: Int, _ high: Int) -> Int {\n",
" let pivot = array[high]\n",
" var i = low\n",
" for j in low..<high {\n",
" if array[j] < pivot {\n",
" swapAt(&array, i, j)\n",
" i += 1\n",
" }\n",
" }\n",
" swapAt(&array, i, high)\n",
" return i\n",
"}\n",
"\n",
"func swapAt(_ array: inout [Int], _ i: Int, _ j: Int) {\n",
" let temp = array[i]\n",
" array[i] = array[j]\n",
" array[j] = temp\n",
"}\n",
"\n",
"// Example usage:\n",
"var arr = [3,6,8,5,2,1,9,7,4]\n",
"quickSort(&arr, 0, arr.count - 1)\n",
"print(arr) // Output: [1, 2, 3, 4, 5, 6, 7, 8, 9]\n",
"```\n",
"\n",
"This code sorts an array of integers in ascending order using the QuickSort algorithm. The `quickSort` function takes an array, a starting index, and an ending index, and recursively sorts the subarrays on either side of the pivot element. The `partition` function finds the pivot index, and the `swapAt` function swaps two elements at given indices.\n",
"==========\n",
"Prompt: 12 tokens, 79.511 tokens-per-sec\n",
"Generation: 448 tokens, 31.514 tokens-per-sec\n",
"Peak memory: 4.184 GB\n"
]
}
],
"source": [ "source": [
"!mlx_lm.generate --model \"mlx-community/Mistral-7B-Instruct-v0.3-4bit\" \\\n", "!mlx_lm.generate --model \"mlx-community/Mistral-7B-Instruct-v0.3-4bit\" \\\n",
" --prompt \"Write a quick sort in Swift\" \\\n", " --prompt \"Write a quick sort in Swift\" \\\n",
@@ -100,10 +170,62 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 5,
"id": "e042a321", "id": "e042a321",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cf86d861ac5e4879a194bfbc3f0e908d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Fetching 7 files: 0%| | 0/7 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========\n",
"Here's a simple implementation of the QuickSort algorithm in Swift. This version uses Swift's built-in `swapAt()` function to swap elements in an array.\n",
"\n",
"```swift\n",
"func quickSort(_ array: inout [Int], _ low: Int, _ high: Int) {\n",
" if low < high {\n",
" let pivotIndex = partition(array, low, high)\n",
" quickSort(&array, low, pivotIndex - 1)\n",
" quickSort(&array, pivotIndex + 1, high)\n",
" }\n",
"}\n",
"\n",
"func partition(_ array: inout [Int], _ low: Int, _ high: Int) -> Int {\n",
" let pivot = array[high]\n",
" var i = low\n",
" for j in low..<high {\n",
" if array[j] < pivot {\n",
" swapAt(&array, i, j)\n",
" i += 1\n",
" }\n",
" }\n",
" swapAt(&array, i, high)\n",
" return i\n",
"}\n",
"\n",
"func swapAt(_ array: inout [Int], _ i: Int, _ j: Int) {\n",
" let temp\n",
"==========\n",
"Prompt: 12 tokens, 78.600 tokens-per-sec\n",
"Generation: 256 tokens, 31.893 tokens-per-sec\n",
"Peak memory: 4.184 GB\n"
]
}
],
"source": [ "source": [
"# Using MLX LM from Python\n", "# Using MLX LM from Python\n",
"\n", "\n",
@@ -133,10 +255,25 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 6,
"id": "629dfa50", "id": "629dfa50",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6602e0adecbe4b58ba99e514fc9c9032",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Fetching 7 files: 0%| | 0/7 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [ "source": [
"from mlx_lm import load, generate\n", "from mlx_lm import load, generate\n",
"\n", "\n",
@@ -165,10 +302,24 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 9,
"id": "a3b56bdc", "id": "a3b56bdc",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Attention(\n",
" (q_proj): QuantizedLinear(input_dims=4096, output_dims=4096, bias=False, group_size=64, bits=4)\n",
" (k_proj): QuantizedLinear(input_dims=4096, output_dims=1024, bias=False, group_size=64, bits=4)\n",
" (v_proj): QuantizedLinear(input_dims=4096, output_dims=1024, bias=False, group_size=64, bits=4)\n",
" (o_proj): QuantizedLinear(input_dims=4096, output_dims=4096, bias=False, group_size=64, bits=4)\n",
" (rope): RoPE(128, traditional=False)\n",
")\n"
]
}
],
"source": [ "source": [
"print(model.layers[0].self_attn)" "print(model.layers[0].self_attn)"
] ]
@@ -183,10 +334,62 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 10,
"id": "775fd3f3", "id": "775fd3f3",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ed3f0f097da64b379819f577a29dc9f6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Fetching 7 files: 0%| | 0/7 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========\n",
"Here's a simple implementation of the QuickSort algorithm in Swift. This version uses Swift's built-in `swapAt()` function to swap elements in an array.\n",
"\n",
"```swift\n",
"func quickSort(_ array: inout [Int], _ low: Int, _ high: Int) {\n",
" if low < high {\n",
" let pivotIndex = partition(array, low, high)\n",
" quickSort(&array, low, pivotIndex - 1)\n",
" quickSort(&array, pivotIndex + 1, high)\n",
" }\n",
"}\n",
"\n",
"func partition(_ array: inout [Int], _ low: Int, _ high: Int) -> Int {\n",
" let pivot = array[high]\n",
" var i = low\n",
" for j in low..<high {\n",
" if array[j] < pivot {\n",
" swapAt(&array, i, j)\n",
" i += 1\n",
" }\n",
" }\n",
" swapAt(&array, i, high)\n",
" return i\n",
"}\n",
"\n",
"func swapAt(_ array: inout [Int], _ i: Int, _ j: Int) {\n",
" let temp\n",
"==========\n",
"Prompt: 12 tokens, 76.085 tokens-per-sec\n",
"Generation: 256 tokens, 31.792 tokens-per-sec\n",
"Peak memory: 8.155 GB\n"
]
}
],
"source": [ "source": [
"from mlx_lm import load, generate\n", "from mlx_lm import load, generate\n",
"from mlx_lm.models.cache import make_prompt_cache\n", "from mlx_lm.models.cache import make_prompt_cache\n",
@@ -217,10 +420,30 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 11,
"id": "0d669073", "id": "0d669073",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========\n",
"Imagine you have a big box full of toys. You want to sort them so that all the red toys are together, all the blue toys are together, and all the green toys are together.\n",
"\n",
"1. First, you pick one toy (this is your pivot toy).\n",
"2. Then, you look at all the other toys one by one. If a toy is not red, you move it to the left if it's not red, and if it's blue, you move it to the right. You keep doing this until you have looked at all the toys.\n",
"3. Now, you have a group of toys on the left that are red or blue, and a group of toys on the right that are green or blue. You swap the pivot toy with one of the toys in the group on the left or right, depending on whether you want red toys on the left or right.\n",
"4. Now, you repeat the same process with the group of toys on the left and the group of toys on the right, until all the toys are sorted!\n",
"\n",
"This is a quick way to sort a big box of toys, and it's called QuickSort!\n",
"==========\n",
"Prompt: 16 tokens, 116.542 tokens-per-sec\n",
"Generation: 245 tokens, 29.632 tokens-per-sec\n",
"Peak memory: 8.155 GB\n"
]
}
],
"source": [ "source": [
"prompt = \"how can I explain it to a five year old?\"\n", "prompt = \"how can I explain it to a five year old?\"\n",
"messages = [{\"role\": \"user\", \"content\": prompt}]\n", "messages = [{\"role\": \"user\", \"content\": prompt}]\n",
@@ -245,10 +468,22 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 12,
"id": "f8218994", "id": "f8218994",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[INFO] Loading\n",
"Fetching 9 files: 100%|███████████████████████| 9/9 [00:00<00:00, 161319.38it/s]\n",
"[INFO] Using dtype: float16\n",
"[INFO] Quantizing\n",
"[INFO] Quantized model with 4.500 bits per weight.\n"
]
}
],
"source": [ "source": [
"import os\n", "import os\n",
"mlx_path=\"./mistral-7b-v0.3-4bit\"\n", "mlx_path=\"./mistral-7b-v0.3-4bit\"\n",
@@ -261,10 +496,24 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 13,
"id": "d4e62b96", "id": "d4e62b96",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Size of original bfloat16 model\n",
"===============================\n",
"3.8174 GB\n",
"\n",
"Size of quantized model\n",
"===============================\n",
"13.5049 GB\n"
]
}
],
"source": [ "source": [
"import subprocess\n", "import subprocess\n",
"\n", "\n",
@@ -274,13 +523,12 @@
" size_gb = size_mb / 1024\n", " size_gb = size_mb / 1024\n",
" return size_gb\n", " return size_gb\n",
"\n", "\n",
"\n", "directory_path = './mistral-7b-v0.3-4bit'\n",
"directory_path = os.path.expanduser('~/.cache/huggingface/hub/models--mlx-community--Mistral-7B-Instruct-v0.3')\n",
"print(\"Size of original bfloat16 model\")\n", "print(\"Size of original bfloat16 model\")\n",
"print(\"===============================\")\n", "print(\"===============================\")\n",
"print(f\"{get_directory_size_mb(directory_path):2.4f} GB\")\n", "print(f\"{get_directory_size_mb(directory_path):2.4f} GB\")\n",
"print()\n", "print()\n",
"directory_path = './mistral-7b-v0.3-4bit'\n", "directory_path = os.path.expanduser('~/.cache/huggingface/hub/models--mlx-community--Mistral-7B-Instruct-v0.3')\n",
"print(\"Size of quantized model\")\n", "print(\"Size of quantized model\")\n",
"print(\"===============================\")\n", "print(\"===============================\")\n",
"print(f\"{get_directory_size_mb(directory_path):2.4f} GB\")" "print(f\"{get_directory_size_mb(directory_path):2.4f} GB\")"
@@ -296,10 +544,79 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 14,
"id": "9d2cd325", "id": "9d2cd325",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[INFO] Loading\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "15c0aac6b06b4541ab1d5d20f5c5a255",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Fetching 12 files: 0%| | 0/12 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ba81f7892189428cadcfed19173a0731",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"consolidated.safetensors: 42%|####1 | 10.5G/25.0G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[INFO] Using dtype: bfloat16\n",
"[INFO] Quantizing\n",
"[INFO] Quantized model with 4.574 bits per weight.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "60442f8dcab848ef9771a8e2b5516a13",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"README.md: 0%| | 0.00/7.82k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Size of mixed 4-6-bit quantized model\n",
"============================\n",
"3.8799 GB\n"
]
}
],
"source": [ "source": [
"# Model quantization with MLX LM in Python\n", "# Model quantization with MLX LM in Python\n",
"\n", "\n",
@@ -341,10 +658,23 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 15,
"id": "5efb794d", "id": "5efb794d",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========\n",
"The latest Super Bowl, Super Bowl LV (55), was played on February 7, 2021, between the Kansas City Chiefs and the Tampa Bay Buccaneers. The Tampa Bay Buccaneers, led by quarterback Tom Brady, won the game, making it his seventh Super Bowl victory. This made Tom Brady the most successful quarterback in Super Bowl history.\n",
"==========\n",
"Prompt: 11 tokens, 8.131 tokens-per-sec\n",
"Generation: 87 tokens, 31.385 tokens-per-sec\n",
"Peak memory: 4.137 GB\n"
]
}
],
"source": [ "source": [
"!mlx_lm.generate --model \"./mistral-7b-v0.3-4bit\" \\\n", "!mlx_lm.generate --model \"./mistral-7b-v0.3-4bit\" \\\n",
" --prompt \"Who played in the latest super bowl?\"" " --prompt \"Who played in the latest super bowl?\""
@@ -360,7 +690,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 16,
"id": "b4c31126", "id": "b4c31126",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -383,10 +713,23 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 20,
"id": "7dcf9874", "id": "7dcf9874",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========\n",
"In the latest Super Bowl, the Philadelphia Eagles soared to victory, claiming their championship title with a resounding 40-22 win over the Kansas City Chiefs. The Eagles' triumphant flight was led by their fearless leader, Jalen Hurts, who not only secured his place in the annals of Super Bowl history but also etched his name into the hearts of Eagles fans everywhere. This wasn't just any Super Bowl; it was Super Bowl\n",
"==========\n",
"Prompt: 11 tokens, 28.533 tokens-per-sec\n",
"Generation: 100 tokens, 30.986 tokens-per-sec\n",
"Peak memory: 4.151 GB\n"
]
}
],
"source": [ "source": [
"!mlx_lm.generate --model \"./mistral-7b-v0.3-4bit\" \\\n", "!mlx_lm.generate --model \"./mistral-7b-v0.3-4bit\" \\\n",
" --prompt \"Who played in the latest super bowl?\" \\\n", " --prompt \"Who played in the latest super bowl?\" \\\n",
@@ -403,10 +746,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 21,
"id": "8935f7b6", "id": "8935f7b6",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading pretrained model\n"
]
}
],
"source": [ "source": [
"!mlx_lm.fuse --model \"./mistral-7b-v0.3-4bit\" \\\n", "!mlx_lm.fuse --model \"./mistral-7b-v0.3-4bit\" \\\n",
" --adapter-path \"adapters\" \\\n", " --adapter-path \"adapters\" \\\n",
@@ -423,10 +774,23 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 22,
"id": "343a8977", "id": "343a8977",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========\n",
"The latest Super Bowl, Super Bowl LIX, was played between the Philadelphia Eagles and the Kansas City Chiefs. The Philadelphia Eagles emerged victorious, with Jalen Hurts leading the charge for the Eagles.\n",
"==========\n",
"Prompt: 11 tokens, 11.760 tokens-per-sec\n",
"Generation: 46 tokens, 32.194 tokens-per-sec\n",
"Peak memory: 4.137 GB\n"
]
}
],
"source": [ "source": [
"!mlx_lm.generate --model \"./fused-mistral-7b-v0.3-4bit\" \\\n", "!mlx_lm.generate --model \"./fused-mistral-7b-v0.3-4bit\" \\\n",
" --prompt \"Who played in the latest super bowl?\" \\\n", " --prompt \"Who played in the latest super bowl?\" \\\n",
@@ -436,7 +800,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "mlx",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
@@ -450,7 +814,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.17" "version": "3.12.9"
} }
}, },
"nbformat": 4, "nbformat": 4,