mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Compare commits
4 Commits
3f18394321
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e52c128d11 | ||
|
|
7ddca42f4d | ||
|
|
21a4d4cdab | ||
|
|
8e4391ca21 |
@@ -1,40 +0,0 @@
|
||||
version: 2.1
|
||||
|
||||
orbs:
|
||||
apple: ml-explore/pr-approval@0.1.0
|
||||
|
||||
jobs:
|
||||
linux_build_and_test:
|
||||
docker:
|
||||
- image: cimg/python:3.9
|
||||
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Run style checks
|
||||
command: |
|
||||
pip install pre-commit
|
||||
pre-commit run --all
|
||||
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
|
||||
|
||||
workflows:
|
||||
build_and_test:
|
||||
when:
|
||||
matches:
|
||||
pattern: "^(?!pull/)[-\\w]+$"
|
||||
value: << pipeline.git.branch >>
|
||||
jobs:
|
||||
- linux_build_and_test
|
||||
|
||||
prb:
|
||||
when:
|
||||
matches:
|
||||
pattern: "^pull/\\d+(/head)?$"
|
||||
value: << pipeline.git.branch >>
|
||||
jobs:
|
||||
- hold:
|
||||
type: approval
|
||||
- apple/authenticate:
|
||||
context: pr-approval
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
25
.github/workflows/pull_request.yml
vendored
Normal file
25
.github/workflows/pull_request.yml
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
name: Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/head/main' }}
|
||||
|
||||
jobs:
|
||||
check_lint:
|
||||
if: github.repository == 'ml-explore/mlx-examples'
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 208 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 142 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 112 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 101 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 79 KiB |
195
gan/main.py
195
gan/main.py
@@ -1,195 +0,0 @@
|
||||
import mnist
|
||||
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Generator Block
|
||||
def GenBlock(in_dim:int,out_dim:int):
|
||||
return nn.Sequential(
|
||||
nn.Linear(in_dim,out_dim),
|
||||
nn.BatchNorm(out_dim, 0.8),
|
||||
nn.LeakyReLU(0.2)
|
||||
)
|
||||
|
||||
# Generator Model
|
||||
class Generator(nn.Module):
|
||||
|
||||
def __init__(self, z_dim:int = 32, im_dim:int = 784, hidden_dim: int = 256):
|
||||
super(Generator, self).__init__()
|
||||
|
||||
self.gen = nn.Sequential(
|
||||
GenBlock(z_dim, hidden_dim),
|
||||
GenBlock(hidden_dim, hidden_dim * 2),
|
||||
GenBlock(hidden_dim * 2, hidden_dim * 4),
|
||||
|
||||
nn.Linear(hidden_dim * 4,im_dim),
|
||||
)
|
||||
|
||||
def __call__(self, noise):
|
||||
x = self.gen(noise)
|
||||
return mx.tanh(x)
|
||||
|
||||
# make 2D noise with shape n_samples x z_dim
|
||||
def get_noise(n_samples:list[int], z_dim:int)->list[int]:
|
||||
return mx.random.normal(shape=(n_samples, z_dim))
|
||||
|
||||
#---------------------------------------------#
|
||||
|
||||
# Discriminator Block
|
||||
def DisBlock(in_dim:int,out_dim:int):
|
||||
return nn.Sequential(
|
||||
nn.Linear(in_dim,out_dim),
|
||||
nn.LeakyReLU(negative_slope=0.2),
|
||||
nn.Dropout(0.3),
|
||||
)
|
||||
|
||||
# Discriminator Model
|
||||
class Discriminator(nn.Module):
|
||||
|
||||
def __init__(self,im_dim:int = 784, hidden_dim:int = 256):
|
||||
super(Discriminator, self).__init__()
|
||||
|
||||
self.disc = nn.Sequential(
|
||||
DisBlock(im_dim, hidden_dim * 4),
|
||||
DisBlock(hidden_dim * 4, hidden_dim * 2),
|
||||
DisBlock(hidden_dim * 2, hidden_dim),
|
||||
|
||||
nn.Linear(hidden_dim,1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def __call__(self, noise):
|
||||
return self.disc(noise)
|
||||
|
||||
# Discriminator Loss
|
||||
def disc_loss(gen, disc, real, num_images, z_dim):
|
||||
|
||||
noise = mx.array(get_noise(num_images, z_dim))
|
||||
fake_images = gen(noise)
|
||||
|
||||
fake_disc = disc(fake_images)
|
||||
|
||||
fake_labels = mx.zeros((fake_images.shape[0],1))
|
||||
|
||||
fake_loss = mx.mean(nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True))
|
||||
|
||||
real_disc = mx.array(disc(real))
|
||||
real_labels = mx.ones((real.shape[0],1))
|
||||
|
||||
real_loss = mx.mean(nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True))
|
||||
|
||||
disc_loss = (fake_loss + real_loss) / 2.0
|
||||
|
||||
return disc_loss
|
||||
|
||||
# Genearator Loss
|
||||
def gen_loss(gen, disc, num_images, z_dim):
|
||||
|
||||
noise = mx.array(get_noise(num_images, z_dim))
|
||||
|
||||
fake_images = gen(noise)
|
||||
fake_disc = mx.array(disc(fake_images))
|
||||
|
||||
fake_labels = mx.ones((fake_images.shape[0],1))
|
||||
|
||||
gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)
|
||||
|
||||
return mx.mean(gen_loss)
|
||||
|
||||
# make batch of images
|
||||
def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:
|
||||
perm = np.random.permutation(len(ipt))
|
||||
for s in range(0, len(ipt), batch_size):
|
||||
ids = perm[s : s + batch_size]
|
||||
yield ipt[ids]
|
||||
|
||||
# plot batch of images at epoch steps
|
||||
def show_images(epoch_num:int,imgs:list[int],num_imgs:int = 25):
|
||||
if (imgs.shape[0] > 0):
|
||||
fig,axes = plt.subplots(5, 5, figsize=(5, 5))
|
||||
|
||||
for i, ax in enumerate(axes.flat):
|
||||
img = mx.array(imgs[i]).reshape(28,28)
|
||||
ax.imshow(img,cmap='gray')
|
||||
ax.axis('off')
|
||||
plt.tight_layout()
|
||||
plt.savefig('gen_images/img_{}.png'.format(epoch_num))
|
||||
plt.show()
|
||||
|
||||
def main(args:dict):
|
||||
seed = 42
|
||||
n_epochs = 500
|
||||
z_dim = 128
|
||||
batch_size = 128
|
||||
lr = 2e-5
|
||||
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Load the data
|
||||
train_images,*_ = map(np.array, getattr(mnist,'mnist')())
|
||||
|
||||
# Normalization images => [-1,1]
|
||||
train_images = train_images * 2.0 - 1.0
|
||||
|
||||
gen = Generator(z_dim)
|
||||
mx.eval(gen.parameters())
|
||||
gen_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])
|
||||
|
||||
disc = Discriminator()
|
||||
mx.eval(disc.parameters())
|
||||
disc_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])
|
||||
|
||||
# TODO training...
|
||||
|
||||
D_loss_grad = nn.value_and_grad(disc, disc_loss)
|
||||
G_loss_grad = nn.value_and_grad(gen, gen_loss)
|
||||
|
||||
for epoch in tqdm(range(n_epochs)):
|
||||
|
||||
for idx,real in enumerate(batch_iterate(batch_size, train_images)):
|
||||
|
||||
# TODO Train Discriminator
|
||||
D_loss,D_grads = D_loss_grad(gen, disc,mx.array(real), batch_size, z_dim)
|
||||
|
||||
# Update optimizer
|
||||
disc_opt.update(disc, D_grads)
|
||||
|
||||
# Update gradients
|
||||
mx.eval(disc.parameters(), disc_opt.state)
|
||||
|
||||
# TODO Train Generator
|
||||
G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)
|
||||
|
||||
# Update optimizer
|
||||
gen_opt.update(gen, G_grads)
|
||||
|
||||
# Update gradients
|
||||
mx.eval(gen.parameters(), gen_opt.state)
|
||||
|
||||
if epoch%100==0:
|
||||
print("Epoch: {}, iteration: {}, Discriminator Loss:{}, Generator Loss: {}".format(epoch,idx,D_loss,G_loss))
|
||||
fake_noise = mx.array(get_noise(batch_size, z_dim))
|
||||
fake = gen(fake_noise)
|
||||
show_images(epoch,fake)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Train a simple GAN on MNIST with MLX.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="mnist",
|
||||
choices=["mnist", "fashion_mnist"],
|
||||
help="The dataset to use.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if not args.gpu:
|
||||
mx.set_default_device(mx.cpu)
|
||||
main(args)
|
||||
83
gan/mnist.py
83
gan/mnist.py
@@ -1,83 +0,0 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import gzip
|
||||
import os
|
||||
import pickle
|
||||
from urllib import request
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def mnist(
|
||||
save_dir="/tmp",
|
||||
base_url="https://raw.githubusercontent.com/fgnt/mnist/master/",
|
||||
filename="mnist.pkl",
|
||||
):
|
||||
"""
|
||||
Load the MNIST dataset in 4 tensors: train images, train labels,
|
||||
test images, and test labels.
|
||||
|
||||
Checks `save_dir` for already downloaded data otherwise downloads.
|
||||
|
||||
Download code modified from:
|
||||
https://github.com/hsjeong5/MNIST-for-Numpy
|
||||
"""
|
||||
|
||||
def download_and_save(save_file):
|
||||
filename = [
|
||||
["training_images", "train-images-idx3-ubyte.gz"],
|
||||
["test_images", "t10k-images-idx3-ubyte.gz"],
|
||||
["training_labels", "train-labels-idx1-ubyte.gz"],
|
||||
["test_labels", "t10k-labels-idx1-ubyte.gz"],
|
||||
]
|
||||
|
||||
mnist = {}
|
||||
for name in filename:
|
||||
out_file = os.path.join("/tmp", name[1])
|
||||
request.urlretrieve(base_url + name[1], out_file)
|
||||
for name in filename[:2]:
|
||||
out_file = os.path.join("/tmp", name[1])
|
||||
with gzip.open(out_file, "rb") as f:
|
||||
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(
|
||||
-1, 28 * 28
|
||||
)
|
||||
for name in filename[-2:]:
|
||||
out_file = os.path.join("/tmp", name[1])
|
||||
with gzip.open(out_file, "rb") as f:
|
||||
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||
with open(save_file, "wb") as f:
|
||||
pickle.dump(mnist, f)
|
||||
|
||||
save_file = os.path.join(save_dir, filename)
|
||||
if not os.path.exists(save_file):
|
||||
download_and_save(save_file)
|
||||
with open(save_file, "rb") as f:
|
||||
mnist = pickle.load(f)
|
||||
|
||||
def preproc(x):
|
||||
return x.astype(np.float32) / 255.0
|
||||
|
||||
mnist["training_images"] = preproc(mnist["training_images"])
|
||||
mnist["test_images"] = preproc(mnist["test_images"])
|
||||
return (
|
||||
mnist["training_images"],
|
||||
mnist["training_labels"].astype(np.uint32),
|
||||
mnist["test_images"],
|
||||
mnist["test_labels"].astype(np.uint32),
|
||||
)
|
||||
|
||||
|
||||
def fashion_mnist(save_dir="/tmp"):
|
||||
return mnist(
|
||||
save_dir,
|
||||
base_url="http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/",
|
||||
filename="fashion_mnist.pkl",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_x, train_y, test_x, test_y = mnist()
|
||||
assert train_x.shape == (60000, 28 * 28), "Wrong training set size"
|
||||
assert train_y.shape == (60000,), "Wrong training set size"
|
||||
assert test_x.shape == (10000, 28 * 28), "Wrong test set size"
|
||||
assert test_y.shape == (10000,), "Wrong test set size"
|
||||
File diff suppressed because one or more lines are too long
@@ -11,12 +11,6 @@ audio_file = "mlx_whisper/assets/ls_test.flac"
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="Benchmark script.")
|
||||
parser.add_argument(
|
||||
"--mlx-dir",
|
||||
type=str,
|
||||
default="mlx_models",
|
||||
help="The folder of MLX models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--all",
|
||||
action="store_true",
|
||||
|
||||
@@ -382,7 +382,7 @@ if __name__ == "__main__":
|
||||
|
||||
# Save weights
|
||||
print("[INFO] Saving")
|
||||
mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights)
|
||||
mx.save_safetensors(str(mlx_path / "model.safetensors"), weights)
|
||||
|
||||
# Save config.json with model_type
|
||||
with open(str(mlx_path / "config.json"), "w") as f:
|
||||
|
||||
@@ -156,42 +156,42 @@ def build_parser():
|
||||
"--prepend-punctuations",
|
||||
type=str,
|
||||
default="\"'“¿([{-",
|
||||
help="If word-timestamps is True, merge these punctuation symbols with the next word",
|
||||
help="If --word-timestamps is True, merge these punctuation symbols with the next word",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--append-punctuations",
|
||||
type=str,
|
||||
default="\"'.。,,!!??::”)]}、",
|
||||
help="If word_timestamps is True, merge these punctuation symbols with the previous word",
|
||||
help="If --word-timestamps is True, merge these punctuation symbols with the previous word",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highlight-words",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt",
|
||||
help="(requires --word-timestamps True) underline each word as it is spoken in srt and vtt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-line-width",
|
||||
type=int,
|
||||
default=None,
|
||||
help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line",
|
||||
help="(requires --word-timestampss True) the maximum number of characters in a line before breaking the line",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-line-count",
|
||||
type=int,
|
||||
default=None,
|
||||
help="(requires --word_timestamps True) the maximum number of lines in a segment",
|
||||
help="(requires --word-timestamps True) the maximum number of lines in a segment",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-words-per-line",
|
||||
type=int,
|
||||
default=None,
|
||||
help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment",
|
||||
help="(requires --word-timestamps True, no effect with --max-line-width) the maximum number of words in a segment",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hallucination-silence-threshold",
|
||||
type=optional_float,
|
||||
help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected",
|
||||
help="(requires --word-timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip-timestamps",
|
||||
|
||||
@@ -26,6 +26,9 @@ def load_model(
|
||||
|
||||
model_args = whisper.ModelDimensions(**config)
|
||||
|
||||
# Prefer model.safetensors, fall back to weights.safetensors, then weights.npz
|
||||
wf = model_path / "model.safetensors"
|
||||
if not wf.exists():
|
||||
wf = model_path / "weights.safetensors"
|
||||
if not wf.exists():
|
||||
wf = model_path / "weights.npz"
|
||||
|
||||
@@ -62,7 +62,7 @@ class ModelHolder:
|
||||
def transcribe(
|
||||
audio: Union[str, np.ndarray, mx.array],
|
||||
*,
|
||||
path_or_hf_repo: str = "mlx-community/whisper-tiny",
|
||||
path_or_hf_repo: str = "mlx-community/whisper-turbo",
|
||||
verbose: Optional[bool] = None,
|
||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||
compression_ratio_threshold: Optional[float] = 2.4,
|
||||
|
||||
Reference in New Issue
Block a user