From b1b9b11801e4d86f36ac569e199d70b39f00bfe2 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 12:09:10 -0800 Subject: [PATCH] updates + format --- cifar/README.md | 38 ++++++++++++++++++------------------ cifar/dataset.py | 21 ++++++-------------- cifar/main.py | 51 ++++++++++++++++++++++++++++-------------------- cifar/resnet.py | 1 - 4 files changed, 55 insertions(+), 56 deletions(-) diff --git a/cifar/README.md b/cifar/README.md index abb2c0f5..118aef9e 100644 --- a/cifar/README.md +++ b/cifar/README.md @@ -1,9 +1,13 @@ # CIFAR and ResNets -An example of training a ResNet on CIFAR-10 with MLX. Several ResNet configurations in accordance with the original [paper](https://arxiv.org/abs/1512.03385) are available. Also illustrates how to use `mlx-data` to download and load the dataset. - +An example of training a ResNet on CIFAR-10 with MLX. Several ResNet +configurations in accordance with the original +[paper](https://arxiv.org/abs/1512.03385) are available. The example also +illustrates how to use [MLX Data](https://github.com/ml-explore/mlx-data) to +load the dataset. ## Pre-requisites + Install the dependencies: ``` @@ -11,6 +15,7 @@ pip install -r requirements.txt ``` ## Running the example + Run the example with: ``` @@ -29,23 +34,18 @@ For all available options, run: python main.py --help ``` - -## Throughput - -On the tested device (M1 Macbook Pro, 16GB RAM), I get the following throughput with a `batch_size=256`: -``` -Epoch: 0 | avg. tr_loss 2.074 | avg. tr_acc 0.216 | Train Throughput: 415.39 images/sec -``` - -When training on just the CPU (with the `--cpu` argument), the throughput is significantly lower (almost 30x!): -``` -Epoch: 0 | avg. tr_loss 2.074 | avg. tr_acc 0.216 | Train Throughput: 13.5 images/sec -``` - ## Results -After training for 100 epochs, the following results were observed: + +After training with the default `resnet20` architecture for 100 epochs, you +should see the following results: + ``` -Epoch: 99 | avg. tr_loss 0.320 | avg. tr_acc 0.888 | Train Throughput: 416.77 images/sec -Epoch: 99 | test_acc 0.807 +Epoch: 99 | avg. Train loss 0.320 | avg. Train acc 0.888 | Throughput: 416.77 images/sec +Epoch: 99 | Test acc 0.807 ``` -At the time of writing, `mlx` doesn't have in-built `schedulers`, nor a `BatchNorm` layer. We'll revisit this example for exact reproduction once these features are added. \ No newline at end of file + +Note this was run on an M1 Macbook Pro with 16GB RAM. + +At the time of writing, `mlx` doesn't have built-in learning rate schedules, +nor a `BatchNorm` layer. We intend to update this example once these features +are added. diff --git a/cifar/dataset.py b/cifar/dataset.py index 29f558d1..89b10136 100644 --- a/cifar/dataset.py +++ b/cifar/dataset.py @@ -4,13 +4,15 @@ import math def get_cifar10(batch_size, root=None): - tr = load_cifar10(root=root) - num_tr_samples = tr.size() mean = mx.array([0.485, 0.456, 0.406]).reshape((1, 1, 3)) std = mx.array([0.229, 0.224, 0.225]).reshape((1, 1, 3)) + def normalize(x): + x = x.astype("float32") / 255.0 + return (x - mean) / std + tr_iter = ( tr.shuffle() .to_stream() @@ -18,22 +20,11 @@ def get_cifar10(batch_size, root=None): .pad("image", 0, 4, 4, 0.0) .pad("image", 1, 4, 4, 0.0) .image_random_crop("image", 32, 32) - .key_transform("image", lambda x: (x.astype("float32") / 255.0)) - .key_transform("image", lambda x: (x - mean) / std) + .key_transform("image", normalize) .batch(batch_size) ) test = load_cifar10(root=root, train=False) - num_test_samples = test.size() - - test_iter = ( - test.to_stream() - .key_transform("image", lambda x: (x.astype("float32") / 255.0)) - .key_transform("image", lambda x: (x - mean) / std) - .batch(batch_size) - ) - - num_tr_steps_per_epoch = num_tr_samples // batch_size - num_test_steps_per_epoch = num_test_samples // batch_size + test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size) return tr_iter, test_iter diff --git a/cifar/main.py b/cifar/main.py index 29b0cbc7..26d06a6a 100644 --- a/cifar/main.py +++ b/cifar/main.py @@ -12,7 +12,8 @@ parser.add_argument( "--arch", type=str, default="resnet20", - help="model architecture [resnet20, resnet32, resnet44, resnet56, resnet110, resnet1202]", + choices=[f"resnet{d}" for d in [20, 32, 44, 56, 110, 1202]], + help="model architecture", ) parser.add_argument("--batch_size", type=int, default=256, help="batch size") parser.add_argument("--epochs", type=int, default=100, help="number of epochs") @@ -21,10 +22,6 @@ parser.add_argument("--seed", type=int, default=0, help="random seed") parser.add_argument("--cpu", action="store_true", help="use cpu only") -def loss_fn(model, inp, tgt): - return mx.mean(nn.losses.cross_entropy(model(inp), tgt)) - - def eval_fn(model, inp, tgt): return mx.mean(mx.argmax(model(inp), axis=1) == tgt) @@ -50,17 +47,25 @@ def train_epoch(model, train_iter, optimizer, epoch): optimizer.update(model, grads) mx.eval(model.parameters(), optimizer.state) toc = time.perf_counter() - loss_value = loss.item() - acc_value = acc.item() - losses.append(loss_value) - accs.append(acc_value) - samples_per_sec.append(x.shape[0] / (toc - tic)) + loss = loss.item() + acc = acc.item() + losses.append(loss) + accs.append(acc) + throughput = x.shape[0] / (toc - tic) + samples_per_sec.append(throughput) if batch_counter % 10 == 0: print( - f"Epoch {epoch:02d} [{batch_counter:03d}] | tr_loss {loss_value:.3f} | tr_acc {acc_value:.3f} | Throughput: {x.shape[0] / (toc - tic):.2f} images/second" + " | ".join( + ( + f"Epoch {epoch:02d} [{batch_counter:03d}]", + f"Train loss {loss:.3f}", + f"Train acc {acc:.3f}", + f"Throughput: {throughput:.2f} images/second", + ) + ) ) - mean_tr_loss = mx.mean(mx.array(losses)) + eean_tr_loss = mx.mean(mx.array(losses)) mean_tr_acc = mx.mean(mx.array(accs)) samples_per_sec = mx.mean(mx.array(samples_per_sec)) return mean_tr_loss, mean_tr_acc, samples_per_sec @@ -81,24 +86,28 @@ def test_epoch(model, test_iter, epoch): def main(args): mx.random.seed(args.seed) - model = resnet.__dict__[args.arch]() + model = getattr(resnet, args.arch)() - print("num_params: {:0.04f} M".format(model.num_params() / 1e6)) - mx.eval(model.parameters()) + print("Number of params: {:0.04f} M".format(model.num_params() / 1e6)) optimizer = optim.Adam(learning_rate=args.lr) train_data, test_data = get_cifar10(args.batch_size) for epoch in range(args.epochs): - epoch_tr_loss, epoch_tr_acc, train_throughput = train_epoch( - model, train_data, optimizer, epoch - ) + tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch) print( - f"Epoch: {epoch} | avg. tr_loss {epoch_tr_loss.item():.3f} | avg. tr_acc {epoch_tr_acc.item():.3f} | Train Throughput: {train_throughput.item():.2f} images/sec" + " | ".join( + ( + f"Epoch: {epoch}", + f"avg. Train loss {tr_loss.item():.3f}", + f"avg. Train acc {tr_acc.item():.3f}", + f"Throughput: {throughput.item():.2f} images/sec", + ) + ) ) - epoch_test_acc = test_epoch(model, test_data, epoch) - print(f"Epoch: {epoch} | test_acc {epoch_test_acc.item():.3f}") + test_acc = test_epoch(model, test_data, epoch) + print(f"Epoch: {epoch} | Test acc {test_acc.item():.3f}") train_data.reset() test_data.reset() diff --git a/cifar/resnet.py b/cifar/resnet.py index 22b8a31a..758ee3de 100644 --- a/cifar/resnet.py +++ b/cifar/resnet.py @@ -59,7 +59,6 @@ class Block(nn.Module): self.shortcut = None def __call__(self, x): - out = nn.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) if self.shortcut is None: