mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
feat: simplify data handling
This commit is contained in:
parent
213a950e6e
commit
dee703eacf
@ -7,23 +7,22 @@ from mlx.data.datasets import load_cifar10
|
||||
from model import CLIPModel
|
||||
from PIL import Image
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_cifar10(batch_size, root=None):
|
||||
tr = load_cifar10(root=root)
|
||||
tr_iter = tr.to_stream().batch(batch_size)
|
||||
tr = load_cifar10(root=root).batch(batch_size)
|
||||
|
||||
test = load_cifar10(root=root, train=False)
|
||||
test_iter = test.to_stream().batch(batch_size)
|
||||
test = load_cifar10(root=root, train=False).batch(batch_size)
|
||||
|
||||
return tr_iter, test_iter
|
||||
return tr, test
|
||||
|
||||
|
||||
def get_features(model, image_proc, iter):
|
||||
all_features = []
|
||||
all_labels = []
|
||||
|
||||
for _, batch in enumerate(iter):
|
||||
for batch in tqdm(iter):
|
||||
image, label = batch["image"], batch["label"]
|
||||
x = image_proc([Image.fromarray(im) for im in image])
|
||||
y = mx.array(label)
|
Loading…
Reference in New Issue
Block a user