From 213a950e6eb57b91c0d2c63dc541598f4f54146c Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Wed, 28 Aug 2024 20:54:38 +0100 Subject: [PATCH] feat(clip): add linear probe evaluation script --- clip/eval.py | 55 +++++++++++++++++++++++++++++++++++++++++++ clip/requirements.txt | 1 + 2 files changed, 56 insertions(+) create mode 100644 clip/eval.py diff --git a/clip/eval.py b/clip/eval.py new file mode 100644 index 00000000..5f3dd216 --- /dev/null +++ b/clip/eval.py @@ -0,0 +1,55 @@ +# Mirror of the Linear Probe Evaluation Script +# from the official CLIP Repository. +import mlx.core as mx +import numpy as np +from image_processor import CLIPImageProcessor +from mlx.data.datasets import load_cifar10 +from model import CLIPModel +from PIL import Image +from sklearn.linear_model import LogisticRegression + + +def get_cifar10(batch_size, root=None): + tr = load_cifar10(root=root) + tr_iter = tr.to_stream().batch(batch_size) + + test = load_cifar10(root=root, train=False) + test_iter = test.to_stream().batch(batch_size) + + return tr_iter, test_iter + + +def get_features(model, image_proc, iter): + all_features = [] + all_labels = [] + + for _, batch in enumerate(iter): + image, label = batch["image"], batch["label"] + x = image_proc([Image.fromarray(im) for im in image]) + y = mx.array(label) + + image_embeds = model.get_image_features(x) + all_features.append(image_embeds) + all_labels.append(y) + + return mx.concatenate(all_features), mx.concatenate(all_labels) + + +if __name__ == "__main__": + model = CLIPModel.from_pretrained("mlx_model") + image_proc = CLIPImageProcessor.from_pretrained("mlx_model") + + train_iter, test_iter = get_cifar10(batch_size=256) + train_features, train_labels = get_features(model, image_proc, train_iter) + test_features, test_labels = get_features(model, image_proc, test_iter) + + # Perform logistic regression + # NOTE: The value of C should be determined via a hyperparameter sweep + # using a validation split + classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1) + classifier.fit(train_features, train_labels) + + # Evaluate using the logistic regression classifier + predictions = classifier.predict(test_features) + accuracy = np.mean((np.array(test_labels) == predictions).astype(float)) * 100.0 + print(f"Accuracy = {accuracy:.3f}") diff --git a/clip/requirements.txt b/clip/requirements.txt index 74f826ea..8e05620e 100644 --- a/clip/requirements.txt +++ b/clip/requirements.txt @@ -1,4 +1,5 @@ mlx +mlx-data numpy transformers torch