simplified ResNet, expanded README with throughput and performance

This commit is contained in:
Sarthak Yadav
2023-12-14 09:05:04 +01:00
parent 2439333a57
commit 15a6c155a8
5 changed files with 56 additions and 36 deletions

View File

@@ -1,11 +1,10 @@
# CIFAR and ResNets
* This example shows how to run ResNets on CIFAR10 dataset, in accordance with the original [paper](https://arxiv.org/abs/1512.03385).
* Also illustrates how to use `mlx-data` to download and load the dataset.
An example of training a ResNet on CIFAR-10 with MLX. Several ResNet configurations in accordance with the original [paper](https://arxiv.org/abs/1512.03385) are available. Also illustrates how to use `mlx-data` to download and load the dataset.
## Pre-requisites
* Install the dependencies:
Install the dependencies:
```
pip install -r requirements.txt
@@ -21,7 +20,7 @@ python main.py
By default the example runs on the GPU. To run on the CPU, use:
```
python main.py --cpu_only
python main.py --cpu
```
For all available options, run:
@@ -29,3 +28,24 @@ For all available options, run:
```
python main.py --help
```
## Throughput
On the tested device (M1 Macbook Pro, 16GB RAM), I get the following throughput with a `batch_size=256`:
```
Epoch: 0 | avg. tr_loss 2.074 | avg. tr_acc 0.216 | Train Throughput: 415.39 images/sec
```
When training on just the CPU (with the `--cpu` argument), the throughput is significantly lower (almost 30x!):
```
Epoch: 0 | avg. tr_loss 2.074 | avg. tr_acc 0.216 | Train Throughput: 13.5 images/sec
```
## Results
After training for 100 epochs, the following results were observed:
```
Epoch: 99 | avg. tr_loss 0.320 | avg. tr_acc 0.888 | Train Throughput: 416.77 images/sec
Epoch: 99 | test_acc 0.807
```
At the time of writing, `mlx` doesn't have in-built `schedulers`, nor a `BatchNorm` layer. We'll revisit this example for exact reproduction once these features are added.