mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
feat(clip): add linear probe evaluation script
This commit is contained in:
parent
bf21789b17
commit
213a950e6e
55
clip/eval.py
Normal file
55
clip/eval.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Mirror of the Linear Probe Evaluation Script
|
||||
# from the official CLIP Repository.
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from image_processor import CLIPImageProcessor
|
||||
from mlx.data.datasets import load_cifar10
|
||||
from model import CLIPModel
|
||||
from PIL import Image
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
|
||||
|
||||
def get_cifar10(batch_size, root=None):
|
||||
tr = load_cifar10(root=root)
|
||||
tr_iter = tr.to_stream().batch(batch_size)
|
||||
|
||||
test = load_cifar10(root=root, train=False)
|
||||
test_iter = test.to_stream().batch(batch_size)
|
||||
|
||||
return tr_iter, test_iter
|
||||
|
||||
|
||||
def get_features(model, image_proc, iter):
|
||||
all_features = []
|
||||
all_labels = []
|
||||
|
||||
for _, batch in enumerate(iter):
|
||||
image, label = batch["image"], batch["label"]
|
||||
x = image_proc([Image.fromarray(im) for im in image])
|
||||
y = mx.array(label)
|
||||
|
||||
image_embeds = model.get_image_features(x)
|
||||
all_features.append(image_embeds)
|
||||
all_labels.append(y)
|
||||
|
||||
return mx.concatenate(all_features), mx.concatenate(all_labels)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = CLIPModel.from_pretrained("mlx_model")
|
||||
image_proc = CLIPImageProcessor.from_pretrained("mlx_model")
|
||||
|
||||
train_iter, test_iter = get_cifar10(batch_size=256)
|
||||
train_features, train_labels = get_features(model, image_proc, train_iter)
|
||||
test_features, test_labels = get_features(model, image_proc, test_iter)
|
||||
|
||||
# Perform logistic regression
|
||||
# NOTE: The value of C should be determined via a hyperparameter sweep
|
||||
# using a validation split
|
||||
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
|
||||
classifier.fit(train_features, train_labels)
|
||||
|
||||
# Evaluate using the logistic regression classifier
|
||||
predictions = classifier.predict(test_features)
|
||||
accuracy = np.mean((np.array(test_labels) == predictions).astype(float)) * 100.0
|
||||
print(f"Accuracy = {accuracy:.3f}")
|
@ -1,4 +1,5 @@
|
||||
mlx
|
||||
mlx-data
|
||||
numpy
|
||||
transformers
|
||||
torch
|
||||
|
Loading…
Reference in New Issue
Block a user