Fix linear probe script

This commit is contained in:
Angelos Katharopoulos 2024-10-24 21:53:00 -07:00
parent dee703eacf
commit 6d13a14757

View File

@ -1,5 +1,6 @@
# Mirror of the Linear Probe Evaluation Script
# from the official CLIP Repository.
import mlx.core as mx
import numpy as np
from image_processor import CLIPImageProcessor
@ -12,7 +13,6 @@ from tqdm import tqdm
def get_cifar10(batch_size, root=None):
tr = load_cifar10(root=root).batch(batch_size)
test = load_cifar10(root=root, train=False).batch(batch_size)
return tr, test
@ -28,6 +28,8 @@ def get_features(model, image_proc, iter):
y = mx.array(label)
image_embeds = model.get_image_features(x)
mx.eval(image_embeds)
all_features.append(image_embeds)
all_labels.append(y)
@ -50,5 +52,5 @@ if __name__ == "__main__":
# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = np.mean((np.array(test_labels) == predictions).astype(float)) * 100.0
accuracy = (test_labels.squeeze() == predictions).mean().item() * 100
print(f"Accuracy = {accuracy:.3f}")