mlx-examples/cifar/README.md

65 lines
1.4 KiB
Markdown
Raw Normal View History

2023-12-13 02:01:06 +08:00
# CIFAR and ResNets
2023-12-15 04:09:10 +08:00
An example of training a ResNet on CIFAR-10 with MLX. Several ResNet
configurations in accordance with the original
[paper](https://arxiv.org/abs/1512.03385) are available. The example also
illustrates how to use [MLX Data](https://github.com/ml-explore/mlx-data) to
load the dataset.
2023-12-13 02:01:06 +08:00
## Pre-requisites
2023-12-15 04:09:10 +08:00
Install the dependencies:
2023-12-13 02:01:06 +08:00
```
pip install -r requirements.txt
```
## Running the example
2023-12-15 04:09:10 +08:00
2023-12-13 02:01:06 +08:00
Run the example with:
```
python main.py
```
By default the example runs on the GPU. To run on the CPU, use:
```
python main.py --cpu
2023-12-13 02:01:06 +08:00
```
For all available options, run:
```
python main.py --help
```
2023-12-15 04:09:10 +08:00
## Results
After training with the default `resnet20` architecture for 30 epochs, you
2023-12-15 04:09:10 +08:00
should see the following results:
```
Epoch: 29 | avg. Train loss 0.294 | avg. Train acc 0.897 | Throughput: 270.81 images/sec
Epoch: 29 | Test acc 0.841
```
2023-12-15 04:09:10 +08:00
Note this was run on an M1 Macbook Pro with 16GB RAM.
At the time of writing, `mlx` doesn't have built-in learning rate schedules.
We intend to update this example once these features are added.
2025-03-06 05:33:15 +08:00
## Distributed training
The example also supports distributed data parallel training. You can launch a
distributed training as follows:
```shell
$ cat >hostfile.json
[
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]},
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]}
]
$ mlx.launch --verbose --hostfile hostfile.json main.py --batch 256 --epochs 5 --arch resnet20
```