mlx-examples/cifar/README.md
Markus Enzweiler 2b61d9deb6
Updated CIFAR-10 ResNet example to use BatchNorm instead of LayerNorm (#257)
* replaced nn.LayerNorm by nn.BatchNorm

* mlx>=0.0.8 required

* updated default to 30 epochs instead of 100

* updated README after adding BatchNorm

* requires mlx>=0.0.9

* updated README.md with results for mlx-0.0.9
2024-01-12 05:43:11 -08:00

1.0 KiB

CIFAR and ResNets

An example of training a ResNet on CIFAR-10 with MLX. Several ResNet configurations in accordance with the original paper are available. The example also illustrates how to use MLX Data to load the dataset.

Pre-requisites

Install the dependencies:

pip install -r requirements.txt

Running the example

Run the example with:

python main.py

By default the example runs on the GPU. To run on the CPU, use:

python main.py --cpu

For all available options, run:

python main.py --help

Results

After training with the default resnet20 architecture for 30 epochs, you should see the following results:

Epoch: 29 | avg. Train loss 0.294 | avg. Train acc 0.897 | Throughput: 270.81 images/sec
Epoch: 29 | Test acc 0.841

Note this was run on an M1 Macbook Pro with 16GB RAM.

At the time of writing, mlx doesn't have built-in learning rate schedules. We intend to update this example once these features are added.