2023-12-13 02:01:06 +08:00
|
|
|
# CIFAR and ResNets
|
|
|
|
|
2023-12-15 04:09:10 +08:00
|
|
|
An example of training a ResNet on CIFAR-10 with MLX. Several ResNet
|
|
|
|
configurations in accordance with the original
|
|
|
|
[paper](https://arxiv.org/abs/1512.03385) are available. The example also
|
|
|
|
illustrates how to use [MLX Data](https://github.com/ml-explore/mlx-data) to
|
|
|
|
load the dataset.
|
2023-12-13 02:01:06 +08:00
|
|
|
|
|
|
|
## Pre-requisites
|
2023-12-15 04:09:10 +08:00
|
|
|
|
2023-12-14 16:05:04 +08:00
|
|
|
Install the dependencies:
|
2023-12-13 02:01:06 +08:00
|
|
|
|
|
|
|
```
|
|
|
|
pip install -r requirements.txt
|
|
|
|
```
|
|
|
|
|
|
|
|
## Running the example
|
2023-12-15 04:09:10 +08:00
|
|
|
|
2023-12-13 02:01:06 +08:00
|
|
|
Run the example with:
|
|
|
|
|
|
|
|
```
|
|
|
|
python main.py
|
|
|
|
```
|
|
|
|
|
|
|
|
By default the example runs on the GPU. To run on the CPU, use:
|
|
|
|
|
|
|
|
```
|
2023-12-14 16:05:04 +08:00
|
|
|
python main.py --cpu
|
2023-12-13 02:01:06 +08:00
|
|
|
```
|
|
|
|
|
|
|
|
For all available options, run:
|
|
|
|
|
|
|
|
```
|
|
|
|
python main.py --help
|
|
|
|
```
|
2023-12-14 16:05:04 +08:00
|
|
|
|
2023-12-15 04:09:10 +08:00
|
|
|
## Results
|
2023-12-14 16:05:04 +08:00
|
|
|
|
2023-12-15 04:09:10 +08:00
|
|
|
After training with the default `resnet20` architecture for 100 epochs, you
|
|
|
|
should see the following results:
|
2023-12-14 16:05:04 +08:00
|
|
|
|
|
|
|
```
|
2023-12-15 04:09:10 +08:00
|
|
|
Epoch: 99 | avg. Train loss 0.320 | avg. Train acc 0.888 | Throughput: 416.77 images/sec
|
|
|
|
Epoch: 99 | Test acc 0.807
|
2023-12-14 16:05:04 +08:00
|
|
|
```
|
|
|
|
|
2023-12-15 04:09:10 +08:00
|
|
|
Note this was run on an M1 Macbook Pro with 16GB RAM.
|
2023-12-14 16:05:04 +08:00
|
|
|
|
2023-12-15 04:09:10 +08:00
|
|
|
At the time of writing, `mlx` doesn't have built-in learning rate schedules,
|
2023-12-15 04:14:01 +08:00
|
|
|
or a `BatchNorm` layer. We intend to update this example once these features
|
2023-12-15 04:09:10 +08:00
|
|
|
are added.
|