mlx-examples/cifar
Awni Hannun 27c0a8c002
Add llms subdir + update README (#145)
* add llms subdir + update README

* nits

* use same pre-commit as mlx

* update readmes a bit

* format
2023-12-20 10:22:25 -08:00
..
dataset.py Add llms subdir + update README (#145) 2023-12-20 10:22:25 -08:00
main.py Add llms subdir + update README (#145) 2023-12-20 10:22:25 -08:00
README.md typo / nits 2023-12-14 12:14:01 -08:00
requirements.txt simplified ResNet, expanded README with throughput and performance 2023-12-14 09:05:04 +01:00
resnet.py Add llms subdir + update README (#145) 2023-12-20 10:22:25 -08:00

CIFAR and ResNets

An example of training a ResNet on CIFAR-10 with MLX. Several ResNet configurations in accordance with the original paper are available. The example also illustrates how to use MLX Data to load the dataset.

Pre-requisites

Install the dependencies:

pip install -r requirements.txt

Running the example

Run the example with:

python main.py

By default the example runs on the GPU. To run on the CPU, use:

python main.py --cpu

For all available options, run:

python main.py --help

Results

After training with the default resnet20 architecture for 100 epochs, you should see the following results:

Epoch: 99 | avg. Train loss 0.320 | avg. Train acc 0.888 | Throughput: 416.77 images/sec
Epoch: 99 | Test acc 0.807

Note this was run on an M1 Macbook Pro with 16GB RAM.

At the time of writing, mlx doesn't have built-in learning rate schedules, or a BatchNorm layer. We intend to update this example once these features are added.