From 4971462bf0dd7bba07d9f18fb0fd2752a51fde40 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Fri, 25 Oct 2024 05:56:17 +0100 Subject: [PATCH] feat(clip): add linear probe evaluation script (#960) --- clip/linear_probe.py | 56 +++++++++++++++++++++++++++++++++++++++++++ clip/requirements.txt | 1 + 2 files changed, 57 insertions(+) create mode 100644 clip/linear_probe.py diff --git a/clip/linear_probe.py b/clip/linear_probe.py new file mode 100644 index 00000000..2649e397 --- /dev/null +++ b/clip/linear_probe.py @@ -0,0 +1,56 @@ +# Mirror of the Linear Probe Evaluation Script +# from the official CLIP Repository. + +import mlx.core as mx +import numpy as np +from image_processor import CLIPImageProcessor +from mlx.data.datasets import load_cifar10 +from model import CLIPModel +from PIL import Image +from sklearn.linear_model import LogisticRegression +from tqdm import tqdm + + +def get_cifar10(batch_size, root=None): + tr = load_cifar10(root=root).batch(batch_size) + test = load_cifar10(root=root, train=False).batch(batch_size) + + return tr, test + + +def get_features(model, image_proc, iter): + all_features = [] + all_labels = [] + + for batch in tqdm(iter): + image, label = batch["image"], batch["label"] + x = image_proc([Image.fromarray(im) for im in image]) + y = mx.array(label) + + image_embeds = model.get_image_features(x) + mx.eval(image_embeds) + + all_features.append(image_embeds) + all_labels.append(y) + + return mx.concatenate(all_features), mx.concatenate(all_labels) + + +if __name__ == "__main__": + model = CLIPModel.from_pretrained("mlx_model") + image_proc = CLIPImageProcessor.from_pretrained("mlx_model") + + train_iter, test_iter = get_cifar10(batch_size=256) + train_features, train_labels = get_features(model, image_proc, train_iter) + test_features, test_labels = get_features(model, image_proc, test_iter) + + # Perform logistic regression + # NOTE: The value of C should be determined via a hyperparameter sweep + # using a validation split + classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1) + classifier.fit(train_features, train_labels) + + # Evaluate using the logistic regression classifier + predictions = classifier.predict(test_features) + accuracy = (test_labels.squeeze() == predictions).mean().item() * 100 + print(f"Accuracy = {accuracy:.3f}") diff --git a/clip/requirements.txt b/clip/requirements.txt index 74f826ea..8e05620e 100644 --- a/clip/requirements.txt +++ b/clip/requirements.txt @@ -1,4 +1,5 @@ mlx +mlx-data numpy transformers torch