From 6d13a14757f6a8629c0f9755072d1b5ab090ceb8 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 24 Oct 2024 21:53:00 -0700 Subject: [PATCH] Fix linear probe script --- clip/linear_probe.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/clip/linear_probe.py b/clip/linear_probe.py index 27c90be3..2649e397 100644 --- a/clip/linear_probe.py +++ b/clip/linear_probe.py @@ -1,5 +1,6 @@ # Mirror of the Linear Probe Evaluation Script # from the official CLIP Repository. + import mlx.core as mx import numpy as np from image_processor import CLIPImageProcessor @@ -12,7 +13,6 @@ from tqdm import tqdm def get_cifar10(batch_size, root=None): tr = load_cifar10(root=root).batch(batch_size) - test = load_cifar10(root=root, train=False).batch(batch_size) return tr, test @@ -28,6 +28,8 @@ def get_features(model, image_proc, iter): y = mx.array(label) image_embeds = model.get_image_features(x) + mx.eval(image_embeds) + all_features.append(image_embeds) all_labels.append(y) @@ -50,5 +52,5 @@ if __name__ == "__main__": # Evaluate using the logistic regression classifier predictions = classifier.predict(test_features) - accuracy = np.mean((np.array(test_labels) == predictions).astype(float)) * 100.0 + accuracy = (test_labels.squeeze() == predictions).mean().item() * 100 print(f"Accuracy = {accuracy:.3f}")