mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Fix linear probe script
This commit is contained in:
parent
dee703eacf
commit
6d13a14757
@ -1,5 +1,6 @@
|
|||||||
# Mirror of the Linear Probe Evaluation Script
|
# Mirror of the Linear Probe Evaluation Script
|
||||||
# from the official CLIP Repository.
|
# from the official CLIP Repository.
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from image_processor import CLIPImageProcessor
|
from image_processor import CLIPImageProcessor
|
||||||
@ -12,7 +13,6 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
def get_cifar10(batch_size, root=None):
|
def get_cifar10(batch_size, root=None):
|
||||||
tr = load_cifar10(root=root).batch(batch_size)
|
tr = load_cifar10(root=root).batch(batch_size)
|
||||||
|
|
||||||
test = load_cifar10(root=root, train=False).batch(batch_size)
|
test = load_cifar10(root=root, train=False).batch(batch_size)
|
||||||
|
|
||||||
return tr, test
|
return tr, test
|
||||||
@ -28,6 +28,8 @@ def get_features(model, image_proc, iter):
|
|||||||
y = mx.array(label)
|
y = mx.array(label)
|
||||||
|
|
||||||
image_embeds = model.get_image_features(x)
|
image_embeds = model.get_image_features(x)
|
||||||
|
mx.eval(image_embeds)
|
||||||
|
|
||||||
all_features.append(image_embeds)
|
all_features.append(image_embeds)
|
||||||
all_labels.append(y)
|
all_labels.append(y)
|
||||||
|
|
||||||
@ -50,5 +52,5 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Evaluate using the logistic regression classifier
|
# Evaluate using the logistic regression classifier
|
||||||
predictions = classifier.predict(test_features)
|
predictions = classifier.predict(test_features)
|
||||||
accuracy = np.mean((np.array(test_labels) == predictions).astype(float)) * 100.0
|
accuracy = (test_labels.squeeze() == predictions).mean().item() * 100
|
||||||
print(f"Accuracy = {accuracy:.3f}")
|
print(f"Accuracy = {accuracy:.3f}")
|
||||||
|
Loading…
Reference in New Issue
Block a user