feat: simplify data handling

This commit is contained in:
Saurav Maheshkar 2024-09-14 23:31:17 +01:00
parent 213a950e6e
commit dee703eacf
No known key found for this signature in database
GPG Key ID: 9947E4A9BBB93993

View File

@ -7,23 +7,22 @@ from mlx.data.datasets import load_cifar10
from model import CLIPModel from model import CLIPModel
from PIL import Image from PIL import Image
from sklearn.linear_model import LogisticRegression from sklearn.linear_model import LogisticRegression
from tqdm import tqdm
def get_cifar10(batch_size, root=None): def get_cifar10(batch_size, root=None):
tr = load_cifar10(root=root) tr = load_cifar10(root=root).batch(batch_size)
tr_iter = tr.to_stream().batch(batch_size)
test = load_cifar10(root=root, train=False) test = load_cifar10(root=root, train=False).batch(batch_size)
test_iter = test.to_stream().batch(batch_size)
return tr_iter, test_iter return tr, test
def get_features(model, image_proc, iter): def get_features(model, image_proc, iter):
all_features = [] all_features = []
all_labels = [] all_labels = []
for _, batch in enumerate(iter): for batch in tqdm(iter):
image, label = batch["image"], batch["label"] image, label = batch["image"], batch["label"]
x = image_proc([Image.fromarray(im) for im in image]) x = image_proc([Image.fromarray(im) for im in image])
y = mx.array(label) y = mx.array(label)