mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Update a few examples to use compile (#420)
* update a few examples to use compile * update mnist * add compile to vae and rename some stuff for simplicity * update reqs * use state in eval * GCN example with RNG + dropout * add a bit of prefetching
This commit is contained in:
parent
da7adae5ec
commit
f45a1ab83c
@ -1,14 +1,12 @@
|
|||||||
import math
|
import numpy as np
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from mlx.data.datasets import load_cifar10
|
from mlx.data.datasets import load_cifar10
|
||||||
|
|
||||||
|
|
||||||
def get_cifar10(batch_size, root=None):
|
def get_cifar10(batch_size, root=None):
|
||||||
tr = load_cifar10(root=root)
|
tr = load_cifar10(root=root)
|
||||||
|
|
||||||
mean = mx.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
|
mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
|
||||||
std = mx.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
|
std = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
|
||||||
|
|
||||||
def normalize(x):
|
def normalize(x):
|
||||||
x = x.astype("float32") / 255.0
|
x = x.astype("float32") / 255.0
|
||||||
@ -23,6 +21,7 @@ def get_cifar10(batch_size, root=None):
|
|||||||
.image_random_crop("image", 32, 32)
|
.image_random_crop("image", 32, 32)
|
||||||
.key_transform("image", normalize)
|
.key_transform("image", normalize)
|
||||||
.batch(batch_size)
|
.batch(batch_size)
|
||||||
|
.prefetch(4, 4)
|
||||||
)
|
)
|
||||||
|
|
||||||
test = load_cifar10(root=root, train=False)
|
test = load_cifar10(root=root, train=False)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -33,19 +34,25 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
|||||||
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
|
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
|
||||||
return loss, acc
|
return loss, acc
|
||||||
|
|
||||||
train_step_fn = nn.value_and_grad(model, train_step)
|
|
||||||
|
|
||||||
losses = []
|
losses = []
|
||||||
accs = []
|
accs = []
|
||||||
samples_per_sec = []
|
samples_per_sec = []
|
||||||
|
|
||||||
|
state = [model.state, optimizer.state]
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
|
def step(inp, tgt):
|
||||||
|
train_step_fn = nn.value_and_grad(model, train_step)
|
||||||
|
(loss, acc), grads = train_step_fn(model, inp, tgt)
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
return loss, acc
|
||||||
|
|
||||||
for batch_counter, batch in enumerate(train_iter):
|
for batch_counter, batch in enumerate(train_iter):
|
||||||
x = mx.array(batch["image"])
|
x = mx.array(batch["image"])
|
||||||
y = mx.array(batch["label"])
|
y = mx.array(batch["label"])
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
(loss, acc), grads = train_step_fn(model, x, y)
|
loss, acc = step(x, y)
|
||||||
optimizer.update(model, grads)
|
mx.eval(state)
|
||||||
mx.eval(model.parameters(), optimizer.state)
|
|
||||||
toc = time.perf_counter()
|
toc = time.perf_counter()
|
||||||
loss = loss.item()
|
loss = loss.item()
|
||||||
acc = acc.item()
|
acc = acc.item()
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
mlx>=0.0.9
|
mlx>=0.2
|
||||||
mlx-data
|
mlx-data
|
||||||
numpy
|
numpy
|
||||||
|
@ -23,6 +23,7 @@ def mnist(batch_size, img_size, root=None):
|
|||||||
.image_resize("image", h=img_size[0], w=img_size[1])
|
.image_resize("image", h=img_size[0], w=img_size[1])
|
||||||
.key_transform("image", normalize)
|
.key_transform("image", normalize)
|
||||||
.batch(batch_size)
|
.batch(batch_size)
|
||||||
|
.prefetch(4, 4)
|
||||||
)
|
)
|
||||||
|
|
||||||
# iterator over test set
|
# iterator over test set
|
||||||
|
104
cvae/main.py
104
cvae/main.py
@ -2,14 +2,15 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import dataset
|
import dataset
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import mlx.optimizers as optim
|
import mlx.optimizers as optim
|
||||||
import model
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import vae
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
@ -53,44 +54,6 @@ def loss_fn(model, X):
|
|||||||
return recon_loss + kl_div
|
return recon_loss + kl_div
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(model, data, optimizer, epoch):
|
|
||||||
loss_acc = 0.0
|
|
||||||
throughput_acc = 0.0
|
|
||||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
|
||||||
|
|
||||||
# Iterate over training batches
|
|
||||||
for batch_count, batch in enumerate(data):
|
|
||||||
X = mx.array(batch["image"])
|
|
||||||
|
|
||||||
throughput_tic = time.perf_counter()
|
|
||||||
|
|
||||||
# Forward pass + backward pass + update
|
|
||||||
loss, grads = loss_and_grad_fn(model, X)
|
|
||||||
optimizer.update(model, grads)
|
|
||||||
|
|
||||||
# Evaluate updated model parameters
|
|
||||||
mx.eval(model.parameters(), optimizer.state)
|
|
||||||
|
|
||||||
throughput_toc = time.perf_counter()
|
|
||||||
throughput_acc += X.shape[0] / (throughput_toc - throughput_tic)
|
|
||||||
loss_acc += loss.item()
|
|
||||||
|
|
||||||
if batch_count > 0 and (batch_count % 10 == 0):
|
|
||||||
print(
|
|
||||||
" | ".join(
|
|
||||||
[
|
|
||||||
f"Epoch {epoch:4d}",
|
|
||||||
f"Loss {(loss_acc / batch_count):10.2f}",
|
|
||||||
f"Throughput {(throughput_acc / batch_count):8.2f} im/s",
|
|
||||||
f"Batch {batch_count:5d}",
|
|
||||||
]
|
|
||||||
),
|
|
||||||
end="\r",
|
|
||||||
)
|
|
||||||
|
|
||||||
return loss_acc, throughput_acc, batch_count
|
|
||||||
|
|
||||||
|
|
||||||
def reconstruct(model, batch, out_file):
|
def reconstruct(model, batch, out_file):
|
||||||
# Reconstruct a single batch only
|
# Reconstruct a single batch only
|
||||||
images = mx.array(batch["image"])
|
images = mx.array(batch["image"])
|
||||||
@ -127,10 +90,10 @@ def main(args):
|
|||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Load the model
|
# Load the model
|
||||||
vae = model.CVAE(args.latent_dims, img_size, args.max_filters)
|
model = vae.CVAE(args.latent_dims, img_size, args.max_filters)
|
||||||
mx.eval(vae.parameters())
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
num_params = sum(x.size for _, x in tree_flatten(vae.trainable_parameters()))
|
num_params = sum(x.size for _, x in tree_flatten(model.trainable_parameters()))
|
||||||
print("Number of trainable params: {:0.04f} M".format(num_params / 1e6))
|
print("Number of trainable params: {:0.04f} M".format(num_params / 1e6))
|
||||||
|
|
||||||
optimizer = optim.AdamW(learning_rate=args.lr)
|
optimizer = optim.AdamW(learning_rate=args.lr)
|
||||||
@ -139,19 +102,53 @@ def main(args):
|
|||||||
train_batch = next(train_iter)
|
train_batch = next(train_iter)
|
||||||
test_batch = next(test_iter)
|
test_batch = next(test_iter)
|
||||||
|
|
||||||
|
state = [model.state, optimizer.state]
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
|
def step(X):
|
||||||
|
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||||
|
loss, grads = loss_and_grad_fn(model, X)
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
return loss
|
||||||
|
|
||||||
for e in range(1, args.epochs + 1):
|
for e in range(1, args.epochs + 1):
|
||||||
# Reset iterators and stats at the beginning of each epoch
|
# Reset iterators and stats at the beginning of each epoch
|
||||||
train_iter.reset()
|
train_iter.reset()
|
||||||
vae.train()
|
model.train()
|
||||||
|
|
||||||
# Train one epoch
|
# Train one epoch
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
loss_acc, throughput_acc, batch_count = train_epoch(
|
loss_acc = 0.0
|
||||||
vae, train_iter, optimizer, e
|
throughput_acc = 0.0
|
||||||
)
|
|
||||||
toc = time.perf_counter()
|
|
||||||
|
|
||||||
vae.eval()
|
# Iterate over training batches
|
||||||
|
for batch_count, batch in enumerate(train_iter):
|
||||||
|
X = mx.array(batch["image"])
|
||||||
|
throughput_tic = time.perf_counter()
|
||||||
|
|
||||||
|
# Forward pass + backward pass + update
|
||||||
|
loss = step(X)
|
||||||
|
|
||||||
|
# Evaluate updated model parameters
|
||||||
|
mx.eval(state)
|
||||||
|
|
||||||
|
throughput_toc = time.perf_counter()
|
||||||
|
throughput_acc += X.shape[0] / (throughput_toc - throughput_tic)
|
||||||
|
loss_acc += loss.item()
|
||||||
|
|
||||||
|
if batch_count > 0 and (batch_count % 10 == 0):
|
||||||
|
print(
|
||||||
|
" | ".join(
|
||||||
|
[
|
||||||
|
f"Epoch {e:4d}",
|
||||||
|
f"Loss {(loss_acc / batch_count):10.2f}",
|
||||||
|
f"Throughput {(throughput_acc / batch_count):8.2f} im/s",
|
||||||
|
f"Batch {batch_count:5d}",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
end="\r",
|
||||||
|
)
|
||||||
|
toc = time.perf_counter()
|
||||||
|
|
||||||
print(
|
print(
|
||||||
" | ".join(
|
" | ".join(
|
||||||
@ -163,14 +160,17 @@ def main(args):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
# Reconstruct a batch of training and test images
|
# Reconstruct a batch of training and test images
|
||||||
reconstruct(vae, train_batch, save_dir / f"train_{e:03d}.png")
|
reconstruct(model, train_batch, save_dir / f"train_{e:03d}.png")
|
||||||
reconstruct(vae, test_batch, save_dir / f"test_{e:03d}.png")
|
reconstruct(model, test_batch, save_dir / f"test_{e:03d}.png")
|
||||||
|
|
||||||
# Generate images
|
# Generate images
|
||||||
generate(vae, save_dir / f"generated_{e:03d}.png")
|
generate(model, save_dir / f"generated_{e:03d}.png")
|
||||||
|
|
||||||
vae.save_weights(str(save_dir / "weights.npz"))
|
model.save_weights(str(save_dir / "weights.npz"))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
mlx>=0.0.9
|
mlx>=0.2
|
||||||
mlx-data
|
mlx-data
|
||||||
numpy
|
numpy
|
||||||
Pillow
|
Pillow
|
||||||
|
25
gcn/main.py
25
gcn/main.py
@ -1,4 +1,6 @@
|
|||||||
|
import time
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -47,23 +49,31 @@ def main(args):
|
|||||||
mx.eval(gcn.parameters())
|
mx.eval(gcn.parameters())
|
||||||
|
|
||||||
optimizer = optim.Adam(learning_rate=args.lr)
|
optimizer = optim.Adam(learning_rate=args.lr)
|
||||||
loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)
|
|
||||||
|
state = [gcn.state, optimizer.state, mx.random.state]
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
|
def step():
|
||||||
|
loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)
|
||||||
|
(loss, y_hat), grads = loss_and_grad_fn(
|
||||||
|
gcn, x, adj, y, train_mask, args.weight_decay
|
||||||
|
)
|
||||||
|
optimizer.update(gcn, grads)
|
||||||
|
return loss, y_hat
|
||||||
|
|
||||||
best_val_loss = float("inf")
|
best_val_loss = float("inf")
|
||||||
cnt = 0
|
cnt = 0
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
for epoch in range(args.epochs):
|
for epoch in range(args.epochs):
|
||||||
# Loss
|
tic = time.time()
|
||||||
(loss, y_hat), grads = loss_and_grad_fn(
|
loss, y_hat = step()
|
||||||
gcn, x, adj, y, train_mask, args.weight_decay
|
mx.eval(state)
|
||||||
)
|
|
||||||
optimizer.update(gcn, grads)
|
|
||||||
mx.eval(gcn.parameters(), optimizer.state)
|
|
||||||
|
|
||||||
# Validation
|
# Validation
|
||||||
val_loss = loss_fn(y_hat[val_mask], y[val_mask])
|
val_loss = loss_fn(y_hat[val_mask], y[val_mask])
|
||||||
val_acc = eval_fn(y_hat[val_mask], y[val_mask])
|
val_acc = eval_fn(y_hat[val_mask], y[val_mask])
|
||||||
|
toc = time.time()
|
||||||
|
|
||||||
# Early stopping
|
# Early stopping
|
||||||
if val_loss < best_val_loss:
|
if val_loss < best_val_loss:
|
||||||
@ -81,6 +91,7 @@ def main(args):
|
|||||||
f"Train loss: {loss.item():.3f}",
|
f"Train loss: {loss.item():.3f}",
|
||||||
f"Val loss: {val_loss.item():.3f}",
|
f"Val loss: {val_loss.item():.3f}",
|
||||||
f"Val acc: {val_acc.item():.2f}",
|
f"Val acc: {val_acc.item():.2f}",
|
||||||
|
f"Time: {1e3*(toc - tic):.3f} (ms)",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
mlx
|
mlx>=0.1
|
||||||
numpy
|
numpy
|
||||||
transformers>=4.37.0
|
transformers>=4.37.0
|
||||||
protobuf
|
protobuf
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -34,10 +35,6 @@ def loss_fn(model, X, y):
|
|||||||
return nn.losses.cross_entropy(model(X), y, reduction="mean")
|
return nn.losses.cross_entropy(model(X), y, reduction="mean")
|
||||||
|
|
||||||
|
|
||||||
def eval_fn(model, X, y):
|
|
||||||
return mx.mean(mx.argmax(model(X), axis=1) == y)
|
|
||||||
|
|
||||||
|
|
||||||
def batch_iterate(batch_size, X, y):
|
def batch_iterate(batch_size, X, y):
|
||||||
perm = mx.array(np.random.permutation(y.size))
|
perm = mx.array(np.random.permutation(y.size))
|
||||||
for s in range(0, y.size, batch_size):
|
for s in range(0, y.size, batch_size):
|
||||||
@ -65,16 +62,25 @@ def main(args):
|
|||||||
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
|
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
|
||||||
optimizer = optim.SGD(learning_rate=learning_rate)
|
optimizer = optim.SGD(learning_rate=learning_rate)
|
||||||
|
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=model.state, outputs=model.state)
|
||||||
|
def step(X, y):
|
||||||
|
loss, grads = loss_and_grad_fn(model, X, y)
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=model.state)
|
||||||
|
def eval_fn(X, y):
|
||||||
|
return mx.mean(mx.argmax(model(X), axis=1) == y)
|
||||||
|
|
||||||
for e in range(num_epochs):
|
for e in range(num_epochs):
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
for X, y in batch_iterate(batch_size, train_images, train_labels):
|
for X, y in batch_iterate(batch_size, train_images, train_labels):
|
||||||
loss, grads = loss_and_grad_fn(model, X, y)
|
step(X, y)
|
||||||
optimizer.update(model, grads)
|
mx.eval(model.state)
|
||||||
mx.eval(model.parameters(), optimizer.state)
|
accuracy = eval_fn(test_images, test_labels)
|
||||||
accuracy = eval_fn(model, test_images, test_labels)
|
|
||||||
toc = time.perf_counter()
|
toc = time.perf_counter()
|
||||||
print(
|
print(
|
||||||
f"Epoch {e}: Test accuracy {accuracy.item():.3f},"
|
f"Epoch {e}: Test accuracy {accuracy.item():.3f},"
|
||||||
|
@ -1,2 +1,2 @@
|
|||||||
mlx
|
mlx>=0.2
|
||||||
numpy
|
numpy
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -27,18 +29,23 @@ def main(args):
|
|||||||
def loss_fn(model, x):
|
def loss_fn(model, x):
|
||||||
return -mx.mean(model(x))
|
return -mx.mean(model(x))
|
||||||
|
|
||||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
|
||||||
optimizer = optim.Adam(learning_rate=args.learning_rate)
|
optimizer = optim.Adam(learning_rate=args.learning_rate)
|
||||||
|
|
||||||
|
state = [model.state, optimizer.state]
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
|
def step(x):
|
||||||
|
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||||
|
loss, grads = loss_and_grad_fn(model, x)
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
return loss
|
||||||
|
|
||||||
with trange(args.n_steps) as steps:
|
with trange(args.n_steps) as steps:
|
||||||
for step in steps:
|
for it in steps:
|
||||||
idx = np.random.choice(x.shape[0], replace=False, size=args.n_batch)
|
idx = np.random.choice(x.shape[0], replace=False, size=args.n_batch)
|
||||||
loss, grads = loss_and_grad_fn(model, mx.array(x[idx]))
|
loss = step(mx.array(x[idx]))
|
||||||
|
mx.eval(state)
|
||||||
optimizer.update(model, grads)
|
steps.set_postfix(val=loss.item())
|
||||||
mx.eval(model.parameters())
|
|
||||||
|
|
||||||
steps.set_postfix(val=loss)
|
|
||||||
|
|
||||||
# Plot samples from trained flow
|
# Plot samples from trained flow
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
mlx
|
mlx>=0.2
|
||||||
numpy
|
numpy
|
||||||
tqdm
|
tqdm
|
||||||
scikit-learn
|
scikit-learn
|
||||||
matplotlib
|
matplotlib
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import kwt
|
import kwt
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@ -46,22 +47,30 @@ def prepare_dataset(batch_size, split, root=None):
|
|||||||
.key_transform("audio", normalize)
|
.key_transform("audio", normalize)
|
||||||
.shuffle()
|
.shuffle()
|
||||||
.batch(batch_size)
|
.batch(batch_size)
|
||||||
|
.to_stream()
|
||||||
|
.prefetch(4, 4)
|
||||||
)
|
)
|
||||||
return data_iter
|
return data_iter
|
||||||
|
|
||||||
|
|
||||||
def eval_fn(model, inp, tgt):
|
def eval_fn(model, x, y):
|
||||||
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
|
return mx.mean(mx.argmax(model(x), axis=1) == y)
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(model, train_iter, optimizer, epoch):
|
def train_epoch(model, train_iter, optimizer, epoch):
|
||||||
def train_step(model, inp, tgt):
|
def train_step(model, x, y):
|
||||||
output = model(inp)
|
output = model(x)
|
||||||
loss = mx.mean(nn.losses.cross_entropy(output, tgt))
|
loss = mx.mean(nn.losses.cross_entropy(output, y))
|
||||||
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
|
acc = mx.mean(mx.argmax(output, axis=1) == y)
|
||||||
return loss, acc
|
return loss, acc
|
||||||
|
|
||||||
train_step_fn = nn.value_and_grad(model, train_step)
|
state = [model.state, optimizer.state]
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
|
def step(x, y):
|
||||||
|
(loss, acc), grads = nn.value_and_grad(model, train_step)(model, x, y)
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
return loss, acc
|
||||||
|
|
||||||
losses = []
|
losses = []
|
||||||
accs = []
|
accs = []
|
||||||
@ -72,9 +81,8 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
|||||||
x = mx.array(batch["audio"])
|
x = mx.array(batch["audio"])
|
||||||
y = mx.array(batch["label"])
|
y = mx.array(batch["label"])
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
(loss, acc), grads = train_step_fn(model, x, y)
|
loss, acc = step(x, y)
|
||||||
optimizer.update(model, grads)
|
mx.eval(state)
|
||||||
mx.eval(model.parameters(), optimizer.state)
|
|
||||||
toc = time.perf_counter()
|
toc = time.perf_counter()
|
||||||
loss = loss.item()
|
loss = loss.item()
|
||||||
acc = acc.item()
|
acc = acc.item()
|
||||||
|
@ -1,2 +1,2 @@
|
|||||||
mlx
|
mlx>=0.2
|
||||||
mlx-data
|
mlx-data
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@ -37,11 +38,6 @@ class TransformerLM(nn.Module):
|
|||||||
x = self.transformer(x, mask)
|
x = self.transformer(x, mask)
|
||||||
return self.out_proj(x)
|
return self.out_proj(x)
|
||||||
|
|
||||||
def loss(self, x, y, reduce=True):
|
|
||||||
logits = self(x)
|
|
||||||
losses = nn.losses.cross_entropy(logits, y)
|
|
||||||
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
|
|
||||||
|
|
||||||
|
|
||||||
def to_samples(context_size, dataset):
|
def to_samples(context_size, dataset):
|
||||||
tokens = dataset.size
|
tokens = dataset.size
|
||||||
@ -88,31 +84,42 @@ def main(args):
|
|||||||
)
|
)
|
||||||
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
|
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
|
||||||
|
|
||||||
|
def loss_fn(model, x, y, reduce=True):
|
||||||
|
logits = model(x)
|
||||||
|
losses = nn.losses.cross_entropy(logits, y)
|
||||||
|
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
|
||||||
|
|
||||||
optimizer = optim.AdamW(
|
optimizer = optim.AdamW(
|
||||||
learning_rate=args.learning_rate, weight_decay=args.weight_decay
|
learning_rate=args.learning_rate, weight_decay=args.weight_decay
|
||||||
)
|
)
|
||||||
loss_and_grad_fn = nn.value_and_grad(model, model.loss)
|
|
||||||
|
|
||||||
def eval_fn(model, dataset):
|
def eval_fn(dataset):
|
||||||
inputs, targets = map(mx.array, to_samples(context_size, dataset))
|
inputs, targets = map(mx.array, to_samples(context_size, dataset))
|
||||||
loss = 0
|
loss = 0
|
||||||
for s in range(0, targets.shape[0], batch_size):
|
for s in range(0, targets.shape[0], batch_size):
|
||||||
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
|
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
|
||||||
bx, by = map(mx.array, (bx, by))
|
bx, by = map(mx.array, (bx, by))
|
||||||
losses = model.loss(bx, by, reduce=False)
|
losses = loss(bx, by, reduce=False)
|
||||||
loss += mx.sum(losses).item()
|
loss += mx.sum(losses).item()
|
||||||
return loss / len(targets)
|
return loss / len(targets)
|
||||||
|
|
||||||
|
state = [model.state, optimizer.state]
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
|
def step(inputs, targets):
|
||||||
|
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||||
|
loss, grads = loss_and_grad_fn(model, inputs, targets)
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
return loss
|
||||||
|
|
||||||
train_iterator = iterate_batches(batch_size, context_size, train)
|
train_iterator = iterate_batches(batch_size, context_size, train)
|
||||||
losses = []
|
losses = []
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
|
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
|
||||||
inputs, targets = map(mx.array, (inputs, targets))
|
inputs, targets = map(mx.array, (inputs, targets))
|
||||||
loss, grads = loss_and_grad_fn(inputs, targets)
|
|
||||||
optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate
|
optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate
|
||||||
optimizer.update(model, grads)
|
loss = step(inputs, targets)
|
||||||
del grads
|
mx.eval(state)
|
||||||
mx.eval(loss, model.parameters())
|
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
if (it + 1) % steps_per_report == 0:
|
if (it + 1) % steps_per_report == 0:
|
||||||
train_loss = np.mean(losses)
|
train_loss = np.mean(losses)
|
||||||
|
@ -1 +1 @@
|
|||||||
mlx >= 0.12
|
mlx>=0.2
|
||||||
|
Loading…
Reference in New Issue
Block a user