mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-16 07:21:12 +08:00
parent
848f118ac5
commit
0371d90ccb
@ -14,10 +14,16 @@ Run the example with:
|
|||||||
python main.py
|
python main.py
|
||||||
```
|
```
|
||||||
|
|
||||||
By default the example runs on the CPU. To run on the GPU, use:
|
By default, the example runs on the CPU. To run on the GPU, use:
|
||||||
|
|
||||||
```
|
```
|
||||||
python main.py --gpu
|
python main.py --gpu
|
||||||
```
|
```
|
||||||
|
|
||||||
|
For a full list of options run:
|
||||||
|
|
||||||
|
```
|
||||||
|
python main.py --help
|
||||||
|
```
|
||||||
|
|
||||||
To run the PyTorch or Jax examples install the respective framework.
|
To run the PyTorch or Jax examples install the respective framework.
|
||||||
|
@ -54,10 +54,10 @@ if __name__ == "__main__":
|
|||||||
batch_size = 256
|
batch_size = 256
|
||||||
num_epochs = 10
|
num_epochs = 10
|
||||||
learning_rate = 1e-1
|
learning_rate = 1e-1
|
||||||
|
dataset = "mnist"
|
||||||
|
|
||||||
# Load the data
|
# Load the data
|
||||||
train_images, train_labels, test_images, test_labels = mnist.mnist()
|
train_images, train_labels, test_images, test_labels = getattr(mnist, dataset)()
|
||||||
|
|
||||||
# Load the model
|
# Load the model
|
||||||
key, subkey = jax.random.split(jax.random.PRNGKey(seed))
|
key, subkey = jax.random.split(jax.random.PRNGKey(seed))
|
||||||
params = init_model(
|
params = init_model(
|
||||||
|
@ -45,7 +45,7 @@ def batch_iterate(batch_size, X, y):
|
|||||||
yield X[ids], y[ids]
|
yield X[ids], y[ids]
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main(args):
|
||||||
seed = 0
|
seed = 0
|
||||||
num_layers = 2
|
num_layers = 2
|
||||||
hidden_dim = 32
|
hidden_dim = 32
|
||||||
@ -57,7 +57,9 @@ def main():
|
|||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
|
|
||||||
# Load the data
|
# Load the data
|
||||||
train_images, train_labels, test_images, test_labels = map(mx.array, mnist.mnist())
|
train_images, train_labels, test_images, test_labels = map(
|
||||||
|
mx.array, getattr(mnist, args.dataset)()
|
||||||
|
)
|
||||||
|
|
||||||
# Load the model
|
# Load the model
|
||||||
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
|
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
|
||||||
@ -83,7 +85,14 @@ def main():
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser("Train a simple MLP on MNIST with MLX.")
|
parser = argparse.ArgumentParser("Train a simple MLP on MNIST with MLX.")
|
||||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
type=str,
|
||||||
|
default="mnist",
|
||||||
|
choices=["mnist", "fashion_mnist"],
|
||||||
|
help="The dataset to use.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if not args.gpu:
|
if not args.gpu:
|
||||||
mx.set_default_device(mx.cpu)
|
mx.set_default_device(mx.cpu)
|
||||||
main()
|
main(args)
|
||||||
|
@ -8,7 +8,9 @@ from urllib import request
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def mnist(save_dir="/tmp"):
|
def mnist(
|
||||||
|
save_dir="/tmp", base_url="http://yann.lecun.com/exdb/mnist/", filename="mnist.pkl"
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load the MNIST dataset in 4 tensors: train images, train labels,
|
Load the MNIST dataset in 4 tensors: train images, train labels,
|
||||||
test images, and test labels.
|
test images, and test labels.
|
||||||
@ -20,7 +22,6 @@ def mnist(save_dir="/tmp"):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def download_and_save(save_file):
|
def download_and_save(save_file):
|
||||||
base_url = "http://yann.lecun.com/exdb/mnist/"
|
|
||||||
filename = [
|
filename = [
|
||||||
["training_images", "train-images-idx3-ubyte.gz"],
|
["training_images", "train-images-idx3-ubyte.gz"],
|
||||||
["test_images", "t10k-images-idx3-ubyte.gz"],
|
["test_images", "t10k-images-idx3-ubyte.gz"],
|
||||||
@ -45,13 +46,15 @@ def mnist(save_dir="/tmp"):
|
|||||||
with open(save_file, "wb") as f:
|
with open(save_file, "wb") as f:
|
||||||
pickle.dump(mnist, f)
|
pickle.dump(mnist, f)
|
||||||
|
|
||||||
save_file = os.path.join(save_dir, "mnist.pkl")
|
save_file = os.path.join(save_dir, filename)
|
||||||
if not os.path.exists(save_file):
|
if not os.path.exists(save_file):
|
||||||
download_and_save(save_file)
|
download_and_save(save_file)
|
||||||
with open(save_file, "rb") as f:
|
with open(save_file, "rb") as f:
|
||||||
mnist = pickle.load(f)
|
mnist = pickle.load(f)
|
||||||
|
|
||||||
preproc = lambda x: x.astype(np.float32) / 255.0
|
def preproc(x):
|
||||||
|
return x.astype(np.float32) / 255.0
|
||||||
|
|
||||||
mnist["training_images"] = preproc(mnist["training_images"])
|
mnist["training_images"] = preproc(mnist["training_images"])
|
||||||
mnist["test_images"] = preproc(mnist["test_images"])
|
mnist["test_images"] = preproc(mnist["test_images"])
|
||||||
return (
|
return (
|
||||||
@ -62,6 +65,14 @@ def mnist(save_dir="/tmp"):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fashion_mnist(save_dir="/tmp"):
|
||||||
|
return mnist(
|
||||||
|
save_dir,
|
||||||
|
base_url="http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/",
|
||||||
|
filename="fashion_mnist.pkl",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
train_x, train_y, test_x, test_y = mnist()
|
train_x, train_y, test_x, test_y = mnist()
|
||||||
assert train_x.shape == (60000, 28 * 28), "Wrong training set size"
|
assert train_x.shape == (60000, 28 * 28), "Wrong training set size"
|
||||||
|
@ -49,6 +49,13 @@ def batch_iterate(batch_size, X, y, device):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser("Train a simple MLP on MNIST with PyTorch.")
|
parser = argparse.ArgumentParser("Train a simple MLP on MNIST with PyTorch.")
|
||||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
type=str,
|
||||||
|
default="mnist",
|
||||||
|
choices=["mnist", "fashion_mnist"],
|
||||||
|
help="The dataset to use.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not args.gpu:
|
if not args.gpu:
|
||||||
@ -71,7 +78,9 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
return torch.from_numpy(x.astype(int)).to(device)
|
return torch.from_numpy(x.astype(int)).to(device)
|
||||||
|
|
||||||
train_images, train_labels, test_images, test_labels = map(to_tensor, mnist.mnist())
|
train_images, train_labels, test_images, test_labels = map(
|
||||||
|
to_tensor, getattr(mnist, args.dataset)()
|
||||||
|
)
|
||||||
|
|
||||||
# Load the model
|
# Load the model
|
||||||
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes).to(device)
|
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes).to(device)
|
||||||
|
Loading…
Reference in New Issue
Block a user