mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Fix linear probe script
This commit is contained in:
parent
dee703eacf
commit
6d13a14757
@ -1,5 +1,6 @@
|
||||
# Mirror of the Linear Probe Evaluation Script
|
||||
# from the official CLIP Repository.
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from image_processor import CLIPImageProcessor
|
||||
@ -12,7 +13,6 @@ from tqdm import tqdm
|
||||
|
||||
def get_cifar10(batch_size, root=None):
|
||||
tr = load_cifar10(root=root).batch(batch_size)
|
||||
|
||||
test = load_cifar10(root=root, train=False).batch(batch_size)
|
||||
|
||||
return tr, test
|
||||
@ -28,6 +28,8 @@ def get_features(model, image_proc, iter):
|
||||
y = mx.array(label)
|
||||
|
||||
image_embeds = model.get_image_features(x)
|
||||
mx.eval(image_embeds)
|
||||
|
||||
all_features.append(image_embeds)
|
||||
all_labels.append(y)
|
||||
|
||||
@ -50,5 +52,5 @@ if __name__ == "__main__":
|
||||
|
||||
# Evaluate using the logistic regression classifier
|
||||
predictions = classifier.predict(test_features)
|
||||
accuracy = np.mean((np.array(test_labels) == predictions).astype(float)) * 100.0
|
||||
accuracy = (test_labels.squeeze() == predictions).mean().item() * 100
|
||||
print(f"Accuracy = {accuracy:.3f}")
|
||||
|
Loading…
Reference in New Issue
Block a user